Yuan Liu 399b5a0d6e
Bump version to v0.9.0 (#299)
* [Feature]: MAE pre-training with fp16 (#271)

* [Feature]: MAE pre-training with fp16

* [Fix]: Fix lint

* [Fix]: Fix SimMIM config link, and add SimMIM to model_zoo (#272)

* [Fix]: Fix link error

* [Fix]: Add SimMIM to model zoo

* [Fix]: Fix lint

* [Fix] fix 'no init_cfg' error for pre-trained model backbones (#256)

* [UT] add unit test for apis (#276)

* [UT] add unit test for apis

* ignore pytest log

* [Feature] Add extra dataloader settings in configs. (#264)

* [Feature] support to set validation samples per gpu independently

* set default to be cfg.data.samples_per_gpu

* modify the tools/test.py

* using 'train_dataloader', 'val_dataloader', 'test_dataloader' for specific settings

* test 'evaluation' branch

* [Fix]: Change imgs_per_gpu to samples_per_gpu MAE (#278)

* [Feature]: Add SimMIM 192 pt 224 ft (#280)

* [Feature]: Add SimMIM 192 pt 224 ft

* [Feature]: Add simmim 192 pt 224 ft to readme

* [Fix] fix key error bug when registering custom hooks (#273)

* [UT] remove pytorch1.5 test (#288)

* [Benchmark] rename linear probing config file names (#281)

* [Benchmark] rename linear probing config file names

* update config links

* Avoid GPU memory leak with prefetch dataloader (#277)

* [Feature] barlowtwins (#207)

* [Fix]: Fix mmcls upgrade bug (#235)

* [Feature]: Add multi machine dist_train (#232)

* [Feature]: Add multi machine dist_train

* [Fix]: Change bash to sh

* [Fix]: Fix missing sh suffix

* [Refactor]: Change bash to sh

* [Refactor] Add unit test (#234)

* [Refactor] add unit test

* update workflow

* update

* [Fix] fix lint

* update test

* refactor moco and densecl unit test

* fix lint

* add unit test

* update unit test

* remove modification

* [Feature]: Add MAE metafile (#238)

* [Feature]: Add MAE metafile

* [Fix]: Fix lint

* [Fix]: Change LARS to AdamW in the metafile of MAE

* Add barlowtwins

* Add unit test for barlowtwins

* Adjust training params

* add decorator to pass CI

* adjust params

* Add barlowtwins

* Add unit test for barlowtwins

* Adjust training params

* add decorator to pass CI

* adjust params

* add barlowtwins configs

* revise LatentCrossCorrelationHead

* modify ut to save memory

* add metafile

* add barlowtwins results to model zoo

* add barlow twins to homepage

* fix batch size bug

* add algorithm readme

* add type hints

* reorganize the model zoo

* remove one config

* recover the config

* add missing docstring

* revise barlowtwins

* reorganize coco and voc benchmark

* add barlowtwins to index.rst

* revise docstring

Co-authored-by: Yuan Liu <30762564+YuanLiuuuuuu@users.noreply.github.com>
Co-authored-by: Yixiao Fang <36138628+fangyixiao18@users.noreply.github.com>
Co-authored-by: fangyixiao18 <fangyx18@hotmail.com>

* [Fix] fix --local-rank (#290)

* [UT] reduce memory usage while runing unit test (#291)

* [Feature]: CAE Supported (#284)

* [Feature]: Add mc

* [Feature]: Add dataset of CAE

* [Feature]: Init version of CAE

* [Feature]: Add mc

* [Fix]: Change beta to (0.9, 0.999)

* [Fix]: New feature

* [Fix]: Decouple the qkv bias

* [Feature]: Decouple qkv bias in MultiheadAttention

* [Feature]: New mask generator

* [Fix]: Fix TransformEncoderLayer bug

* [Feature]: Add MAE CAE linear prob

* [Fix]: Fix config

* [Fix]: Delete redundant mc

* [Fix]: Add init value in mim cls vit

* [Fix]: Fix cae ft config

* [Fix]: Delete repeated init_values

* [Fix]: Change bs from 64 to 128 in CAE ft

* [Fix]: Add mc in cae pt

* [Fix]: Fix momemtum update bug

* [Fix]: Add no weight_decay for gamma

* [Feature]: Add mc for cae pt

* [Fix]: Delete mc

* [Fix]: Delete redundant files

* [Fix]: Fix lint

* [Feature]: Add docstring to algo, backbone, neck and head

* [Fix]: Fix lint

* [Fix]: network

* [Feature]: Add docstrings for network blocks

* [Feature]: Add docstring to ToTensor

* [Feature]: Add docstring to transoform

* [Fix]: Add type hint to BEiTMaskGenerator

* [Fix]: Fix lint

* [Fix]: Add copyright to dalle_e

* [Fix]: Fix BlockwiseMaskGenerator

* [Feature]: Add UT for CAE

* [Fix]: Fix dalle state_dict path not existed bug

* [Fix]: Delete file_client_args related code

* [Fix]: Remove redundant code

* [Refactor]: Add fp16 to the name of cae pre-train config

* [Refactor]: Use FFN from mmcv

* [Refactor]: Change network_blocks to trasformer_blocks

* [Fix]: Fix mask generator name bug

* [Fix]: cae pre-train config bug

* [Fix]: Fix docstring grammar

* [Fix]: Fix mc related code

* [Fix]: Add object parent to transform

* [Fix]: Delete unnecessary modification

* [Fix]: Change blockwisemask generator to simmim mask generator

* [Refactor]: Change cae mae pretrain vit to cae mae vit

* [Refactor]: Change lamb to lambd

* [Fix]: Remove blank line

* [Fix]: Fix lint

* [Fix]: Fix UT

* [Fix]: Delete modification to swin

* [Fix]: Fix lint

* [Feature]: Add README and metafile

* [Feature]: Update index.rst

* [Fix]: Update model_zoo

* [Fix]: Change MAE to CAE in algorithm

* [Fix]: Change SimMIMMaskGenerator to CAEMaskGenerator

* [Fix]: Fix model zoo

* [Fix]: Change to dalle_encoder

* [Feature]: Add download link for dalle

* [Fix]: Fix lint

* [Fix]: Fix UT

* [Fix]: Update metafile

* [Fix]: Change b to base

* [Feature]: Add dalle download link in warning

* [Fix] add arxiv link in readme

Co-authored-by: Jiahao Xie <52497952+Jiahao000@users.noreply.github.com>

* [Enhance] update SimCLR models and results (#295)

* [Enhance] update simclr models and results

* [Fix] revise comments to indicate settings

* Update version (#296)

* [Feature]: Update to 0.9.0

* [Feature]: Add version constrain for mmcls

* [Fix]: Fix bug

* [Fix]: Fix version bug

* [Feature]: Update version in install.md

* update changelog

* update readme

* [Fix] fix uppercase

* [Fix] fix uppercase

* [Fix] fix uppercase

* update version dependency

* add cae to readme

Co-authored-by: fangyixiao18 <fangyx18@hotmail.com>
Co-authored-by: Jiahao Xie <52497952+Jiahao000@users.noreply.github.com>

Co-authored-by: Yixiao Fang <36138628+fangyixiao18@users.noreply.github.com>
Co-authored-by: Ming Li <73068772+mitming@users.noreply.github.com>
Co-authored-by: xcnick <xcnick0412@gmail.com>
Co-authored-by: fangyixiao18 <fangyx18@hotmail.com>
Co-authored-by: Jiahao Xie <52497952+Jiahao000@users.noreply.github.com>
2022-04-29 20:01:30 +08:00

187 lines
7.1 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import os.path as osp
import time
import mmcv
import torch
from mmcv import DictAction
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import get_dist_info, init_dist, load_checkpoint
from mmselfsup.datasets import build_dataloader, build_dataset
from mmselfsup.models import build_algorithm
from mmselfsup.models.utils import ExtractProcess, knn_classifier
from mmselfsup.utils import get_root_logger
def parse_args():
parser = argparse.ArgumentParser(description='KNN evaluation')
parser.add_argument('config', help='train config file path')
parser.add_argument('--checkpoint', default=None, help='checkpoint file')
parser.add_argument(
'--dataset-config',
default='configs/benchmarks/classification/knn_imagenet.py',
help='knn dataset config file path')
parser.add_argument(
'--work-dir', type=str, default=None, help='the dir to save results')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
# KNN settings
parser.add_argument(
'--num-knn',
default=[10, 20, 100, 200],
nargs='+',
type=int,
help='Number of NN to use. 20 usually works the best.')
parser.add_argument(
'--temperature',
default=0.07,
type=float,
help='Temperature used in the voting coefficient.')
parser.add_argument(
'--use-cuda',
default=True,
type=bool,
help='Store the features on GPU. Set to False if you encounter OOM')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
return args
def main():
args = parse_args()
cfg = mmcv.Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
# set cudnn_benchmark
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
# work_dir is determined in this priority: CLI > segment in file > filename
if args.work_dir is not None:
# update configs according to CLI args if args.work_dir is not None
cfg.work_dir = args.work_dir
elif cfg.get('work_dir', None) is None:
# use config filename as default work_dir if cfg.work_dir is None
work_type = args.config.split('/')[1]
cfg.work_dir = osp.join('./work_dirs', work_type,
osp.splitext(osp.basename(args.config))[0])
# init distributed env first, since logger depends on the dist info.
if args.launcher == 'none':
distributed = False
else:
distributed = True
init_dist(args.launcher, **cfg.dist_params)
# create work_dir and init the logger before other steps
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
knn_work_dir = osp.join(cfg.work_dir, 'knn/')
mmcv.mkdir_or_exist(osp.abspath(knn_work_dir))
log_file = osp.join(knn_work_dir, f'knn_{timestamp}.log')
logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
# build the dataloader
dataset_cfg = mmcv.Config.fromfile(args.dataset_config)
dataset_train = build_dataset(dataset_cfg.data.train)
dataset_val = build_dataset(dataset_cfg.data.val)
if 'imgs_per_gpu' in cfg.data:
logger.warning('"imgs_per_gpu" is deprecated. '
'Please use "samples_per_gpu" instead')
if 'samples_per_gpu' in cfg.data:
logger.warning(
f'Got "imgs_per_gpu"={cfg.data.imgs_per_gpu} and '
f'"samples_per_gpu"={cfg.data.samples_per_gpu}, "imgs_per_gpu"'
f'={cfg.data.imgs_per_gpu} is used in this experiments')
else:
logger.warning(
'Automatically set "samples_per_gpu"="imgs_per_gpu"='
f'{cfg.data.imgs_per_gpu} in this experiments')
cfg.data.samples_per_gpu = cfg.data.imgs_per_gpu
data_loader_train = build_dataloader(
dataset_train,
samples_per_gpu=dataset_cfg.data.samples_per_gpu,
workers_per_gpu=dataset_cfg.data.workers_per_gpu,
dist=distributed,
shuffle=False)
data_loader_val = build_dataloader(
dataset_val,
samples_per_gpu=dataset_cfg.data.samples_per_gpu,
workers_per_gpu=dataset_cfg.data.workers_per_gpu,
dist=distributed,
shuffle=False)
# build the model
model = build_algorithm(cfg.model)
model.init_weights()
# model is determined in this priority: init_cfg > checkpoint > random
if hasattr(cfg.model.backbone, 'init_cfg'):
if getattr(cfg.model.backbone.init_cfg, 'type', None) == 'Pretrained':
logger.info(
f'Use pretrained model: '
f'{cfg.model.backbone.init_cfg.checkpoint} to extract features'
)
elif args.checkpoint is not None:
logger.info(f'Use checkpoint: {args.checkpoint} to extract features')
load_checkpoint(model, args.checkpoint, map_location='cpu')
else:
logger.info('No pretrained or checkpoint is given, use random init.')
if not distributed:
model = MMDataParallel(model, device_ids=[0])
else:
model = MMDistributedDataParallel(
model.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False)
model.eval()
# build extraction processor and run
extractor = ExtractProcess()
train_feats = extractor.extract(
model, data_loader_train, distributed=distributed)['feat']
val_feats = extractor.extract(
model, data_loader_val, distributed=distributed)['feat']
train_feats = torch.from_numpy(train_feats)
val_feats = torch.from_numpy(val_feats)
train_labels = torch.LongTensor(dataset_train.data_source.get_gt_labels())
val_labels = torch.LongTensor(dataset_val.data_source.get_gt_labels())
logger.info('Features are extracted! Start k-NN classification...')
rank, _ = get_dist_info()
if rank == 0:
if args.use_cuda:
train_feats = train_feats.cuda()
val_feats = val_feats.cuda()
train_labels = train_labels.cuda()
val_labels = val_labels.cuda()
for k in args.num_knn:
top1, top5 = knn_classifier(train_feats, train_labels, val_feats,
val_labels, k, args.temperature)
logger.info(
f'{k}-NN classifier result: Top1: {top1}, Top5: {top5}')
if __name__ == '__main__':
main()