diff --git a/engine.py b/engine.py index c68a6ad..bfd27c8 100644 --- a/engine.py +++ b/engine.py @@ -69,19 +69,24 @@ def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss, @torch.no_grad() -def evaluate(nas_config, data_loader, model, device, args = None): +def evaluate(nas_config, nas_test_config, data_loader, model, device, args = None): criterion = torch.nn.CrossEntropyLoss() metric_logger = utils.MetricLogger(delimiter=" ") header = 'Test:' if args.nas_mode: - # Sample the smallest subnetwork to test accuracy - smallest_config = [] + # Sample the subnet to test accuracy + test_config = [] for ratios in nas_config['sparsity']['choices']: - smallest_config.append(ratios[0]) - # smallest_config.append([1, 3]) - model.module.set_sample_config(smallest_config) + if nas_test_config in ratios: + test_config.append(nas_test_config) + else: + # choose smallest_config + test_config.append(ratios[0]) + print(f'Test config {nas_test_config} is not in the choices, choose smallest config {ratios[0]}') + + model.module.set_sample_config(test_config) # switch to evaluation mode model.eval() diff --git a/main.py b/main.py index 7114694..3e7aace 100644 --- a/main.py +++ b/main.py @@ -149,7 +149,7 @@ def get_args_parser(): # parser.add_argument('--distillation-type', default='none', choices=['none', 'soft', 'hard'], type=str, help="") parser.add_argument('--distillation-type', default='none', choices=['none', 'soft', 'soft_fd'], type=str, help="") # parser.add_argument('--distillation-alpha', default=0.5, type=float, help="") - parser.add_argument('--distillation-alpha', default=0.0, type=float, help="") + parser.add_argument('--distillation-alpha', default=1.0, type=float, help="") parser.add_argument('--distillation-tau', default=1.0, type=float, help="") parser.add_argument('--distillation-gamma', default=0.1, type=float, help="coefficient for hidden distillation loss, we set it to be 0.1 by aligning MiniViT") @@ -159,7 +159,7 @@ def get_args_parser(): parser.add_argument('--attn-only', action='store_true') # Dataset parameters - parser.add_argument('--data-path', default='/dev/shm/imagenet', type=str, help='dataset path') + parser.add_argument('--data-path', default='/dataset/imagenet', type=str, help='dataset path') parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET', 'INAT', 'INAT19'], type=str, help='Image Net dataset path') parser.add_argument('--inat-category', default='name', @@ -187,17 +187,12 @@ def get_args_parser(): # Sparsity Training Related Flag # timm == 0.4.12 - # python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --epochs 150 --output_dir result_nas_1:4_150epoch_repeat - # python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --epochs 150 --nas-config configs/deit_small_nxm_nas_124+13.yaml --output_dir result_nas_124+13_150epoch - # python -m torch.distributed.launch --nproc_per_node=8 --use_env --master_port 29500 main.py --nas-config configs/deit_small_nxm_uniform14.yaml --epochs 50 --output_dir result_sub_1:4_50epoch - # python -m torch.distributed.launch --nproc_per_node=8 --use_env --master_port 29501 main.py --nas-config configs/deit_small_nxm_uniform24.yaml --epochs 50 --output_dir result_sub_2:4_50epoch - # python -m torch.distributed.launch --nproc_per_node=8 --use_env --master_port 29501 main.py --nas-config configs/deit_small_nxm_uniform14.yaml --eval - # python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --epochs 150 --nas-config configs/deit_small_nxm_nas_124+13.yaml --output_dir twined_nas_124+13_150epoch parser.add_argument('--model', default='Sparse_deit_small_patch16_224', type=str, metavar='MODEL', help='Name of model to train') parser.add_argument('--pretrained', action='store_true', help='Use pretrained model') parser.add_argument('--nas-config', type=str, default=None, help='configuration for supernet training') parser.add_argument('--nas-mode', action='store_true', default=False) + parser.add_argument('--nas-test-config', type=int, nargs='+', default=None, help='Use test config to eval accuracy') # parser.add_argument('--nas-weights', default='weights/nas_pretrained.pth', help='load pretrained supernet weight') # parser.add_argument('--nas-weights', default='result_nas_1:4_150epoch/checkpoint.pth', help='load pretrained supernet weight') # parser.add_argument('--nas-weights', default='result_sub_1:4_50epoch/best_checkpoint.pth', help='load pretrained supernet weight') @@ -483,8 +478,10 @@ def main(args): if 'scaler' in checkpoint: loss_scaler.load_state_dict(checkpoint['scaler']) lr_scheduler.step(args.start_epoch) + + # Evaluate only if args.eval: - test_stats = evaluate(nas_config, data_loader_val, model, device, args) + test_stats = evaluate(nas_config, args.nas_test_config, data_loader_val, model, device, args) print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") return @@ -517,7 +514,7 @@ def main(args): }, checkpoint_path) - test_stats = evaluate(nas_config, data_loader_val, model, device, args) + test_stats = evaluate(nas_config, args.nas_test_config, data_loader_val, model, device, args) print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")