Yixiao Fang b6585cc1f1
[Fix] Fix knn multi-gpu bug (#634)
* update knn

* update

* fix bugs of knn

* update entrance scripts

* update configs and related codes

* update sampler config

* remove redundance

* update

* update docs

* fix lint

* update logic of loading ckpt
2022-12-23 11:45:48 +08:00

185 lines
6.5 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import copy
import os
import os.path as osp
import time
import torch
from mmengine.config import Config, DictAction
from mmengine.dist import get_rank, init_dist
from mmengine.logging import MMLogger
from mmengine.model.wrappers import MMDistributedDataParallel, is_model_wrapper
from mmengine.runner import Runner, load_checkpoint
from mmengine.utils import mkdir_or_exist
from mmselfsup.evaluation.functional import knn_eval
from mmselfsup.models.utils import Extractor
from mmselfsup.registry import MODELS
from mmselfsup.utils import register_all_modules
def parse_args():
parser = argparse.ArgumentParser(description='KNN evaluation')
parser.add_argument('config', help='train config file path')
parser.add_argument('--checkpoint', default=None, help='checkpoint file')
parser.add_argument(
'--dataset-config',
default='configs/benchmarks/classification/knn_imagenet.py',
help='knn dataset config file path')
parser.add_argument(
'--work-dir', type=str, default=None, help='the dir to save results')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
# KNN settings
parser.add_argument(
'--num-knn',
default=[10, 20, 100, 200],
nargs='+',
type=int,
help='Number of NN to use. 20 usually works the best.')
parser.add_argument(
'--temperature',
default=0.07,
type=float,
help='Temperature used in the voting coefficient.')
parser.add_argument(
'--use-cuda',
default=True,
type=bool,
help='Store the features on GPU. Set to False if you encounter OOM')
parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument('--seed', type=int, default=0, help='random seed')
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
return args
def main():
args = parse_args()
# register all modules in mmselfsup into the registries
register_all_modules()
# load config
cfg = Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
# set cudnn_benchmark
if cfg.env_cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
# work_dir is determined in this priority: CLI > segment in file > filename
if args.work_dir is not None:
# update configs according to CLI args if args.work_dir is not None
cfg.work_dir = args.work_dir
elif cfg.get('work_dir', None) is None:
# use config filename as default work_dir if cfg.work_dir is None
cfg.work_dir = osp.join('./work_dirs/benchmarks/knn/',
osp.splitext(osp.basename(args.config))[0])
# init distributed env first, since logger depends on the dist info.
if args.launcher == 'none':
distributed = False
else:
distributed = True
init_dist(args.launcher)
# create work_dir
mkdir_or_exist(osp.abspath(cfg.work_dir))
# init the logger before other steps
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
log_file = osp.join(cfg.work_dir, f'knn_{timestamp}.log')
logger = MMLogger.get_instance(
'mmselfsup',
logger_name='mmselfsup',
log_file=log_file,
log_level=cfg.log_level)
# build dataloader
dataset_cfg = Config.fromfile(args.dataset_config)
data_loader_train = Runner.build_dataloader(
dataloader=dataset_cfg.train_dataloader, seed=args.seed)
data_loader_val = Runner.build_dataloader(
dataloader=dataset_cfg.val_dataloader, seed=args.seed)
# build the model
model = MODELS.build(cfg.model)
# model is determined in this priority: checkpoint > init_cfg > random
if args.checkpoint is not None:
logger.info(f'Use checkpoint: {args.checkpoint} to extract features')
load_checkpoint(model, args.checkpoint, map_location='cpu')
elif getattr(model.backbone.init_cfg, 'type', None) == 'Pretrained':
model.init_weights()
else:
logger.warning(
'No pretrained or checkpoint is given, use random init.')
if torch.cuda.is_available():
model = model.cuda()
if distributed:
model = MMDistributedDataParallel(
module=model.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False)
if is_model_wrapper(model):
model = model.module
# build extractor and extract features
extractor_train = Extractor(
extract_dataloader=data_loader_train,
seed=args.seed,
dist_mode=distributed,
pool_cfg=copy.deepcopy(dataset_cfg.pool_cfg))
extractor_val = Extractor(
extract_dataloader=data_loader_val,
seed=args.seed,
dist_mode=distributed,
pool_cfg=copy.deepcopy(dataset_cfg.pool_cfg))
train_feats = extractor_train(model)
logger.info('Features from train dataset are extracted.')
val_feats = extractor_val(model)
logger.info('Features from validation dataset are extracted.')
# run knn
rank = get_rank()
if rank == 0:
for key in train_feats.keys():
train_feats = train_feats[key]
val_feats = val_feats[key]
train_labels = torch.LongTensor(
data_loader_train.dataset.get_gt_labels()).to(
train_feats.device)
val_labels = torch.LongTensor(
data_loader_val.dataset.get_gt_labels()).to(val_feats.device)
logger.info(f'Start k-NN classification of key "{key}".')
for k in args.num_knn:
top1, top5 = knn_eval(train_feats, train_labels, val_feats,
val_labels, k, args.temperature)
logger.info(f'Reasults of "{key}", {k}-NN classifier: '
f'Top1 - {top1}, Top5 - {top5}.')
if __name__ == '__main__':
main()