diff --git a/tools/test.py b/tools/test.py index 6e8850f6..bc0afc7d 100644 --- a/tools/test.py +++ b/tools/test.py @@ -18,7 +18,7 @@ from mmselfsup.utils import (get_root_logger, multi_gpu_test, def parse_args(): parser = argparse.ArgumentParser( - description='MMDet test (and eval) a model') + description='MMSelfSup test (and eval) a model') parser.add_argument('config', help='test config file path') parser.add_argument('checkpoint', help='checkpoint file') parser.add_argument( @@ -86,6 +86,9 @@ def main(): distributed = True init_dist(args.launcher, **cfg.dist_params) + # create work_dir + mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir)) + # logger timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) log_file = osp.join(cfg.work_dir, f'test_{timestamp}.log') @@ -116,8 +119,7 @@ def main(): rank, _ = get_dist_info() if rank == 0: - for name, val in outputs.items(): - dataset.evaluate(torch.from_numpy(val), name, logger, topk=(1, 5)) + dataset.evaluate(outputs, logger, topk=(1, 5)) if __name__ == '__main__':