From 1868bd2acabf62845795a1cc883cd83acce2c653 Mon Sep 17 00:00:00 2001 From: brian1009 Date: Sat, 17 Jun 2023 12:09:16 +0800 Subject: [PATCH] [fix, NFC] fix evaluation bug when running in non-nas mode. Check if nas is enable when running related command --- engine.py | 15 ++++++++------- main.py | 8 ++++---- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/engine.py b/engine.py index 977f566..c68a6ad 100644 --- a/engine.py +++ b/engine.py @@ -69,18 +69,19 @@ def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss, @torch.no_grad() -def evaluate(nas_config, data_loader, model, device): +def evaluate(nas_config, data_loader, model, device, args = None): criterion = torch.nn.CrossEntropyLoss() metric_logger = utils.MetricLogger(delimiter=" ") header = 'Test:' - # Sample the smallest subnetwork to test accuracy - smallest_config = [] - for ratios in nas_config['sparsity']['choices']: - smallest_config.append(ratios[0]) - # smallest_config.append([1, 3]) - model.module.set_sample_config(smallest_config) + if args.nas_mode: + # Sample the smallest subnetwork to test accuracy + smallest_config = [] + for ratios in nas_config['sparsity']['choices']: + smallest_config.append(ratios[0]) + # smallest_config.append([1, 3]) + model.module.set_sample_config(smallest_config) # switch to evaluation mode model.eval() diff --git a/main.py b/main.py index 0df55f2..8c26d8c 100644 --- a/main.py +++ b/main.py @@ -199,7 +199,7 @@ def get_args_parser(): parser.add_argument('--model', default='Sparse_deit_small_patch16_224', type=str, metavar='MODEL', help='Name of model to train') parser.add_argument('--nas-config', type=str, default='configs/deit_small_nxm_nas_124.yaml', help='configuration for supernet training') - parser.add_argument('--nas-mode', action='store_true', default=True) + parser.add_argument('--nas-mode', action='store_true', default=False) # parser.add_argument('--nas-weights', default='weights/nas_pretrained.pth', help='load pretrained supernet weight') # parser.add_argument('--nas-weights', default='result_nas_1:4_150epoch/checkpoint.pth', help='load pretrained supernet weight') # parser.add_argument('--nas-weights', default='result_sub_1:4_50epoch/best_checkpoint.pth', help='load pretrained supernet weight') @@ -307,7 +307,7 @@ def main(args): print(f"Creating model: {args.model}") model = create_model( args.model, - pretrained=True, + pretrained=False, num_classes=args.nb_classes, drop_rate=args.drop, drop_path_rate=args.drop_path, @@ -481,7 +481,7 @@ def main(args): loss_scaler.load_state_dict(checkpoint['scaler']) lr_scheduler.step(args.start_epoch) if args.eval: - test_stats = evaluate(nas_config, data_loader_val, model, device) + test_stats = evaluate(nas_config, data_loader_val, model, device, args) print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") return @@ -514,7 +514,7 @@ def main(args): }, checkpoint_path) - test_stats = evaluate(nas_config, data_loader_val, model, device) + test_stats = evaluate(nas_config, data_loader_val, model, device, args) print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")