# Copyright (c) OpenMMLab. All rights reserved. import argparse import os import os.path as osp import time import warnings import mmcv import torch from mmcls.apis import set_random_seed from mmcls.models import build_classifier from mmcv import DictAction from mmcv.parallel import MMDataParallel, MMDistributedDataParallel from mmcv.runner import (get_dist_info, init_dist, load_checkpoint, wrap_fp16_model) from mmfewshot.classification.apis import (multi_gpu_meta_test, single_gpu_meta_test) from mmfewshot.classification.datasets import (build_dataset, build_meta_test_dataloader) from mmfewshot.utils import get_root_logger def parse_args(): parser = argparse.ArgumentParser(description='mmcls test model') parser.add_argument('config', help='test config file path') parser.add_argument('checkpoint', help='checkpoint file') parser.add_argument('--work-dir', help='the dir to save logs and models') parser.add_argument( '--metrics', type=str, nargs='+', help='evaluation metrics, which depends on the dataset, e.g., ' '"accuracy", "precision", "recall", "f1_score", "support" for single ' 'label dataset, and "mAP", "CP", "CR", "CF1", "OP", "OR", "OF1" for ' 'multi-label dataset') parser.add_argument( '--show-task-results', action='store_true', help='whether to show results of each task') parser.add_argument('--tmpdir', help='tmp dir for writing some results') parser.add_argument( '--options', nargs='+', action=DictAction, help='override some settings in the used config, the key-value pair ' 'in xxx=yyy format will be merged into config file.') parser.add_argument( '--metric-options', nargs='+', action=DictAction, default={}, help='custom options for evaluation, the key-value pair in xxx=yyy ' 'format will be parsed as a dict metric_options for dataset.evaluate()' ' function.') parser.add_argument( '--launcher', choices=['none', 'pytorch', 'slurm', 'mpi'], default='none', help='job launcher') parser.add_argument('--seed', type=int, default=None, help='random seed') parser.add_argument( '--deterministic', action='store_true', help='whether to set deterministic options for CUDNN backend.') parser.add_argument('--local_rank', type=int, default=0) parser.add_argument( '--device', default=None, help='device used for testing. (Deprecated)') parser.add_argument( '--gpu-ids', type=int, nargs='+', help='(Deprecated, please use --gpu-id) ids of gpus to use ' '(only applicable to non-distributed testing)') parser.add_argument( '--gpu-id', type=int, default=0, help='id of gpu to use ' '(only applicable to non-distributed testing)') parser.add_argument( '--show_task_results', action='store_true', help='whether to record eval result of each task.') args = parser.parse_args() if 'LOCAL_RANK' not in os.environ: os.environ['LOCAL_RANK'] = str(args.local_rank) if args.device: warnings.warn( '--device is deprecated. To use cpu to test, please ' 'refers to https://mmclassification.readthedocs.io/en/latest/' 'getting_started.html#inference-with-pretrained-models') assert args.metrics or args.out, \ 'Please specify at least one of output path and evaluation metrics.' return args def main(): args = parse_args() cfg = mmcv.Config.fromfile(args.config) if args.options is not None: cfg.merge_from_dict(args.options) # set cudnn_benchmark if cfg.get('cudnn_benchmark', False): torch.backends.cudnn.benchmark = True # work_dir is determined in this priority: CLI > segment in file > filename if args.work_dir is not None: # update configs according to CLI args if args.work_dir is not None cfg.work_dir = args.work_dir elif cfg.get('work_dir', None) is None: # use config filename as default work_dir if cfg.work_dir is None cfg.work_dir = osp.join('./work_dirs', osp.splitext(osp.basename(args.config))[0]) if args.gpu_ids is not None: cfg.gpu_ids = args.gpu_ids[0:1] warnings.warn('`--gpu-ids` is deprecated, please use `--gpu-id`. ' 'Because we only support single GPU mode in ' 'non-distributed testing. Use the first GPU ' 'in `gpu_ids` now.') else: cfg.gpu_ids = [args.gpu_id] # init distributed env first, since logger depends on the dist info. if args.launcher == 'none': distributed = False else: distributed = True init_dist(args.launcher, **cfg.dist_params) rank, _ = get_dist_info() # create work_dir mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir)) # dump config cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config))) # init the logger before other steps timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) log_file = osp.join(cfg.work_dir, f'{timestamp}_test.log') logger = get_root_logger(log_file=log_file, log_level=cfg.log_level) # set random seeds if args.seed is None: args.seed = 0 logger.info(f'Set random seed to {args.seed}, ' f'deterministic: {args.deterministic}') set_random_seed(args.seed, deterministic=args.deterministic) dataset = build_dataset(cfg.data.test) meta_test_cfg = cfg.data.test.meta_test_cfg (support_data_loader, query_data_loader, all_data_loader) = build_meta_test_dataloader(dataset, meta_test_cfg) # build the model and load checkpoint model = build_classifier(cfg.model) fp16_cfg = cfg.get('fp16', None) if fp16_cfg is not None: wrap_fp16_model(model) if rank == 0: logger.info(f'load from checkpoint: {args.checkpoint} ') load_checkpoint(model, args.checkpoint, map_location='cpu') if not distributed: if args.device == 'cpu': model = model.cpu() else: model = MMDataParallel(model, device_ids=[0]) meta_eval_results = single_gpu_meta_test( model, meta_test_cfg.num_episodes, support_data_loader, query_data_loader, all_data_loader, meta_test_cfg=meta_test_cfg, logger=logger, eval_kwargs=dict(metric=cfg.evaluation.metric), show_task_results=args.show_task_results) else: model = MMDistributedDataParallel( model.cuda(), device_ids=[torch.cuda.current_device()], broadcast_buffers=False) meta_eval_results = multi_gpu_meta_test( model, meta_test_cfg.num_episodes, support_data_loader, query_data_loader, all_data_loader, meta_test_cfg=meta_test_cfg, logger=logger, eval_kwargs=dict(metric=cfg.evaluation.metric), show_task_results=args.show_task_results) if rank == 0: logger.info(f'Checkpoint: {args.checkpoint}') for k, v in meta_eval_results.items(): logger.info(f'{k} : {v:.2f}') if __name__ == '__main__': main()