# Copyright (c) OpenMMLab. All rights reserved. import argparse import copy import os import os.path as osp import time import torch from mmengine.config import Config, DictAction from mmengine.dist import get_rank, init_dist from mmengine.logging import MMLogger from mmengine.model.wrappers import MMDistributedDataParallel, is_model_wrapper from mmengine.runner import Runner, load_checkpoint from mmengine.utils import mkdir_or_exist from mmselfsup.evaluation.functional import knn_eval from mmselfsup.models.utils import Extractor from mmselfsup.registry import MODELS from mmselfsup.utils import register_all_modules def parse_args(): parser = argparse.ArgumentParser(description='KNN evaluation') parser.add_argument('config', help='train config file path') parser.add_argument('--checkpoint', default=None, help='checkpoint file') parser.add_argument( '--dataset-config', default='configs/benchmarks/classification/knn_imagenet.py', help='knn dataset config file path') parser.add_argument( '--work-dir', type=str, default=None, help='the dir to save results') parser.add_argument( '--launcher', choices=['none', 'pytorch', 'slurm', 'mpi'], default='none', help='job launcher') parser.add_argument( '--cfg-options', nargs='+', action=DictAction, help='override some settings in the used config, the key-value pair ' 'in xxx=yyy format will be merged into config file. If the value to ' 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 'Note that the quotation marks are necessary and that no white space ' 'is allowed.') # KNN settings parser.add_argument( '--num-knn', default=[10, 20, 100, 200], nargs='+', type=int, help='Number of NN to use. 20 usually works the best.') parser.add_argument( '--temperature', default=0.07, type=float, help='Temperature used in the voting coefficient.') parser.add_argument( '--use-cuda', default=True, type=bool, help='Store the features on GPU. Set to False if you encounter OOM') parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--seed', type=int, default=0, help='random seed') args = parser.parse_args() if 'LOCAL_RANK' not in os.environ: os.environ['LOCAL_RANK'] = str(args.local_rank) return args def main(): args = parse_args() # register all modules in mmselfsup into the registries register_all_modules() # load config cfg = Config.fromfile(args.config) if args.cfg_options is not None: cfg.merge_from_dict(args.cfg_options) # set cudnn_benchmark if cfg.env_cfg.get('cudnn_benchmark', False): torch.backends.cudnn.benchmark = True # work_dir is determined in this priority: CLI > segment in file > filename if args.work_dir is not None: # update configs according to CLI args if args.work_dir is not None cfg.work_dir = args.work_dir elif cfg.get('work_dir', None) is None: # use config filename as default work_dir if cfg.work_dir is None cfg.work_dir = osp.join('./work_dirs/benchmarks/knn/', osp.splitext(osp.basename(args.config))[0]) # init distributed env first, since logger depends on the dist info. if args.launcher == 'none': distributed = False else: distributed = True init_dist(args.launcher) # create work_dir mkdir_or_exist(osp.abspath(cfg.work_dir)) # init the logger before other steps timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) log_file = osp.join(cfg.work_dir, f'knn_{timestamp}.log') logger = MMLogger.get_instance( 'mmselfsup', logger_name='mmselfsup', log_file=log_file, log_level=cfg.log_level) # build dataloader dataset_cfg = Config.fromfile(args.dataset_config) data_loader_train = Runner.build_dataloader( dataloader=dataset_cfg.train_dataloader, seed=args.seed) data_loader_val = Runner.build_dataloader( dataloader=dataset_cfg.val_dataloader, seed=args.seed) # build the model model = MODELS.build(cfg.model) # model is determined in this priority: checkpoint > init_cfg > random if args.checkpoint is not None: logger.info(f'Use checkpoint: {args.checkpoint} to extract features') load_checkpoint(model, args.checkpoint, map_location='cpu') elif getattr(model.backbone.init_cfg, 'type', None) == 'Pretrained': model.init_weights() else: logger.warning( 'No pretrained or checkpoint is given, use random init.') if torch.cuda.is_available(): model = model.cuda() if distributed: model = MMDistributedDataParallel( module=model.cuda(), device_ids=[torch.cuda.current_device()], broadcast_buffers=False) if is_model_wrapper(model): model = model.module # build extractor and extract features extractor_train = Extractor( extract_dataloader=data_loader_train, seed=args.seed, dist_mode=distributed, pool_cfg=copy.deepcopy(dataset_cfg.pool_cfg)) extractor_val = Extractor( extract_dataloader=data_loader_val, seed=args.seed, dist_mode=distributed, pool_cfg=copy.deepcopy(dataset_cfg.pool_cfg)) train_feats = extractor_train(model) logger.info('Features from train dataset are extracted.') val_feats = extractor_val(model) logger.info('Features from validation dataset are extracted.') # run knn rank = get_rank() if rank == 0: for key in train_feats.keys(): train_feats = train_feats[key] val_feats = val_feats[key] train_labels = torch.LongTensor( data_loader_train.dataset.get_gt_labels()).to( train_feats.device) val_labels = torch.LongTensor( data_loader_val.dataset.get_gt_labels()).to(val_feats.device) logger.info(f'Start k-NN classification of key "{key}".') for k in args.num_knn: top1, top5 = knn_eval(train_feats, train_labels, val_feats, val_labels, k, args.temperature) logger.info(f'Reasults of "{key}", {k}-NN classifier: ' f'Top1 - {top1}, Top5 - {top5}.') if __name__ == '__main__': main()