mirror of
https://github.com/facebookresearch/deit.git
synced 2025-06-03 14:52:20 +08:00
[fix, NFC] fix evaluation bug when running in non-nas mode. Check if nas is enable when running related command
This commit is contained in:
parent
43d74ab8b8
commit
1868bd2aca
@ -69,12 +69,13 @@ def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss,
|
|||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def evaluate(nas_config, data_loader, model, device):
|
def evaluate(nas_config, data_loader, model, device, args = None):
|
||||||
criterion = torch.nn.CrossEntropyLoss()
|
criterion = torch.nn.CrossEntropyLoss()
|
||||||
|
|
||||||
metric_logger = utils.MetricLogger(delimiter=" ")
|
metric_logger = utils.MetricLogger(delimiter=" ")
|
||||||
header = 'Test:'
|
header = 'Test:'
|
||||||
|
|
||||||
|
if args.nas_mode:
|
||||||
# Sample the smallest subnetwork to test accuracy
|
# Sample the smallest subnetwork to test accuracy
|
||||||
smallest_config = []
|
smallest_config = []
|
||||||
for ratios in nas_config['sparsity']['choices']:
|
for ratios in nas_config['sparsity']['choices']:
|
||||||
|
8
main.py
8
main.py
@ -199,7 +199,7 @@ def get_args_parser():
|
|||||||
parser.add_argument('--model', default='Sparse_deit_small_patch16_224', type=str, metavar='MODEL',
|
parser.add_argument('--model', default='Sparse_deit_small_patch16_224', type=str, metavar='MODEL',
|
||||||
help='Name of model to train')
|
help='Name of model to train')
|
||||||
parser.add_argument('--nas-config', type=str, default='configs/deit_small_nxm_nas_124.yaml', help='configuration for supernet training')
|
parser.add_argument('--nas-config', type=str, default='configs/deit_small_nxm_nas_124.yaml', help='configuration for supernet training')
|
||||||
parser.add_argument('--nas-mode', action='store_true', default=True)
|
parser.add_argument('--nas-mode', action='store_true', default=False)
|
||||||
# parser.add_argument('--nas-weights', default='weights/nas_pretrained.pth', help='load pretrained supernet weight')
|
# parser.add_argument('--nas-weights', default='weights/nas_pretrained.pth', help='load pretrained supernet weight')
|
||||||
# parser.add_argument('--nas-weights', default='result_nas_1:4_150epoch/checkpoint.pth', help='load pretrained supernet weight')
|
# parser.add_argument('--nas-weights', default='result_nas_1:4_150epoch/checkpoint.pth', help='load pretrained supernet weight')
|
||||||
# parser.add_argument('--nas-weights', default='result_sub_1:4_50epoch/best_checkpoint.pth', help='load pretrained supernet weight')
|
# parser.add_argument('--nas-weights', default='result_sub_1:4_50epoch/best_checkpoint.pth', help='load pretrained supernet weight')
|
||||||
@ -307,7 +307,7 @@ def main(args):
|
|||||||
print(f"Creating model: {args.model}")
|
print(f"Creating model: {args.model}")
|
||||||
model = create_model(
|
model = create_model(
|
||||||
args.model,
|
args.model,
|
||||||
pretrained=True,
|
pretrained=False,
|
||||||
num_classes=args.nb_classes,
|
num_classes=args.nb_classes,
|
||||||
drop_rate=args.drop,
|
drop_rate=args.drop,
|
||||||
drop_path_rate=args.drop_path,
|
drop_path_rate=args.drop_path,
|
||||||
@ -481,7 +481,7 @@ def main(args):
|
|||||||
loss_scaler.load_state_dict(checkpoint['scaler'])
|
loss_scaler.load_state_dict(checkpoint['scaler'])
|
||||||
lr_scheduler.step(args.start_epoch)
|
lr_scheduler.step(args.start_epoch)
|
||||||
if args.eval:
|
if args.eval:
|
||||||
test_stats = evaluate(nas_config, data_loader_val, model, device)
|
test_stats = evaluate(nas_config, data_loader_val, model, device, args)
|
||||||
print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
|
print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -514,7 +514,7 @@ def main(args):
|
|||||||
}, checkpoint_path)
|
}, checkpoint_path)
|
||||||
|
|
||||||
|
|
||||||
test_stats = evaluate(nas_config, data_loader_val, model, device)
|
test_stats = evaluate(nas_config, data_loader_val, model, device, args)
|
||||||
|
|
||||||
print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
|
print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user