EasyCV/tools/test.py

146 lines
4.6 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import argparse
import importlib
import os
import os.path as osp
import sys
import time
import mmcv
import torch
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import get_dist_info, init_dist, load_checkpoint
from easycv.core.evaluation.builder import build_evaluator
from easycv.datasets import build_dataloader, build_dataset
from easycv.models import build_model
from easycv.utils.collect import dist_forward_collect, nondist_forward_collect
# from mmcv import Config
from easycv.utils.config_tools import mmcv_config_fromfile, traverse_replace
from easycv.utils.logger import get_root_logger
sys.path.append(os.path.abspath(os.path.dirname(os.path.dirname(__file__))))
sys.path.append(
os.path.abspath(
osp.join(os.path.dirname(os.path.dirname(__file__)), '../')))
def single_gpu_test(model, data_loader):
model.eval()
func = lambda **x: model(mode='test', **x)
results = nondist_forward_collect(func, data_loader,
len(data_loader.dataset))
return results
def multi_gpu_test(model, data_loader):
model.eval()
func = lambda **x: model(mode='test', **x)
rank, world_size = get_dist_info()
results = dist_forward_collect(func, data_loader, rank,
len(data_loader.dataset))
return results
def parse_args():
parser = argparse.ArgumentParser(
description='MMDet test (and eval) a model')
parser.add_argument('config', help='test config file path')
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument(
'--work_dir',
type=str,
default=None,
help='the dir to save logs and models')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument(
'--port',
type=int,
default=29500,
help='port only works when launcher=="slurm"')
parser.add_argument(
'--model_type',
choices=['classification', 'pose'],
default='classification',
help='model type')
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
return args
def main():
args = parse_args()
cfg = mmcv_config_fromfile(args.config)
# set cudnn_benchmark
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
# update configs according to CLI args
if args.work_dir is not None:
cfg.work_dir = args.work_dir
cfg.model.pretrained = None # ensure to use checkpoint rather than pretraining
# check memcached package exists
if importlib.util.find_spec('mc') is None:
traverse_replace(cfg, 'memcached', False)
# init distributed env first, since logger depends on the dist info.
if args.launcher == 'none':
distributed = False
else:
distributed = True
if args.launcher == 'slurm':
cfg.dist_params['port'] = args.port
init_dist(args.launcher, **cfg.dist_params)
# logger
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
log_file = osp.join(cfg.work_dir, 'test_{}.log'.format(timestamp))
logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
# build the dataloader
dataset = build_dataset(cfg.data.val)
data_loader = build_dataloader(
dataset,
imgs_per_gpu=cfg.data.imgs_per_gpu,
workers_per_gpu=cfg.data.workers_per_gpu,
dist=distributed,
shuffle=False)
# build the model and load checkpoint
model = build_model(cfg.model)
load_checkpoint(model, args.checkpoint, map_location='cpu')
if not distributed:
model = MMDataParallel(model, device_ids=[0])
outputs = single_gpu_test(model, data_loader)
else:
model = MMDistributedDataParallel(
model.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False)
outputs = multi_gpu_test(model, data_loader) # dict{key: np.ndarray}
rank, _ = get_dist_info()
if rank == 0:
if args.model_type == 'pose':
evaluators = build_evaluator(
cfg.eval_pipelines[0]['evaluators'][0])
dataset.evaluate(outputs, evaluators)
else:
for name, val in outputs.items():
dataset.evaluate(
torch.from_numpy(val), name, logger, topk=(1, 5))
if __name__ == '__main__':
main()