add nas_test_config

This commit is contained in:
max410011 2023-06-26 17:15:31 +00:00
parent e3de89e879
commit 4dda0672ba
2 changed files with 18 additions and 16 deletions

View File

@ -69,19 +69,24 @@ def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss,
@torch.no_grad()
def evaluate(nas_config, data_loader, model, device, args = None):
def evaluate(nas_config, nas_test_config, data_loader, model, device, args = None):
criterion = torch.nn.CrossEntropyLoss()
metric_logger = utils.MetricLogger(delimiter=" ")
header = 'Test:'
if args.nas_mode:
# Sample the smallest subnetwork to test accuracy
smallest_config = []
# Sample the subnet to test accuracy
test_config = []
for ratios in nas_config['sparsity']['choices']:
smallest_config.append(ratios[0])
# smallest_config.append([1, 3])
model.module.set_sample_config(smallest_config)
if nas_test_config in ratios:
test_config.append(nas_test_config)
else:
# choose smallest_config
test_config.append(ratios[0])
print(f'Test config {nas_test_config} is not in the choices, choose smallest config {ratios[0]}')
model.module.set_sample_config(test_config)
# switch to evaluation mode
model.eval()

17
main.py
View File

@ -149,7 +149,7 @@ def get_args_parser():
# parser.add_argument('--distillation-type', default='none', choices=['none', 'soft', 'hard'], type=str, help="")
parser.add_argument('--distillation-type', default='none', choices=['none', 'soft', 'soft_fd'], type=str, help="")
# parser.add_argument('--distillation-alpha', default=0.5, type=float, help="")
parser.add_argument('--distillation-alpha', default=0.0, type=float, help="")
parser.add_argument('--distillation-alpha', default=1.0, type=float, help="")
parser.add_argument('--distillation-tau', default=1.0, type=float, help="")
parser.add_argument('--distillation-gamma', default=0.1, type=float,
help="coefficient for hidden distillation loss, we set it to be 0.1 by aligning MiniViT")
@ -159,7 +159,7 @@ def get_args_parser():
parser.add_argument('--attn-only', action='store_true')
# Dataset parameters
parser.add_argument('--data-path', default='/dev/shm/imagenet', type=str, help='dataset path')
parser.add_argument('--data-path', default='/dataset/imagenet', type=str, help='dataset path')
parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET', 'INAT', 'INAT19'],
type=str, help='Image Net dataset path')
parser.add_argument('--inat-category', default='name',
@ -187,17 +187,12 @@ def get_args_parser():
# Sparsity Training Related Flag
# timm == 0.4.12
# python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --epochs 150 --output_dir result_nas_1:4_150epoch_repeat
# python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --epochs 150 --nas-config configs/deit_small_nxm_nas_124+13.yaml --output_dir result_nas_124+13_150epoch
# python -m torch.distributed.launch --nproc_per_node=8 --use_env --master_port 29500 main.py --nas-config configs/deit_small_nxm_uniform14.yaml --epochs 50 --output_dir result_sub_1:4_50epoch
# python -m torch.distributed.launch --nproc_per_node=8 --use_env --master_port 29501 main.py --nas-config configs/deit_small_nxm_uniform24.yaml --epochs 50 --output_dir result_sub_2:4_50epoch
# python -m torch.distributed.launch --nproc_per_node=8 --use_env --master_port 29501 main.py --nas-config configs/deit_small_nxm_uniform14.yaml --eval
# python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --epochs 150 --nas-config configs/deit_small_nxm_nas_124+13.yaml --output_dir twined_nas_124+13_150epoch
parser.add_argument('--model', default='Sparse_deit_small_patch16_224', type=str, metavar='MODEL',
help='Name of model to train')
parser.add_argument('--pretrained', action='store_true', help='Use pretrained model')
parser.add_argument('--nas-config', type=str, default=None, help='configuration for supernet training')
parser.add_argument('--nas-mode', action='store_true', default=False)
parser.add_argument('--nas-test-config', type=int, nargs='+', default=None, help='Use test config to eval accuracy')
# parser.add_argument('--nas-weights', default='weights/nas_pretrained.pth', help='load pretrained supernet weight')
# parser.add_argument('--nas-weights', default='result_nas_1:4_150epoch/checkpoint.pth', help='load pretrained supernet weight')
# parser.add_argument('--nas-weights', default='result_sub_1:4_50epoch/best_checkpoint.pth', help='load pretrained supernet weight')
@ -483,8 +478,10 @@ def main(args):
if 'scaler' in checkpoint:
loss_scaler.load_state_dict(checkpoint['scaler'])
lr_scheduler.step(args.start_epoch)
# Evaluate only
if args.eval:
test_stats = evaluate(nas_config, data_loader_val, model, device, args)
test_stats = evaluate(nas_config, args.nas_test_config, data_loader_val, model, device, args)
print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
return
@ -517,7 +514,7 @@ def main(args):
}, checkpoint_path)
test_stats = evaluate(nas_config, data_loader_val, model, device, args)
test_stats = evaluate(nas_config, args.nas_test_config, data_loader_val, model, device, args)
print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")