# Copyright (c) Alibaba, Inc. and its affiliates. import argparse import importlib import os import os.path as osp import sys import time import mmcv import torch from mmcv.parallel import MMDataParallel, MMDistributedDataParallel from mmcv.runner import get_dist_info, init_dist, load_checkpoint from easycv.core.evaluation.builder import build_evaluator from easycv.datasets import build_dataloader, build_dataset from easycv.models import build_model from easycv.utils import (dist_forward_collect, get_root_logger, nondist_forward_collect, traverse_replace) # from mmcv import Config from easycv.utils.config_tools import mmcv_config_fromfile sys.path.append(os.path.abspath(os.path.dirname(os.path.dirname(__file__)))) sys.path.append( os.path.abspath( osp.join(os.path.dirname(os.path.dirname(__file__)), '../'))) def single_gpu_test(model, data_loader): model.eval() func = lambda **x: model(mode='test', **x) results = nondist_forward_collect(func, data_loader, len(data_loader.dataset)) return results def multi_gpu_test(model, data_loader): model.eval() func = lambda **x: model(mode='test', **x) rank, world_size = get_dist_info() results = dist_forward_collect(func, data_loader, rank, len(data_loader.dataset)) return results def parse_args(): parser = argparse.ArgumentParser( description='MMDet test (and eval) a model') parser.add_argument('config', help='test config file path') parser.add_argument('checkpoint', help='checkpoint file') parser.add_argument( '--work_dir', type=str, default=None, help='the dir to save logs and models') parser.add_argument( '--launcher', choices=['none', 'pytorch', 'slurm', 'mpi'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) parser.add_argument( '--port', type=int, default=29500, help='port only works when launcher=="slurm"') parser.add_argument( '--model_type', choices=['classification', 'pose'], default='classification', help='model type') args = parser.parse_args() if 'LOCAL_RANK' not in os.environ: os.environ['LOCAL_RANK'] = str(args.local_rank) return args def main(): args = parse_args() cfg = mmcv_config_fromfile(args.config) # set cudnn_benchmark if cfg.get('cudnn_benchmark', False): torch.backends.cudnn.benchmark = True # update configs according to CLI args if args.work_dir is not None: cfg.work_dir = args.work_dir cfg.model.pretrained = None # ensure to use checkpoint rather than pretraining # check memcached package exists if importlib.util.find_spec('mc') is None: traverse_replace(cfg, 'memcached', False) # init distributed env first, since logger depends on the dist info. if args.launcher == 'none': distributed = False else: distributed = True if args.launcher == 'slurm': cfg.dist_params['port'] = args.port init_dist(args.launcher, **cfg.dist_params) # logger timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) log_file = osp.join(cfg.work_dir, 'test_{}.log'.format(timestamp)) logger = get_root_logger(log_file=log_file, log_level=cfg.log_level) # build the dataloader dataset = build_dataset(cfg.data.val) data_loader = build_dataloader( dataset, imgs_per_gpu=cfg.data.imgs_per_gpu, workers_per_gpu=cfg.data.workers_per_gpu, dist=distributed, shuffle=False) # build the model and load checkpoint model = build_model(cfg.model) load_checkpoint(model, args.checkpoint, map_location='cpu') if not distributed: model = MMDataParallel(model, device_ids=[0]) outputs = single_gpu_test(model, data_loader) else: model = MMDistributedDataParallel( model.cuda(), device_ids=[torch.cuda.current_device()], broadcast_buffers=False) outputs = multi_gpu_test(model, data_loader) # dict{key: np.ndarray} rank, _ = get_dist_info() if rank == 0: if args.model_type == 'pose': evaluators = build_evaluator( cfg.eval_pipelines[0]['evaluators'][0]) dataset.evaluate(outputs, evaluators) else: for name, val in outputs.items(): dataset.evaluate( torch.from_numpy(val), name, logger, topk=(1, 5)) if __name__ == '__main__': main()