mmselfsup/tools/analysis_tools/visualize_tsne.py
Yixiao Fang df907e5ce0
Bump version to v0.8.0 (#269)
* [Fix]: Fix mmcls upgrade bug (#235)

* [Feature]: Add multi machine dist_train (#232)

* [Feature]: Add multi machine dist_train

* [Fix]: Change bash to sh

* [Fix]: Fix missing sh suffix

* [Refactor]: Change bash to sh

* [Refactor] Add unit test (#234)

* [Refactor] add unit test

* update workflow

* update

* [Fix] fix lint

* update test

* refactor moco and densecl unit test

* fix lint

* add unit test

* update unit test

* remove modification

* [Feature]: Add MAE metafile (#238)

* [Feature]: Add MAE metafile

* [Fix]: Fix lint

* [Fix]: Change LARS to AdamW in the metafile of MAE

* [Fix] fix codecov bug (#241)

* [Fix] fix codecov bug

* update comment

* [Refactor] Using MMCls backbones (#233)

* [Refactor] using backbones from MMCls

* [Refactor] modify the unit test

* [Fix] modify default setting of out_indices

* [Docs] fix lint

* [Refactor] modify super init

* [Refactore] remove res_layer.py

* using mmcv PatchEmbed

* [Fix]: Fix outdated problem (#249)

* [Fix]: Fix outdated problem

* [Fix]: Update MoCov3 bibtex

* [Fix]: Use abs path in README

* [Fix]: Reformat MAE bibtex

* [Fix]: Reformat MoCov3 bibtex

* [Feature] Resume from the latest checkpoint automatically. (#245)

* [Feature] Resume from the latest checkpoint automatically.

* fix windows path problem

* fix lint

* add code reference

* [Docs] add docstring for ResNet and ResNeXt (#252)

* [Feature] support KNN benchmark (#243)

* [Feature] support KNN benchmark

* [Fix] add docstring and multi-machine testing

* [Fix] fix lint

* [Fix] change args format and check init_cfg

* [Docs] add benchmark tutorial

* [Docs] add benchmark results

* [Feature]: SimMIM supported (#239)

* [Feature]: SimMIM Pretrain

* [Feature]: Add mix precision and 16x128 config

* [Fix]: Fix config import bug

* [Fix]: Fix config bug

* [Feature]: Simim Finetune

* [Fix]: Log every 100

* [Fix]: Fix eval problem

* [Feature]: Add docstring for simmim

* [Refactor]: Merge layer wise lr decay to Default constructor

* [Fix]:Fix simmim evaluation bug

* [Fix]: Change model to be compatible to latest version of mmcls

* [Fix]: Fix lint

* [Fix]: Rewrite forward_train for classification cls

* [Feature]: Add UT

* [Fix]: Fix lint

* [Feature]: Add 32 gpus training for simmim ft

* [Fix]: Rename mmcls classifier wrapper

* [Fix]: Add docstring to SimMIMNeck

* [Feature]: Generate docstring for the forward function of simmim encoder

* [Fix]: Rewrite the class docstring for constructor

* [Fix]: Fix lint

* [Fix]: Fix UT

* [Fix]: Reformat config

* [Fix]: Add img resolution

* [Feature]: Add readme and metafile

* [Fix]: Fix typo in README.md

* [Fix]: Change BlackMaskGen to BlockwiseMaskGenerator

* [Fix]: Change the name of SwinForSimMIM

* [Fix]: Delete irrelevant files

* [Feature]: Create extra transformerfinetuneconstructor

* [Fix]: Fix lint

* [Fix]: Update SimMIM README

* [Fix]: Change SimMIMPretrainHead to SimMIMHead

* [Fix]: Fix the docstring of ft constructor

* [Fix]: Fix UT

* [Fix]: Recover deletion

Co-authored-by: Your <you@example.com>

* [Fix] add seed to distributed sampler (#250)

* [Fix] add seed to distributed sampler

* fix lint

* [Feature] Add ImageNet21k (#225)

* solve memory leak by limited implementation

* fix lint problem

Co-authored-by: liming <liming.ai@bytedance.com>

* [Refactor] change args format to '--a-b' (#253)

* [Refactor] change args format to `--a-b`

* modify tsne script

* modify 'sh' files

* modify getting_started.md

* modify getting_started.md

* [Fix] fix 'mkdir' error in prepare_voc07_cls.sh (#261)

* [Fix] fix positional parameter error (#260)

* [Fix] fix command errors in benchmarks tutorial (#263)

* [Docs] add brief installation steps in README.md (#265)

* [Docs] add colab tutorial (#247)

* [Docs] add colab tutorial

* fix lint

* modify the colab tutorial, using API to train the model

* modify the description

* remove #

* modify the command

* [Docs] translate 6_benchmarks.md into Chinese (#262)

* [Docs] translate 6_benchmarks.md into Chinese

* Update 6_benchmarks.md

change 基准 to 基准评测

* Update 6_benchmarks.md

(1)  Add Chinese translation of  ‘1 folder for ImageNet nearest-neighbor classification task’
(2) 数据预准备 -> 数据准备

* [Docs] remove install scripts in README (#267)

* [Docs] Update version information in dev branch (#268)

* update version to v0.8.0

* fix lint

* [Fix]: Install the latest mmcls

* [Fix]: Add SimMIM in RAEDME

Co-authored-by: Yuan Liu <30762564+YuanLiuuuuuu@users.noreply.github.com>
Co-authored-by: Jiahao Xie <52497952+Jiahao000@users.noreply.github.com>
Co-authored-by: Your <you@example.com>
Co-authored-by: Ming Li <73068772+mitming@users.noreply.github.com>
Co-authored-by: liming <liming.ai@bytedance.com>
Co-authored-by: RenQin <45731309+soonera@users.noreply.github.com>
Co-authored-by: YuanLiuuuuuu <3463423099@qq.com>
2022-03-31 18:47:54 +08:00

326 lines
12 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
import time
import matplotlib.pyplot as plt
import mmcv
import numpy as np
import torch
from mmcv import Config, DictAction
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import get_dist_info, init_dist, load_checkpoint
from sklearn.manifold import TSNE
from mmselfsup.apis import set_random_seed
from mmselfsup.datasets import build_dataloader, build_dataset
from mmselfsup.models import build_algorithm
from mmselfsup.models.utils import MultiExtractProcess
from mmselfsup.utils import get_root_logger
def parse_args():
parser = argparse.ArgumentParser(description='t-SNE visualization')
parser.add_argument('config', help='train config file path')
parser.add_argument('--checkpoint', default=None, help='checkpoint file')
parser.add_argument(
'--work_dir',
help='(Deprecated, please use --work-dir) the dir to save logs and '
'models')
parser.add_argument('--work-dir', help='the dir to save logs and models')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
parser.add_argument(
'--dataset_config',
default='configs/benchmarks/classification/tsne_imagenet.py',
help='(Deprecated, please use --dataset-config) '
'extract dataset config file path')
parser.add_argument(
'--dataset-config',
default='configs/benchmarks/classification/tsne_imagenet.py',
help='extract dataset config file path')
parser.add_argument(
'--layer_ind',
type=str,
default='0,1,2,3,4',
help='(Deprecated, please use --layer-ind) layer indices, '
'separated by comma, e.g., "0,1,2,3,4"')
parser.add_argument(
'--layer-ind',
type=str,
default='0,1,2,3,4',
help='layer indices, separated by comma, e.g., "0,1,2,3,4"')
parser.add_argument(
'--pool_type',
choices=['specified', 'adaptive'],
default='specified',
help='(Deprecated, please use --pool-type) Pooling type in '
':class:`MultiPooling`')
parser.add_argument(
'--pool-type',
choices=['specified', 'adaptive'],
default='specified',
help='Pooling type in :class:`MultiPooling`')
parser.add_argument(
'--max_num_class',
type=int,
default=20,
help='(Deprecated, please use --max-num-class) the maximum number '
'of classes to apply t-SNE algorithms, now the function supports '
'maximum 20 classes')
parser.add_argument(
'--max-num-class',
type=int,
default=20,
help='the maximum number of classes to apply t-SNE algorithms, now the'
'function supports maximum 20 classes')
parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument(
'--deterministic',
action='store_true',
help='whether to set deterministic options for CUDNN backend.')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
# t-SNE settings
parser.add_argument(
'--n_components',
type=int,
default=2,
help='(Deprecated, please use --n-components) the dimension of results'
)
parser.add_argument(
'--n-components', type=int, default=2, help='the dimension of results')
parser.add_argument(
'--perplexity',
type=float,
default=30.0,
help='The perplexity is related to the number of nearest neighbors'
'that is used in other manifold learning algorithms.')
parser.add_argument(
'--early_exaggeration',
type=float,
default=12.0,
help='(Deprecated, please use --early-exaggeration) Controls how '
'tight natural clusters in the original space are in the embedded '
'space and how much space will be between them.')
parser.add_argument(
'--early-exaggeration',
type=float,
default=12.0,
help='Controls how tight natural clusters in the original space are in'
'the embedded space and how much space will be between them.')
parser.add_argument(
'--learning_rate',
type=float,
default=200.0,
help='(Deprecated, please use --learning-rate) The learning rate '
'for t-SNE is usually in the range [10.0, 1000.0]. '
'If the learning rate is too high, the data may look'
'like a ball with any point approximately equidistant from its nearest'
'neighbours. If the learning rate is too low, most points may look'
'compressed in a dense cloud with few outliers.')
parser.add_argument(
'--learning-rate',
type=float,
default=200.0,
help='The learning rate for t-SNE is usually in the range'
'[10.0, 1000.0]. If the learning rate is too high, the data may look'
'like a ball with any point approximately equidistant from its nearest'
'neighbours. If the learning rate is too low, most points may look'
'compressed in a dense cloud with few outliers.')
parser.add_argument(
'--n_iter',
type=int,
default=1000,
help='(Deprecated, please use --n-iter) Maximum number of iterations '
'for the optimization. Should be at least 250.')
parser.add_argument(
'--n-iter',
type=int,
default=1000,
help='Maximum number of iterations for the optimization. Should be at'
'least 250.')
parser.add_argument(
'--n_iter_without_progress',
type=int,
default=300,
help='(Deprecated, please use --n-iter-without-progress) Maximum '
'number of iterations without progress before we abort the '
'optimization.')
parser.add_argument(
'--n-iter-without-progress',
type=int,
default=300,
help='Maximum number of iterations without progress before we abort'
'the optimization.')
parser.add_argument(
'--init', type=str, default='random', help='The init method')
args = parser.parse_args()
return args
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
# set cudnn_benchmark
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
# work_dir is determined in this priority: CLI > segment in file > filename
if args.work_dir is not None:
# update configs according to CLI args if args.work_dir is not None
cfg.work_dir = args.work_dir
elif cfg.get('work_dir', None) is None:
# use config filename as default work_dir if cfg.work_dir is None
work_type = args.config.split('/')[1]
cfg.work_dir = osp.join('./work_dirs', work_type,
osp.splitext(osp.basename(args.config))[0])
# get out_indices from args
layer_ind = [int(idx) for idx in args.layer_ind.split(',')]
cfg.model.backbone.out_indices = layer_ind
# init distributed env first, since logger depends on the dist info.
if args.launcher == 'none':
distributed = False
else:
distributed = True
init_dist(args.launcher, **cfg.dist_params)
# create work_dir and init the logger before other steps
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
tsne_work_dir = osp.join(cfg.work_dir, f'tsne_{timestamp}/')
mmcv.mkdir_or_exist(osp.abspath(tsne_work_dir))
log_file = osp.join(tsne_work_dir, 'extract.log')
logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
# set random seeds
if args.seed is not None:
logger.info(f'Set random seed to {args.seed}, '
f'deterministic: {args.deterministic}')
set_random_seed(args.seed, deterministic=args.deterministic)
# build the dataloader
dataset_cfg = mmcv.Config.fromfile(args.dataset_config)
dataset = build_dataset(dataset_cfg.data.extract)
# compress dataset, select that the label is less then max_num_class
tmp_infos = []
for i in range(len(dataset)):
if dataset.data_source.data_infos[i]['gt_label'] < args.max_num_class:
tmp_infos.append(dataset.data_source.data_infos[i])
dataset.data_source.data_infos = tmp_infos
logger.info(f'Apply t-SNE to visualize {len(dataset)} samples.')
if 'imgs_per_gpu' in cfg.data:
logger.warning('"imgs_per_gpu" is deprecated. '
'Please use "samples_per_gpu" instead')
if 'samples_per_gpu' in cfg.data:
logger.warning(
f'Got "imgs_per_gpu"={cfg.data.imgs_per_gpu} and '
f'"samples_per_gpu"={cfg.data.samples_per_gpu}, "imgs_per_gpu"'
f'={cfg.data.imgs_per_gpu} is used in this experiments')
else:
logger.warning(
'Automatically set "samples_per_gpu"="imgs_per_gpu"='
f'{cfg.data.imgs_per_gpu} in this experiments')
cfg.data.samples_per_gpu = cfg.data.imgs_per_gpu
data_loader = build_dataloader(
dataset,
samples_per_gpu=dataset_cfg.data.samples_per_gpu,
workers_per_gpu=dataset_cfg.data.workers_per_gpu,
dist=distributed,
shuffle=False)
# build the model
model = build_algorithm(cfg.model)
model.init_weights()
# model is determined in this priority: init_cfg > checkpoint > random
if hasattr(cfg.model.backbone, 'init_cfg'):
if getattr(cfg.model.backbone.init_cfg, 'type', None) == 'Pretrained':
logger.info(
f'Use pretrained model: '
f'{cfg.model.backbone.init_cfg.checkpoint} to extract features'
)
elif args.checkpoint is not None:
logger.info(f'Use checkpoint: {args.checkpoint} to extract features')
load_checkpoint(model, args.checkpoint, map_location='cpu')
else:
logger.info('No pretrained or checkpoint is given, use random init.')
if not distributed:
model = MMDataParallel(model, device_ids=[0])
else:
model = MMDistributedDataParallel(
model.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False)
# build extraction processor and run
extractor = MultiExtractProcess(
pool_type=args.pool_type, backbone='resnet50', layer_indices=layer_ind)
features = extractor.extract(model, data_loader, distributed=distributed)
labels = dataset.data_source.get_gt_labels()
# save features
mmcv.mkdir_or_exist(f'{tsne_work_dir}features/')
logger.info(f'Save features to {tsne_work_dir}features/')
if distributed:
rank, _ = get_dist_info()
if rank == 0:
for key, val in features.items():
output_file = \
f'{tsne_work_dir}features/{dataset_cfg.name}_{key}.npy'
np.save(output_file, val)
else:
for key, val in features.items():
output_file = \
f'{tsne_work_dir}features/{dataset_cfg.name}_{key}.npy'
np.save(output_file, val)
# build t-SNE model
tsne_model = TSNE(
n_components=args.n_components,
perplexity=args.perplexity,
early_exaggeration=args.early_exaggeration,
learning_rate=args.learning_rate,
n_iter=args.n_iter,
n_iter_without_progress=args.n_iter_without_progress,
init=args.init)
# run and get results
mmcv.mkdir_or_exist(f'{tsne_work_dir}saved_pictures/')
logger.info('Running t-SNE......')
for key, val in features.items():
result = tsne_model.fit_transform(val)
res_min, res_max = result.min(0), result.max(0)
res_norm = (result - res_min) / (res_max - res_min)
plt.figure(figsize=(10, 10))
plt.scatter(
res_norm[:, 0],
res_norm[:, 1],
alpha=1.0,
s=15,
c=labels,
cmap='tab20')
plt.savefig(f'{tsne_work_dir}saved_pictures/{key}.png')
logger.info(f'Saved results to {tsne_work_dir}saved_pictures/')
if __name__ == '__main__':
main()