mirror of
https://github.com/facebookresearch/deit.git
synced 2025-06-03 14:52:20 +08:00
update main.py
This commit is contained in:
parent
be01a30ac2
commit
39232d1e02
45
main.py
45
main.py
@ -30,19 +30,18 @@ import models_v2
|
||||
import model_sparse
|
||||
import random
|
||||
import utils
|
||||
import wandb
|
||||
|
||||
from sparsity_factory.pruners import weight_pruner_loader, prune_weights_reparam, check_valid_pruner
|
||||
|
||||
def get_args_parser():
|
||||
parser = argparse.ArgumentParser('DeiT training and evaluation script', add_help=False)
|
||||
parser.add_argument('--batch-size', default=64, type=int)
|
||||
parser.add_argument('--batch-size', default=128, type=int)
|
||||
parser.add_argument('--epochs', default=300, type=int)
|
||||
parser.add_argument('--bce-loss', action='store_true')
|
||||
parser.add_argument('--unscale-lr', action='store_true')
|
||||
|
||||
# Model parameters
|
||||
parser.add_argument('--model', default='deit_base_patch16_224', type=str, metavar='MODEL',
|
||||
help='Name of model to train')
|
||||
parser.add_argument('--input-size', default=224, type=int, help='images input size')
|
||||
|
||||
parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
|
||||
@ -155,7 +154,7 @@ def get_args_parser():
|
||||
parser.add_argument('--attn-only', action='store_true')
|
||||
|
||||
# Dataset parameters
|
||||
parser.add_argument('--data-path', default='/datasets01/imagenet_full_size/061417/', type=str,
|
||||
parser.add_argument('--data-path', default='/dataset/imagenet', type=str,
|
||||
help='dataset path')
|
||||
parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET', 'INAT', 'INAT19'],
|
||||
type=str, help='Image Net dataset path')
|
||||
@ -163,8 +162,6 @@ def get_args_parser():
|
||||
choices=['kingdom', 'phylum', 'class', 'order', 'supercategory', 'family', 'genus', 'name'],
|
||||
type=str, help='semantic granularity')
|
||||
|
||||
parser.add_argument('--output_dir', default='',
|
||||
help='path where to save, empty for no saving')
|
||||
parser.add_argument('--device', default='cuda',
|
||||
help='device to use for training / testing')
|
||||
parser.add_argument('--seed', default=0, type=int)
|
||||
@ -187,8 +184,14 @@ def get_args_parser():
|
||||
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
|
||||
|
||||
# Sparsity Training Related Flag
|
||||
parser.add_argument('--nas-config', type=str, help='configuration for supernet training')
|
||||
parser.add_argument('--nas-mode', action='store_true')
|
||||
parser.add_argument('--model', default='Sparse_deit_small_patch16_224', type=str, metavar='MODEL',
|
||||
help='Name of model to train')
|
||||
parser.add_argument('--nas-config', type=str, default='configs/deit_small_nxm_ea124_9.0M.yml', help='configuration for supernet training')
|
||||
parser.add_argument('--nas-mode', action='store_true', default=True)
|
||||
parser.add_argument('--nas-weights', default='weights/nas_pretrained.pth', help='load pretrained supernet weight')
|
||||
parser.add_argument('--wandb', action='store_true')
|
||||
parser.add_argument('--output_dir', default='result',
|
||||
help='path where to save, empty for no saving')
|
||||
return parser
|
||||
|
||||
|
||||
@ -289,6 +292,14 @@ def main(args):
|
||||
img_size=args.input_size
|
||||
)
|
||||
|
||||
if args.wandb:
|
||||
wandb.init(project='sparsity')
|
||||
|
||||
# load nas pretrained weight
|
||||
if args.nas_weights:
|
||||
state_dict = torch.load(args.nas_weights)
|
||||
model.load_state_dict(state_dict['model'], strict=True)
|
||||
|
||||
if args.finetune:
|
||||
if args.finetune.startswith('https'):
|
||||
checkpoint = torch.hub.load_state_dict_from_url(
|
||||
@ -358,11 +369,23 @@ def main(args):
|
||||
device='cpu' if args.model_ema_force_cpu else '',
|
||||
resume='')
|
||||
|
||||
|
||||
## Add ASP to model (apex)
|
||||
one_ll = model.blocks[0].attn.proj.weight
|
||||
ASP.init_model_for_pruning(model, "m4n2_1d", whitelist=[torch.nn.Linear, torch.nn.Conv2d], allow_recompute_mask=True)
|
||||
# ASP.init_optimizer_for_pruning(optimizer)
|
||||
print("DENSE :: ", one_ll)
|
||||
ASP.compute_sparse_masks()
|
||||
print("SPARSE :: ", one_ll)
|
||||
return
|
||||
|
||||
model_without_ddp = model
|
||||
if args.distributed:
|
||||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
|
||||
model_without_ddp = model.module
|
||||
|
||||
print(model_without_ddp)
|
||||
|
||||
if args.nas_mode:
|
||||
smallest_config = []
|
||||
for ratios in nas_config['sparsity']['choices']:
|
||||
@ -471,7 +494,13 @@ def main(args):
|
||||
|
||||
|
||||
test_stats = evaluate(data_loader_val, model, device)
|
||||
|
||||
print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
|
||||
if wandb and wandb.run:
|
||||
wandb.log({**{f'train_{k}': v for k, v in train_stats.items()},
|
||||
**{f'test_{k}': v for k, v in test_stats.items()},
|
||||
'epoch': epoch,
|
||||
'n_parameters': n_parameters})
|
||||
|
||||
if max_accuracy < test_stats["acc1"]:
|
||||
max_accuracy = test_stats["acc1"]
|
||||
|
Loading…
x
Reference in New Issue
Block a user