mirror of
https://github.com/open-mmlab/mmselfsup.git
synced 2025-06-03 14:59:38 +08:00
* [Fix]: Fix mmcls upgrade bug (#235) * [Feature]: Add multi machine dist_train (#232) * [Feature]: Add multi machine dist_train * [Fix]: Change bash to sh * [Fix]: Fix missing sh suffix * [Refactor]: Change bash to sh * [Refactor] Add unit test (#234) * [Refactor] add unit test * update workflow * update * [Fix] fix lint * update test * refactor moco and densecl unit test * fix lint * add unit test * update unit test * remove modification * [Feature]: Add MAE metafile (#238) * [Feature]: Add MAE metafile * [Fix]: Fix lint * [Fix]: Change LARS to AdamW in the metafile of MAE * [Fix] fix codecov bug (#241) * [Fix] fix codecov bug * update comment * [Refactor] Using MMCls backbones (#233) * [Refactor] using backbones from MMCls * [Refactor] modify the unit test * [Fix] modify default setting of out_indices * [Docs] fix lint * [Refactor] modify super init * [Refactore] remove res_layer.py * using mmcv PatchEmbed * [Fix]: Fix outdated problem (#249) * [Fix]: Fix outdated problem * [Fix]: Update MoCov3 bibtex * [Fix]: Use abs path in README * [Fix]: Reformat MAE bibtex * [Fix]: Reformat MoCov3 bibtex * [Feature] Resume from the latest checkpoint automatically. (#245) * [Feature] Resume from the latest checkpoint automatically. * fix windows path problem * fix lint * add code reference * [Docs] add docstring for ResNet and ResNeXt (#252) * [Feature] support KNN benchmark (#243) * [Feature] support KNN benchmark * [Fix] add docstring and multi-machine testing * [Fix] fix lint * [Fix] change args format and check init_cfg * [Docs] add benchmark tutorial * [Docs] add benchmark results * [Feature]: SimMIM supported (#239) * [Feature]: SimMIM Pretrain * [Feature]: Add mix precision and 16x128 config * [Fix]: Fix config import bug * [Fix]: Fix config bug * [Feature]: Simim Finetune * [Fix]: Log every 100 * [Fix]: Fix eval problem * [Feature]: Add docstring for simmim * [Refactor]: Merge layer wise lr decay to Default constructor * [Fix]:Fix simmim evaluation bug * [Fix]: Change model to be compatible to latest version of mmcls * [Fix]: Fix lint * [Fix]: Rewrite forward_train for classification cls * [Feature]: Add UT * [Fix]: Fix lint * [Feature]: Add 32 gpus training for simmim ft * [Fix]: Rename mmcls classifier wrapper * [Fix]: Add docstring to SimMIMNeck * [Feature]: Generate docstring for the forward function of simmim encoder * [Fix]: Rewrite the class docstring for constructor * [Fix]: Fix lint * [Fix]: Fix UT * [Fix]: Reformat config * [Fix]: Add img resolution * [Feature]: Add readme and metafile * [Fix]: Fix typo in README.md * [Fix]: Change BlackMaskGen to BlockwiseMaskGenerator * [Fix]: Change the name of SwinForSimMIM * [Fix]: Delete irrelevant files * [Feature]: Create extra transformerfinetuneconstructor * [Fix]: Fix lint * [Fix]: Update SimMIM README * [Fix]: Change SimMIMPretrainHead to SimMIMHead * [Fix]: Fix the docstring of ft constructor * [Fix]: Fix UT * [Fix]: Recover deletion Co-authored-by: Your <you@example.com> * [Fix] add seed to distributed sampler (#250) * [Fix] add seed to distributed sampler * fix lint * [Feature] Add ImageNet21k (#225) * solve memory leak by limited implementation * fix lint problem Co-authored-by: liming <liming.ai@bytedance.com> * [Refactor] change args format to '--a-b' (#253) * [Refactor] change args format to `--a-b` * modify tsne script * modify 'sh' files * modify getting_started.md * modify getting_started.md * [Fix] fix 'mkdir' error in prepare_voc07_cls.sh (#261) * [Fix] fix positional parameter error (#260) * [Fix] fix command errors in benchmarks tutorial (#263) * [Docs] add brief installation steps in README.md (#265) * [Docs] add colab tutorial (#247) * [Docs] add colab tutorial * fix lint * modify the colab tutorial, using API to train the model * modify the description * remove # * modify the command * [Docs] translate 6_benchmarks.md into Chinese (#262) * [Docs] translate 6_benchmarks.md into Chinese * Update 6_benchmarks.md change 基准 to 基准评测 * Update 6_benchmarks.md (1) Add Chinese translation of ‘1 folder for ImageNet nearest-neighbor classification task’ (2) 数据预准备 -> 数据准备 * [Docs] remove install scripts in README (#267) * [Docs] Update version information in dev branch (#268) * update version to v0.8.0 * fix lint * [Fix]: Install the latest mmcls * [Fix]: Add SimMIM in RAEDME Co-authored-by: Yuan Liu <30762564+YuanLiuuuuuu@users.noreply.github.com> Co-authored-by: Jiahao Xie <52497952+Jiahao000@users.noreply.github.com> Co-authored-by: Your <you@example.com> Co-authored-by: Ming Li <73068772+mitming@users.noreply.github.com> Co-authored-by: liming <liming.ai@bytedance.com> Co-authored-by: RenQin <45731309+soonera@users.noreply.github.com> Co-authored-by: YuanLiuuuuuu <3463423099@qq.com>
187 lines
7.1 KiB
Python
187 lines
7.1 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import argparse
|
|
import os
|
|
import os.path as osp
|
|
import time
|
|
|
|
import mmcv
|
|
import torch
|
|
from mmcv import DictAction
|
|
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
|
|
from mmcv.runner import get_dist_info, init_dist, load_checkpoint
|
|
|
|
from mmselfsup.datasets import build_dataloader, build_dataset
|
|
from mmselfsup.models import build_algorithm
|
|
from mmselfsup.models.utils import ExtractProcess, knn_classifier
|
|
from mmselfsup.utils import get_root_logger
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description='KNN evaluation')
|
|
parser.add_argument('config', help='train config file path')
|
|
parser.add_argument('--checkpoint', default=None, help='checkpoint file')
|
|
parser.add_argument(
|
|
'--dataset-config',
|
|
default='configs/benchmarks/classification/knn_imagenet.py',
|
|
help='knn dataset config file path')
|
|
parser.add_argument(
|
|
'--work-dir', type=str, default=None, help='the dir to save results')
|
|
parser.add_argument(
|
|
'--launcher',
|
|
choices=['none', 'pytorch', 'slurm', 'mpi'],
|
|
default='none',
|
|
help='job launcher')
|
|
parser.add_argument('--local-rank', type=int, default=0)
|
|
parser.add_argument(
|
|
'--cfg-options',
|
|
nargs='+',
|
|
action=DictAction,
|
|
help='override some settings in the used config, the key-value pair '
|
|
'in xxx=yyy format will be merged into config file. If the value to '
|
|
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
|
|
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
|
|
'Note that the quotation marks are necessary and that no white space '
|
|
'is allowed.')
|
|
# KNN settings
|
|
parser.add_argument(
|
|
'--num-knn',
|
|
default=[10, 20, 100, 200],
|
|
nargs='+',
|
|
type=int,
|
|
help='Number of NN to use. 20 usually works the best.')
|
|
parser.add_argument(
|
|
'--temperature',
|
|
default=0.07,
|
|
type=float,
|
|
help='Temperature used in the voting coefficient.')
|
|
parser.add_argument(
|
|
'--use-cuda',
|
|
default=True,
|
|
type=bool,
|
|
help='Store the features on GPU. Set to False if you encounter OOM')
|
|
args = parser.parse_args()
|
|
if 'LOCAL_RANK' not in os.environ:
|
|
os.environ['LOCAL_RANK'] = str(args.local_rank)
|
|
return args
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
|
|
cfg = mmcv.Config.fromfile(args.config)
|
|
if args.cfg_options is not None:
|
|
cfg.merge_from_dict(args.cfg_options)
|
|
# set cudnn_benchmark
|
|
if cfg.get('cudnn_benchmark', False):
|
|
torch.backends.cudnn.benchmark = True
|
|
# work_dir is determined in this priority: CLI > segment in file > filename
|
|
if args.work_dir is not None:
|
|
# update configs according to CLI args if args.work_dir is not None
|
|
cfg.work_dir = args.work_dir
|
|
elif cfg.get('work_dir', None) is None:
|
|
# use config filename as default work_dir if cfg.work_dir is None
|
|
work_type = args.config.split('/')[1]
|
|
cfg.work_dir = osp.join('./work_dirs', work_type,
|
|
osp.splitext(osp.basename(args.config))[0])
|
|
|
|
# init distributed env first, since logger depends on the dist info.
|
|
if args.launcher == 'none':
|
|
distributed = False
|
|
else:
|
|
distributed = True
|
|
init_dist(args.launcher, **cfg.dist_params)
|
|
|
|
# create work_dir and init the logger before other steps
|
|
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
|
|
knn_work_dir = osp.join(cfg.work_dir, 'knn/')
|
|
mmcv.mkdir_or_exist(osp.abspath(knn_work_dir))
|
|
log_file = osp.join(knn_work_dir, f'knn_{timestamp}.log')
|
|
logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
|
|
|
|
# build the dataloader
|
|
dataset_cfg = mmcv.Config.fromfile(args.dataset_config)
|
|
dataset_train = build_dataset(dataset_cfg.data.train)
|
|
dataset_val = build_dataset(dataset_cfg.data.val)
|
|
if 'imgs_per_gpu' in cfg.data:
|
|
logger.warning('"imgs_per_gpu" is deprecated. '
|
|
'Please use "samples_per_gpu" instead')
|
|
if 'samples_per_gpu' in cfg.data:
|
|
logger.warning(
|
|
f'Got "imgs_per_gpu"={cfg.data.imgs_per_gpu} and '
|
|
f'"samples_per_gpu"={cfg.data.samples_per_gpu}, "imgs_per_gpu"'
|
|
f'={cfg.data.imgs_per_gpu} is used in this experiments')
|
|
else:
|
|
logger.warning(
|
|
'Automatically set "samples_per_gpu"="imgs_per_gpu"='
|
|
f'{cfg.data.imgs_per_gpu} in this experiments')
|
|
cfg.data.samples_per_gpu = cfg.data.imgs_per_gpu
|
|
data_loader_train = build_dataloader(
|
|
dataset_train,
|
|
samples_per_gpu=dataset_cfg.data.samples_per_gpu,
|
|
workers_per_gpu=dataset_cfg.data.workers_per_gpu,
|
|
dist=distributed,
|
|
shuffle=False)
|
|
data_loader_val = build_dataloader(
|
|
dataset_val,
|
|
samples_per_gpu=dataset_cfg.data.samples_per_gpu,
|
|
workers_per_gpu=dataset_cfg.data.workers_per_gpu,
|
|
dist=distributed,
|
|
shuffle=False)
|
|
|
|
# build the model
|
|
model = build_algorithm(cfg.model)
|
|
model.init_weights()
|
|
|
|
# model is determined in this priority: init_cfg > checkpoint > random
|
|
if hasattr(cfg.model.backbone, 'init_cfg'):
|
|
if getattr(cfg.model.backbone.init_cfg, 'type', None) == 'Pretrained':
|
|
logger.info(
|
|
f'Use pretrained model: '
|
|
f'{cfg.model.backbone.init_cfg.checkpoint} to extract features'
|
|
)
|
|
elif args.checkpoint is not None:
|
|
logger.info(f'Use checkpoint: {args.checkpoint} to extract features')
|
|
load_checkpoint(model, args.checkpoint, map_location='cpu')
|
|
else:
|
|
logger.info('No pretrained or checkpoint is given, use random init.')
|
|
|
|
if not distributed:
|
|
model = MMDataParallel(model, device_ids=[0])
|
|
else:
|
|
model = MMDistributedDataParallel(
|
|
model.cuda(),
|
|
device_ids=[torch.cuda.current_device()],
|
|
broadcast_buffers=False)
|
|
|
|
model.eval()
|
|
# build extraction processor and run
|
|
extractor = ExtractProcess()
|
|
train_feats = extractor.extract(
|
|
model, data_loader_train, distributed=distributed)['feat']
|
|
val_feats = extractor.extract(
|
|
model, data_loader_val, distributed=distributed)['feat']
|
|
|
|
train_feats = torch.from_numpy(train_feats)
|
|
val_feats = torch.from_numpy(val_feats)
|
|
train_labels = torch.LongTensor(dataset_train.data_source.get_gt_labels())
|
|
val_labels = torch.LongTensor(dataset_val.data_source.get_gt_labels())
|
|
|
|
logger.info('Features are extracted! Start k-NN classification...')
|
|
|
|
rank, _ = get_dist_info()
|
|
if rank == 0:
|
|
if args.use_cuda:
|
|
train_feats = train_feats.cuda()
|
|
val_feats = val_feats.cuda()
|
|
train_labels = train_labels.cuda()
|
|
val_labels = val_labels.cuda()
|
|
for k in args.num_knn:
|
|
top1, top5 = knn_classifier(train_feats, train_labels, val_feats,
|
|
val_labels, k, args.temperature)
|
|
logger.info(
|
|
f'{k}-NN classifier result: Top1: {top1}, Top5: {top5}')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|