[fix, NFC] fix evaluation bug when running in non-nas mode. Check if nas is enable when running related command

pull/227/head
brian1009 2023-06-17 12:09:16 +08:00
parent 43d74ab8b8
commit 1868bd2aca
2 changed files with 12 additions and 11 deletions

View File

@ -69,18 +69,19 @@ def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss,
@torch.no_grad()
def evaluate(nas_config, data_loader, model, device):
def evaluate(nas_config, data_loader, model, device, args = None):
criterion = torch.nn.CrossEntropyLoss()
metric_logger = utils.MetricLogger(delimiter=" ")
header = 'Test:'
# Sample the smallest subnetwork to test accuracy
smallest_config = []
for ratios in nas_config['sparsity']['choices']:
smallest_config.append(ratios[0])
# smallest_config.append([1, 3])
model.module.set_sample_config(smallest_config)
if args.nas_mode:
# Sample the smallest subnetwork to test accuracy
smallest_config = []
for ratios in nas_config['sparsity']['choices']:
smallest_config.append(ratios[0])
# smallest_config.append([1, 3])
model.module.set_sample_config(smallest_config)
# switch to evaluation mode
model.eval()

View File

@ -199,7 +199,7 @@ def get_args_parser():
parser.add_argument('--model', default='Sparse_deit_small_patch16_224', type=str, metavar='MODEL',
help='Name of model to train')
parser.add_argument('--nas-config', type=str, default='configs/deit_small_nxm_nas_124.yaml', help='configuration for supernet training')
parser.add_argument('--nas-mode', action='store_true', default=True)
parser.add_argument('--nas-mode', action='store_true', default=False)
# parser.add_argument('--nas-weights', default='weights/nas_pretrained.pth', help='load pretrained supernet weight')
# parser.add_argument('--nas-weights', default='result_nas_1:4_150epoch/checkpoint.pth', help='load pretrained supernet weight')
# parser.add_argument('--nas-weights', default='result_sub_1:4_50epoch/best_checkpoint.pth', help='load pretrained supernet weight')
@ -307,7 +307,7 @@ def main(args):
print(f"Creating model: {args.model}")
model = create_model(
args.model,
pretrained=True,
pretrained=False,
num_classes=args.nb_classes,
drop_rate=args.drop,
drop_path_rate=args.drop_path,
@ -481,7 +481,7 @@ def main(args):
loss_scaler.load_state_dict(checkpoint['scaler'])
lr_scheduler.step(args.start_epoch)
if args.eval:
test_stats = evaluate(nas_config, data_loader_val, model, device)
test_stats = evaluate(nas_config, data_loader_val, model, device, args)
print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
return
@ -514,7 +514,7 @@ def main(args):
}, checkpoint_path)
test_stats = evaluate(nas_config, data_loader_val, model, device)
test_stats = evaluate(nas_config, data_loader_val, model, device, args)
print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")