[Refactor] Refactor t-SNE (#629)

* refactor tsne

* update configs

* update docs

* add cls_token option
pull/634/head
Yixiao Fang 2022-12-21 16:21:12 +08:00 committed by GitHub
parent fd6659bba5
commit 347b49e77f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 207 additions and 107 deletions

View File

@ -1,26 +0,0 @@
dataset_type = 'mmcls.ImageNet'
data_root = 'data/imagenet/'
file_client_args = dict(backend='disk')
name = 'imagenet_val'
extract_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(type='mmcls.ResizeEdge', scale=256, edge='short'),
dict(type='CenterCrop', crop_size=224),
dict(type='PackSelfSupInputs'),
]
extract_dataloader = dict(
batch_size=8,
num_workers=4,
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
ann_file='meta/val.txt',
data_prefix='val',
pipeline=extract_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
)
# pooling cfg
pool_cfg = dict(type='MultiPooling', in_indices=(1, 2, 3, 4))

View File

@ -0,0 +1,48 @@
_base_ = 'mmcls::_base_/default_runtime.py'
model = dict(
_scope_='mmcls',
type='ImageClassifier',
data_preprocessor=dict(
num_classes=1000,
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True,
),
backbone=dict(
type='ResNet',
depth=50,
in_channels=3,
num_stages=4,
out_indices=(3),
norm_cfg=dict(type='BN'),
frozen_stages=-1),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=2048,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
))
dataset_type = 'mmcls.ImageNet'
data_root = 'data/imagenet/'
file_client_args = dict(backend='disk')
extract_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(type='mmcls.ResizeEdge', scale=256, edge='short'),
dict(type='CenterCrop', crop_size=224),
dict(type='mmcls.PackClsInputs'),
]
extract_dataloader = dict(
batch_size=8,
num_workers=4,
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
ann_file='meta/val.txt',
data_prefix='val',
pipeline=extract_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
)

View File

@ -0,0 +1,42 @@
_base_ = 'mmcls::_base_/default_runtime.py'
model = dict(
_scope_='mmcls',
type='ImageClassifier',
backbone=dict(
type='SwinTransformer',
arch='base',
img_size=192,
out_indices=-1,
drop_path_rate=0.1,
stage_cfgs=dict(block_cfgs=dict(window_size=6))),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=1024,
init_cfg=None, # suppress the default init_cfg of LinearClsHead.
loss=dict(
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
cal_acc=False))
dataset_type = 'mmcls.ImageNet'
data_root = 'data/imagenet/'
file_client_args = dict(backend='disk')
extract_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(type='mmcls.ResizeEdge', scale=256, edge='short'),
dict(type='CenterCrop', crop_size=224),
dict(type='mmcls.PackClsInputs'),
]
extract_dataloader = dict(
batch_size=8,
num_workers=4,
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
ann_file='meta/val.txt',
data_prefix='val',
pipeline=extract_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
)

View File

@ -0,0 +1,45 @@
_base_ = 'mmcls::_base_/default_runtime.py'
model = dict(
_scope_='mmcls',
type='ImageClassifier',
backbone=dict(
type='VisionTransformer',
arch='base',
img_size=224,
patch_size=16,
out_indices=-1,
drop_path_rate=0.1,
avg_token=False,
output_cls_token=False,
final_norm=False),
neck=None,
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=768,
loss=dict(
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
init_cfg=[dict(type='TruncNormal', layer='Linear', std=2e-5)]),
)
dataset_type = 'mmcls.ImageNet'
data_root = 'data/imagenet/'
file_client_args = dict(backend='disk')
extract_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(type='mmcls.ResizeEdge', scale=256, edge='short'),
dict(type='CenterCrop', crop_size=224),
dict(type='mmcls.PackClsInputs'),
]
extract_dataloader = dict(
batch_size=8,
num_workers=4,
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
ann_file='meta/val.txt',
data_prefix='val',
pipeline=extract_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
)

View File

@ -123,21 +123,26 @@ python tools/analysis_tools/visualize_tsne.py ${CONFIG_FILE} --checkpoint ${CKPT
Arguments:
- `CONFIG_FILE`: config file for the pre-trained model.
- `CKPT_PATH`: the path of model's checkpoint.
- `CONFIG_FILE`: config file for t-SNE, which listed in the directory `configs/tsne/`
- `CKPT_PATH`: the path or link of the model's checkpoint.
- `WORK_DIR`: the directory to save the results of visualization.
- `[optional arguments]`: for optional arguments, you can refer to [visualize_tsne.py](https://github.com/open-mmlab/mmselfsup/blob/master/tools/analysis_tools/visualize_tsne.py)
- `[optional arguments]`: for optional arguments, you can refer to [visualize_tsne.py](https://github.com/open-mmlab/mmselfsup/blob/dev-1.x/tools/analysis_tools/visualize_tsne.py)
An example:
An example of command:
```shell
python tools/analysis_tools/visualize_tsne.py configs/selfsup/simsiam/simsiam_resnet50_8xb32-coslr-100e_in1k.py --checkpoint epoch_100.pth --work-dir work_dirs/selfsup/simsiam_resnet50_8xb32-coslr-200e_in1k
python ./tools/analysis_tools/visualize_tsne.py \
configs/tsne/resnet50_imagenet.py \
--checkpoint https://download.openmmlab.com/mmselfsup/1.x/mocov2/mocov2_resnet50_8xb32-coslr-200e_in1k/mocov2_resnet50_8xb32-coslr-200e_in1k_20220825-b6d23c86.pth \
--work-dir ./work_dirs/tsne/mocov2/ \
--max-num-class 100
```
An example of visualization:
An example of visualization, left is from `MoCoV2_ResNet50` and right is from `MAE_ViT-base`:
<div align="center">
<img src="https://user-images.githubusercontent.com/36138628/199388251-476a5ad2-f9c1-4dfb-afe2-73cf41b5793b.jpg" width="800" />
<img src="https://user-images.githubusercontent.com/36138628/207305086-91df298c-0eb7-4254-9c5b-ba711644501b.png" width="250" />
<img src="https://user-images.githubusercontent.com/36138628/207305333-59af4747-1e9c-4f85-a57d-c7e5d132a6e5.png" width="250" />
</div>
## Visualize Low-level Feature Reconstruction

View File

@ -6,37 +6,33 @@ from functools import partial
from typing import Optional
import matplotlib.pyplot as plt
import mmengine
import numpy as np
import torch
import torch.nn.functional as F
from mmengine.config import Config, DictAction
from mmengine.dataset import default_collate, worker_init_fn
from mmengine.dist import get_rank, init_dist
from mmengine.dist import get_rank
from mmengine.logging import MMLogger
from mmengine.model.wrappers import MMDistributedDataParallel, is_model_wrapper
from mmengine.runner import load_checkpoint
from mmengine.utils import mkdir_or_exist
from sklearn.manifold import TSNE
from torch.utils.data import DataLoader
from mmselfsup.models.utils import Extractor
from mmselfsup.registry import DATA_SAMPLERS, DATASETS, MODELS
from mmselfsup.apis import init_model
from mmselfsup.registry import DATA_SAMPLERS, DATASETS
from mmselfsup.utils import register_all_modules
def parse_args():
parser = argparse.ArgumentParser(description='t-SNE visualization')
parser.add_argument('config', help='train config file path')
parser.add_argument('config', help='tsne config file path')
parser.add_argument('--checkpoint', default=None, help='checkpoint file')
parser.add_argument('--work-dir', help='the dir to save logs and models')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
parser.add_argument(
'--dataset-config',
default='configs/benchmarks/classification/tsne_imagenet.py',
help='extract dataset config file path')
'--vis-stage',
choices=['backbone', 'neck', 'pre_logits'],
default='backbone',
help='the visualization stage of the model')
parser.add_argument(
'--max-num-class',
type=int,
@ -58,6 +54,8 @@ def parse_args():
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
# t-SNE settings
parser.add_argument(
@ -101,6 +99,10 @@ def parse_args():
return args
def post_process():
pass
def main():
args = parse_args()
@ -123,13 +125,6 @@ def main():
cfg.work_dir = osp.join('./work_dirs', work_type,
osp.splitext(osp.basename(args.config))[0])
# init distributed env first, since logger depends on the dist info.
if args.launcher == 'none':
distributed = False
else:
distributed = True
init_dist(args.launcher)
# create work_dir
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
tsne_work_dir = osp.join(cfg.work_dir, f'tsne_{timestamp}/')
@ -143,9 +138,13 @@ def main():
log_file=log_file,
log_level=cfg.log_level)
# build the model from a config file and a checkpoint file
model = init_model(cfg, args.checkpoint, device=args.device)
logger.info(f'Model loaded and the output indices of backbone is '
f'{model.backbone.out_indices}.')
# build the dataset
dataset_cfg = Config.fromfile(args.dataset_config)
extract_dataloader_cfg = dataset_cfg.get('extract_dataloader')
extract_dataloader_cfg = cfg.get('extract_dataloader')
extract_dataset_cfg = extract_dataloader_cfg.pop('dataset')
if isinstance(extract_dataset_cfg, dict):
dataset = DATASETS.build(extract_dataset_cfg)
@ -184,64 +183,51 @@ def main():
worker_init_fn=init_fn,
**extract_dataloader_cfg)
# build the model
# get backbone out_indices from pool_cfg in tsne_dataset_config
cfg.model.backbone.out_indices = dataset_cfg.pool_cfg.get(
'in_indices', [4])
logger.info(
f'The output indices of backbone is {cfg.model.backbone.out_indices}.')
model = MODELS.build(cfg.model)
model.init_weights()
results = dict()
features = []
labels = []
progress_bar = mmengine.ProgressBar(len(tsne_dataloader))
for _, data in enumerate(tsne_dataloader):
with torch.no_grad():
# preprocess data
data = model.data_preprocessor(data)
batch_inputs, batch_data_samples = \
data['inputs'], data['data_samples']
# model is determined in this priority: init_cfg > checkpoint > random
if hasattr(cfg.model.backbone, 'init_cfg'):
if getattr(cfg.model.backbone.init_cfg, 'type', None) == 'Pretrained':
logger.info(
f'Use pretrained model: '
f'{cfg.model.backbone.init_cfg.checkpoint} to extract features'
)
elif args.checkpoint is not None:
logger.info(f'Use checkpoint: {args.checkpoint} to extract features')
load_checkpoint(model, args.checkpoint, map_location='cpu')
else:
logger.info('No pretrained or checkpoint is given, use random init.')
# extract backbone features
batch_features = model.extract_feat(
batch_inputs, stage=args.vis_stage)
if torch.cuda.is_available():
model = model.cuda()
# post process
if args.vis_stage == 'backbone':
if getattr(model.backbone, 'output_cls_token', False) is False:
batch_features = [
F.adaptive_avg_pool2d(inputs, 1).squeeze()
for inputs in batch_features
]
else:
# output_cls_token is True, here t-SNE uses cls_token
batch_features = [feat[-1] for feat in batch_features]
if distributed:
model = MMDistributedDataParallel(
module=model.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False)
batch_labels = torch.cat(
[i.gt_label.label for i in batch_data_samples])
if is_model_wrapper(model):
model = model.module
# save batch features
features.append(batch_features)
labels.extend(batch_labels.cpu().numpy())
progress_bar.update()
# build extractor and extract features
extractor = Extractor(
extract_dataloader=tsne_dataloader,
seed=args.seed,
pool_cfg=dataset_cfg.pool_cfg,
dist_mode=distributed)
features = extractor(model)
labels = tsne_dataloader.dataset.get_gt_labels()
for i in range(len(features[0])):
key = 'feat_' + str(model.backbone.out_indices[i])
results[key] = np.concatenate(
[batch[i].cpu().numpy() for batch in features], axis=0)
# save features
mkdir_or_exist(f'{tsne_work_dir}features/')
logger.info(f'Save features to {tsne_work_dir}features/')
if distributed:
rank = get_rank()
if rank == 0:
for key, val in features.items():
output_file = \
f'{tsne_work_dir}features/{dataset_cfg.name}_{key}.npy'
np.save(output_file, val)
else:
for key, val in features.items():
output_file = \
f'{tsne_work_dir}features/{dataset_cfg.name}_{key}.npy'
np.save(output_file, val)
for key, val in results.items():
output_file = f'{tsne_work_dir}features/{key}.npy'
np.save(output_file, val)
# build t-SNE model
tsne_model = TSNE(
@ -255,8 +241,8 @@ def main():
# run and get results
mkdir_or_exist(f'{tsne_work_dir}saved_pictures/')
logger.info('Running t-SNE......')
for key, val in features.items():
logger.info('Running t-SNE.')
for key, val in results.items():
result = tsne_model.fit_transform(val)
res_min, res_max = result.min(0), result.max(0)
res_norm = (result - res_min) / (res_max - res_min)