mirror of
https://github.com/facebookresearch/deit.git
synced 2025-06-03 14:52:20 +08:00
add nas_test_config
This commit is contained in:
parent
e3de89e879
commit
4dda0672ba
17
engine.py
17
engine.py
@ -69,19 +69,24 @@ def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss,
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate(nas_config, data_loader, model, device, args = None):
|
||||
def evaluate(nas_config, nas_test_config, data_loader, model, device, args = None):
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
|
||||
metric_logger = utils.MetricLogger(delimiter=" ")
|
||||
header = 'Test:'
|
||||
|
||||
if args.nas_mode:
|
||||
# Sample the smallest subnetwork to test accuracy
|
||||
smallest_config = []
|
||||
# Sample the subnet to test accuracy
|
||||
test_config = []
|
||||
for ratios in nas_config['sparsity']['choices']:
|
||||
smallest_config.append(ratios[0])
|
||||
# smallest_config.append([1, 3])
|
||||
model.module.set_sample_config(smallest_config)
|
||||
if nas_test_config in ratios:
|
||||
test_config.append(nas_test_config)
|
||||
else:
|
||||
# choose smallest_config
|
||||
test_config.append(ratios[0])
|
||||
print(f'Test config {nas_test_config} is not in the choices, choose smallest config {ratios[0]}')
|
||||
|
||||
model.module.set_sample_config(test_config)
|
||||
|
||||
# switch to evaluation mode
|
||||
model.eval()
|
||||
|
17
main.py
17
main.py
@ -149,7 +149,7 @@ def get_args_parser():
|
||||
# parser.add_argument('--distillation-type', default='none', choices=['none', 'soft', 'hard'], type=str, help="")
|
||||
parser.add_argument('--distillation-type', default='none', choices=['none', 'soft', 'soft_fd'], type=str, help="")
|
||||
# parser.add_argument('--distillation-alpha', default=0.5, type=float, help="")
|
||||
parser.add_argument('--distillation-alpha', default=0.0, type=float, help="")
|
||||
parser.add_argument('--distillation-alpha', default=1.0, type=float, help="")
|
||||
parser.add_argument('--distillation-tau', default=1.0, type=float, help="")
|
||||
parser.add_argument('--distillation-gamma', default=0.1, type=float,
|
||||
help="coefficient for hidden distillation loss, we set it to be 0.1 by aligning MiniViT")
|
||||
@ -159,7 +159,7 @@ def get_args_parser():
|
||||
parser.add_argument('--attn-only', action='store_true')
|
||||
|
||||
# Dataset parameters
|
||||
parser.add_argument('--data-path', default='/dev/shm/imagenet', type=str, help='dataset path')
|
||||
parser.add_argument('--data-path', default='/dataset/imagenet', type=str, help='dataset path')
|
||||
parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET', 'INAT', 'INAT19'],
|
||||
type=str, help='Image Net dataset path')
|
||||
parser.add_argument('--inat-category', default='name',
|
||||
@ -187,17 +187,12 @@ def get_args_parser():
|
||||
|
||||
# Sparsity Training Related Flag
|
||||
# timm == 0.4.12
|
||||
# python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --epochs 150 --output_dir result_nas_1:4_150epoch_repeat
|
||||
# python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --epochs 150 --nas-config configs/deit_small_nxm_nas_124+13.yaml --output_dir result_nas_124+13_150epoch
|
||||
# python -m torch.distributed.launch --nproc_per_node=8 --use_env --master_port 29500 main.py --nas-config configs/deit_small_nxm_uniform14.yaml --epochs 50 --output_dir result_sub_1:4_50epoch
|
||||
# python -m torch.distributed.launch --nproc_per_node=8 --use_env --master_port 29501 main.py --nas-config configs/deit_small_nxm_uniform24.yaml --epochs 50 --output_dir result_sub_2:4_50epoch
|
||||
# python -m torch.distributed.launch --nproc_per_node=8 --use_env --master_port 29501 main.py --nas-config configs/deit_small_nxm_uniform14.yaml --eval
|
||||
# python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --epochs 150 --nas-config configs/deit_small_nxm_nas_124+13.yaml --output_dir twined_nas_124+13_150epoch
|
||||
parser.add_argument('--model', default='Sparse_deit_small_patch16_224', type=str, metavar='MODEL',
|
||||
help='Name of model to train')
|
||||
parser.add_argument('--pretrained', action='store_true', help='Use pretrained model')
|
||||
parser.add_argument('--nas-config', type=str, default=None, help='configuration for supernet training')
|
||||
parser.add_argument('--nas-mode', action='store_true', default=False)
|
||||
parser.add_argument('--nas-test-config', type=int, nargs='+', default=None, help='Use test config to eval accuracy')
|
||||
# parser.add_argument('--nas-weights', default='weights/nas_pretrained.pth', help='load pretrained supernet weight')
|
||||
# parser.add_argument('--nas-weights', default='result_nas_1:4_150epoch/checkpoint.pth', help='load pretrained supernet weight')
|
||||
# parser.add_argument('--nas-weights', default='result_sub_1:4_50epoch/best_checkpoint.pth', help='load pretrained supernet weight')
|
||||
@ -483,8 +478,10 @@ def main(args):
|
||||
if 'scaler' in checkpoint:
|
||||
loss_scaler.load_state_dict(checkpoint['scaler'])
|
||||
lr_scheduler.step(args.start_epoch)
|
||||
|
||||
# Evaluate only
|
||||
if args.eval:
|
||||
test_stats = evaluate(nas_config, data_loader_val, model, device, args)
|
||||
test_stats = evaluate(nas_config, args.nas_test_config, data_loader_val, model, device, args)
|
||||
print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
|
||||
return
|
||||
|
||||
@ -517,7 +514,7 @@ def main(args):
|
||||
}, checkpoint_path)
|
||||
|
||||
|
||||
test_stats = evaluate(nas_config, data_loader_val, model, device, args)
|
||||
test_stats = evaluate(nas_config, args.nas_test_config, data_loader_val, model, device, args)
|
||||
|
||||
print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user