update main.py

This commit is contained in:
max410011 2023-05-08 08:33:14 +00:00
parent be01a30ac2
commit 39232d1e02

45
main.py
View File

@ -30,19 +30,18 @@ import models_v2
import model_sparse
import random
import utils
import wandb
from sparsity_factory.pruners import weight_pruner_loader, prune_weights_reparam, check_valid_pruner
def get_args_parser():
parser = argparse.ArgumentParser('DeiT training and evaluation script', add_help=False)
parser.add_argument('--batch-size', default=64, type=int)
parser.add_argument('--batch-size', default=128, type=int)
parser.add_argument('--epochs', default=300, type=int)
parser.add_argument('--bce-loss', action='store_true')
parser.add_argument('--unscale-lr', action='store_true')
# Model parameters
parser.add_argument('--model', default='deit_base_patch16_224', type=str, metavar='MODEL',
help='Name of model to train')
parser.add_argument('--input-size', default=224, type=int, help='images input size')
parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
@ -155,7 +154,7 @@ def get_args_parser():
parser.add_argument('--attn-only', action='store_true')
# Dataset parameters
parser.add_argument('--data-path', default='/datasets01/imagenet_full_size/061417/', type=str,
parser.add_argument('--data-path', default='/dataset/imagenet', type=str,
help='dataset path')
parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET', 'INAT', 'INAT19'],
type=str, help='Image Net dataset path')
@ -163,8 +162,6 @@ def get_args_parser():
choices=['kingdom', 'phylum', 'class', 'order', 'supercategory', 'family', 'genus', 'name'],
type=str, help='semantic granularity')
parser.add_argument('--output_dir', default='',
help='path where to save, empty for no saving')
parser.add_argument('--device', default='cuda',
help='device to use for training / testing')
parser.add_argument('--seed', default=0, type=int)
@ -187,8 +184,14 @@ def get_args_parser():
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
# Sparsity Training Related Flag
parser.add_argument('--nas-config', type=str, help='configuration for supernet training')
parser.add_argument('--nas-mode', action='store_true')
parser.add_argument('--model', default='Sparse_deit_small_patch16_224', type=str, metavar='MODEL',
help='Name of model to train')
parser.add_argument('--nas-config', type=str, default='configs/deit_small_nxm_ea124_9.0M.yml', help='configuration for supernet training')
parser.add_argument('--nas-mode', action='store_true', default=True)
parser.add_argument('--nas-weights', default='weights/nas_pretrained.pth', help='load pretrained supernet weight')
parser.add_argument('--wandb', action='store_true')
parser.add_argument('--output_dir', default='result',
help='path where to save, empty for no saving')
return parser
@ -289,6 +292,14 @@ def main(args):
img_size=args.input_size
)
if args.wandb:
wandb.init(project='sparsity')
# load nas pretrained weight
if args.nas_weights:
state_dict = torch.load(args.nas_weights)
model.load_state_dict(state_dict['model'], strict=True)
if args.finetune:
if args.finetune.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url(
@ -358,11 +369,23 @@ def main(args):
device='cpu' if args.model_ema_force_cpu else '',
resume='')
## Add ASP to model (apex)
one_ll = model.blocks[0].attn.proj.weight
ASP.init_model_for_pruning(model, "m4n2_1d", whitelist=[torch.nn.Linear, torch.nn.Conv2d], allow_recompute_mask=True)
# ASP.init_optimizer_for_pruning(optimizer)
print("DENSE :: ", one_ll)
ASP.compute_sparse_masks()
print("SPARSE :: ", one_ll)
return
model_without_ddp = model
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module
print(model_without_ddp)
if args.nas_mode:
smallest_config = []
for ratios in nas_config['sparsity']['choices']:
@ -471,7 +494,13 @@ def main(args):
test_stats = evaluate(data_loader_val, model, device)
print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
if wandb and wandb.run:
wandb.log({**{f'train_{k}': v for k, v in train_stats.items()},
**{f'test_{k}': v for k, v in test_stats.items()},
'epoch': epoch,
'n_parameters': n_parameters})
if max_accuracy < test_stats["acc1"]:
max_accuracy = test_stats["acc1"]