mirror of https://github.com/facebookresearch/deit
[fix, NFC] fix evaluation bug when running in non-nas mode. Check if nas is enable when running related command
parent
43d74ab8b8
commit
1868bd2aca
15
engine.py
15
engine.py
|
@ -69,18 +69,19 @@ def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss,
|
|||
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate(nas_config, data_loader, model, device):
|
||||
def evaluate(nas_config, data_loader, model, device, args = None):
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
|
||||
metric_logger = utils.MetricLogger(delimiter=" ")
|
||||
header = 'Test:'
|
||||
|
||||
# Sample the smallest subnetwork to test accuracy
|
||||
smallest_config = []
|
||||
for ratios in nas_config['sparsity']['choices']:
|
||||
smallest_config.append(ratios[0])
|
||||
# smallest_config.append([1, 3])
|
||||
model.module.set_sample_config(smallest_config)
|
||||
if args.nas_mode:
|
||||
# Sample the smallest subnetwork to test accuracy
|
||||
smallest_config = []
|
||||
for ratios in nas_config['sparsity']['choices']:
|
||||
smallest_config.append(ratios[0])
|
||||
# smallest_config.append([1, 3])
|
||||
model.module.set_sample_config(smallest_config)
|
||||
|
||||
# switch to evaluation mode
|
||||
model.eval()
|
||||
|
|
8
main.py
8
main.py
|
@ -199,7 +199,7 @@ def get_args_parser():
|
|||
parser.add_argument('--model', default='Sparse_deit_small_patch16_224', type=str, metavar='MODEL',
|
||||
help='Name of model to train')
|
||||
parser.add_argument('--nas-config', type=str, default='configs/deit_small_nxm_nas_124.yaml', help='configuration for supernet training')
|
||||
parser.add_argument('--nas-mode', action='store_true', default=True)
|
||||
parser.add_argument('--nas-mode', action='store_true', default=False)
|
||||
# parser.add_argument('--nas-weights', default='weights/nas_pretrained.pth', help='load pretrained supernet weight')
|
||||
# parser.add_argument('--nas-weights', default='result_nas_1:4_150epoch/checkpoint.pth', help='load pretrained supernet weight')
|
||||
# parser.add_argument('--nas-weights', default='result_sub_1:4_50epoch/best_checkpoint.pth', help='load pretrained supernet weight')
|
||||
|
@ -307,7 +307,7 @@ def main(args):
|
|||
print(f"Creating model: {args.model}")
|
||||
model = create_model(
|
||||
args.model,
|
||||
pretrained=True,
|
||||
pretrained=False,
|
||||
num_classes=args.nb_classes,
|
||||
drop_rate=args.drop,
|
||||
drop_path_rate=args.drop_path,
|
||||
|
@ -481,7 +481,7 @@ def main(args):
|
|||
loss_scaler.load_state_dict(checkpoint['scaler'])
|
||||
lr_scheduler.step(args.start_epoch)
|
||||
if args.eval:
|
||||
test_stats = evaluate(nas_config, data_loader_val, model, device)
|
||||
test_stats = evaluate(nas_config, data_loader_val, model, device, args)
|
||||
print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
|
||||
return
|
||||
|
||||
|
@ -514,7 +514,7 @@ def main(args):
|
|||
}, checkpoint_path)
|
||||
|
||||
|
||||
test_stats = evaluate(nas_config, data_loader_val, model, device)
|
||||
test_stats = evaluate(nas_config, data_loader_val, model, device, args)
|
||||
|
||||
print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
|
||||
|
||||
|
|
Loading…
Reference in New Issue