[Refactor] Refactor t-SNE (#629)
* refactor tsne * update configs * update docs * add cls_token optionpull/634/head
parent
fd6659bba5
commit
347b49e77f
|
@ -1,26 +0,0 @@
|
|||
dataset_type = 'mmcls.ImageNet'
|
||||
data_root = 'data/imagenet/'
|
||||
file_client_args = dict(backend='disk')
|
||||
name = 'imagenet_val'
|
||||
|
||||
extract_pipeline = [
|
||||
dict(type='LoadImageFromFile', file_client_args=file_client_args),
|
||||
dict(type='mmcls.ResizeEdge', scale=256, edge='short'),
|
||||
dict(type='CenterCrop', crop_size=224),
|
||||
dict(type='PackSelfSupInputs'),
|
||||
]
|
||||
|
||||
extract_dataloader = dict(
|
||||
batch_size=8,
|
||||
num_workers=4,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
ann_file='meta/val.txt',
|
||||
data_prefix='val',
|
||||
pipeline=extract_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
)
|
||||
|
||||
# pooling cfg
|
||||
pool_cfg = dict(type='MultiPooling', in_indices=(1, 2, 3, 4))
|
|
@ -0,0 +1,48 @@
|
|||
_base_ = 'mmcls::_base_/default_runtime.py'
|
||||
|
||||
model = dict(
|
||||
_scope_='mmcls',
|
||||
type='ImageClassifier',
|
||||
data_preprocessor=dict(
|
||||
num_classes=1000,
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
to_rgb=True,
|
||||
),
|
||||
backbone=dict(
|
||||
type='ResNet',
|
||||
depth=50,
|
||||
in_channels=3,
|
||||
num_stages=4,
|
||||
out_indices=(3),
|
||||
norm_cfg=dict(type='BN'),
|
||||
frozen_stages=-1),
|
||||
neck=dict(type='GlobalAveragePooling'),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=2048,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
topk=(1, 5),
|
||||
))
|
||||
|
||||
dataset_type = 'mmcls.ImageNet'
|
||||
data_root = 'data/imagenet/'
|
||||
file_client_args = dict(backend='disk')
|
||||
extract_pipeline = [
|
||||
dict(type='LoadImageFromFile', file_client_args=file_client_args),
|
||||
dict(type='mmcls.ResizeEdge', scale=256, edge='short'),
|
||||
dict(type='CenterCrop', crop_size=224),
|
||||
dict(type='mmcls.PackClsInputs'),
|
||||
]
|
||||
extract_dataloader = dict(
|
||||
batch_size=8,
|
||||
num_workers=4,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
ann_file='meta/val.txt',
|
||||
data_prefix='val',
|
||||
pipeline=extract_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
)
|
|
@ -0,0 +1,42 @@
|
|||
_base_ = 'mmcls::_base_/default_runtime.py'
|
||||
|
||||
model = dict(
|
||||
_scope_='mmcls',
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='SwinTransformer',
|
||||
arch='base',
|
||||
img_size=192,
|
||||
out_indices=-1,
|
||||
drop_path_rate=0.1,
|
||||
stage_cfgs=dict(block_cfgs=dict(window_size=6))),
|
||||
neck=dict(type='GlobalAveragePooling'),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=1024,
|
||||
init_cfg=None, # suppress the default init_cfg of LinearClsHead.
|
||||
loss=dict(
|
||||
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
|
||||
cal_acc=False))
|
||||
|
||||
dataset_type = 'mmcls.ImageNet'
|
||||
data_root = 'data/imagenet/'
|
||||
file_client_args = dict(backend='disk')
|
||||
extract_pipeline = [
|
||||
dict(type='LoadImageFromFile', file_client_args=file_client_args),
|
||||
dict(type='mmcls.ResizeEdge', scale=256, edge='short'),
|
||||
dict(type='CenterCrop', crop_size=224),
|
||||
dict(type='mmcls.PackClsInputs'),
|
||||
]
|
||||
extract_dataloader = dict(
|
||||
batch_size=8,
|
||||
num_workers=4,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
ann_file='meta/val.txt',
|
||||
data_prefix='val',
|
||||
pipeline=extract_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
)
|
|
@ -0,0 +1,45 @@
|
|||
_base_ = 'mmcls::_base_/default_runtime.py'
|
||||
|
||||
model = dict(
|
||||
_scope_='mmcls',
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='VisionTransformer',
|
||||
arch='base',
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
out_indices=-1,
|
||||
drop_path_rate=0.1,
|
||||
avg_token=False,
|
||||
output_cls_token=False,
|
||||
final_norm=False),
|
||||
neck=None,
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=768,
|
||||
loss=dict(
|
||||
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
|
||||
init_cfg=[dict(type='TruncNormal', layer='Linear', std=2e-5)]),
|
||||
)
|
||||
|
||||
dataset_type = 'mmcls.ImageNet'
|
||||
data_root = 'data/imagenet/'
|
||||
file_client_args = dict(backend='disk')
|
||||
extract_pipeline = [
|
||||
dict(type='LoadImageFromFile', file_client_args=file_client_args),
|
||||
dict(type='mmcls.ResizeEdge', scale=256, edge='short'),
|
||||
dict(type='CenterCrop', crop_size=224),
|
||||
dict(type='mmcls.PackClsInputs'),
|
||||
]
|
||||
extract_dataloader = dict(
|
||||
batch_size=8,
|
||||
num_workers=4,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
ann_file='meta/val.txt',
|
||||
data_prefix='val',
|
||||
pipeline=extract_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
)
|
|
@ -123,21 +123,26 @@ python tools/analysis_tools/visualize_tsne.py ${CONFIG_FILE} --checkpoint ${CKPT
|
|||
|
||||
Arguments:
|
||||
|
||||
- `CONFIG_FILE`: config file for the pre-trained model.
|
||||
- `CKPT_PATH`: the path of model's checkpoint.
|
||||
- `CONFIG_FILE`: config file for t-SNE, which listed in the directory `configs/tsne/`
|
||||
- `CKPT_PATH`: the path or link of the model's checkpoint.
|
||||
- `WORK_DIR`: the directory to save the results of visualization.
|
||||
- `[optional arguments]`: for optional arguments, you can refer to [visualize_tsne.py](https://github.com/open-mmlab/mmselfsup/blob/master/tools/analysis_tools/visualize_tsne.py)
|
||||
- `[optional arguments]`: for optional arguments, you can refer to [visualize_tsne.py](https://github.com/open-mmlab/mmselfsup/blob/dev-1.x/tools/analysis_tools/visualize_tsne.py)
|
||||
|
||||
An example:
|
||||
An example of command:
|
||||
|
||||
```shell
|
||||
python tools/analysis_tools/visualize_tsne.py configs/selfsup/simsiam/simsiam_resnet50_8xb32-coslr-100e_in1k.py --checkpoint epoch_100.pth --work-dir work_dirs/selfsup/simsiam_resnet50_8xb32-coslr-200e_in1k
|
||||
python ./tools/analysis_tools/visualize_tsne.py \
|
||||
configs/tsne/resnet50_imagenet.py \
|
||||
--checkpoint https://download.openmmlab.com/mmselfsup/1.x/mocov2/mocov2_resnet50_8xb32-coslr-200e_in1k/mocov2_resnet50_8xb32-coslr-200e_in1k_20220825-b6d23c86.pth \
|
||||
--work-dir ./work_dirs/tsne/mocov2/ \
|
||||
--max-num-class 100
|
||||
```
|
||||
|
||||
An example of visualization:
|
||||
An example of visualization, left is from `MoCoV2_ResNet50` and right is from `MAE_ViT-base`:
|
||||
|
||||
<div align="center">
|
||||
<img src="https://user-images.githubusercontent.com/36138628/199388251-476a5ad2-f9c1-4dfb-afe2-73cf41b5793b.jpg" width="800" />
|
||||
<img src="https://user-images.githubusercontent.com/36138628/207305086-91df298c-0eb7-4254-9c5b-ba711644501b.png" width="250" />
|
||||
<img src="https://user-images.githubusercontent.com/36138628/207305333-59af4747-1e9c-4f85-a57d-c7e5d132a6e5.png" width="250" />
|
||||
</div>
|
||||
|
||||
## Visualize Low-level Feature Reconstruction
|
||||
|
|
|
@ -6,37 +6,33 @@ from functools import partial
|
|||
from typing import Optional
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import mmengine
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from mmengine.config import Config, DictAction
|
||||
from mmengine.dataset import default_collate, worker_init_fn
|
||||
from mmengine.dist import get_rank, init_dist
|
||||
from mmengine.dist import get_rank
|
||||
from mmengine.logging import MMLogger
|
||||
from mmengine.model.wrappers import MMDistributedDataParallel, is_model_wrapper
|
||||
from mmengine.runner import load_checkpoint
|
||||
from mmengine.utils import mkdir_or_exist
|
||||
from sklearn.manifold import TSNE
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from mmselfsup.models.utils import Extractor
|
||||
from mmselfsup.registry import DATA_SAMPLERS, DATASETS, MODELS
|
||||
from mmselfsup.apis import init_model
|
||||
from mmselfsup.registry import DATA_SAMPLERS, DATASETS
|
||||
from mmselfsup.utils import register_all_modules
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='t-SNE visualization')
|
||||
parser.add_argument('config', help='train config file path')
|
||||
parser.add_argument('config', help='tsne config file path')
|
||||
parser.add_argument('--checkpoint', default=None, help='checkpoint file')
|
||||
parser.add_argument('--work-dir', help='the dir to save logs and models')
|
||||
parser.add_argument(
|
||||
'--launcher',
|
||||
choices=['none', 'pytorch', 'slurm', 'mpi'],
|
||||
default='none',
|
||||
help='job launcher')
|
||||
parser.add_argument(
|
||||
'--dataset-config',
|
||||
default='configs/benchmarks/classification/tsne_imagenet.py',
|
||||
help='extract dataset config file path')
|
||||
'--vis-stage',
|
||||
choices=['backbone', 'neck', 'pre_logits'],
|
||||
default='backbone',
|
||||
help='the visualization stage of the model')
|
||||
parser.add_argument(
|
||||
'--max-num-class',
|
||||
type=int,
|
||||
|
@ -58,6 +54,8 @@ def parse_args():
|
|||
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
|
||||
'Note that the quotation marks are necessary and that no white space '
|
||||
'is allowed.')
|
||||
parser.add_argument(
|
||||
'--device', default='cuda:0', help='Device used for inference')
|
||||
|
||||
# t-SNE settings
|
||||
parser.add_argument(
|
||||
|
@ -101,6 +99,10 @@ def parse_args():
|
|||
return args
|
||||
|
||||
|
||||
def post_process():
|
||||
pass
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
|
@ -123,13 +125,6 @@ def main():
|
|||
cfg.work_dir = osp.join('./work_dirs', work_type,
|
||||
osp.splitext(osp.basename(args.config))[0])
|
||||
|
||||
# init distributed env first, since logger depends on the dist info.
|
||||
if args.launcher == 'none':
|
||||
distributed = False
|
||||
else:
|
||||
distributed = True
|
||||
init_dist(args.launcher)
|
||||
|
||||
# create work_dir
|
||||
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
|
||||
tsne_work_dir = osp.join(cfg.work_dir, f'tsne_{timestamp}/')
|
||||
|
@ -143,9 +138,13 @@ def main():
|
|||
log_file=log_file,
|
||||
log_level=cfg.log_level)
|
||||
|
||||
# build the model from a config file and a checkpoint file
|
||||
model = init_model(cfg, args.checkpoint, device=args.device)
|
||||
logger.info(f'Model loaded and the output indices of backbone is '
|
||||
f'{model.backbone.out_indices}.')
|
||||
|
||||
# build the dataset
|
||||
dataset_cfg = Config.fromfile(args.dataset_config)
|
||||
extract_dataloader_cfg = dataset_cfg.get('extract_dataloader')
|
||||
extract_dataloader_cfg = cfg.get('extract_dataloader')
|
||||
extract_dataset_cfg = extract_dataloader_cfg.pop('dataset')
|
||||
if isinstance(extract_dataset_cfg, dict):
|
||||
dataset = DATASETS.build(extract_dataset_cfg)
|
||||
|
@ -184,64 +183,51 @@ def main():
|
|||
worker_init_fn=init_fn,
|
||||
**extract_dataloader_cfg)
|
||||
|
||||
# build the model
|
||||
# get backbone out_indices from pool_cfg in tsne_dataset_config
|
||||
cfg.model.backbone.out_indices = dataset_cfg.pool_cfg.get(
|
||||
'in_indices', [4])
|
||||
logger.info(
|
||||
f'The output indices of backbone is {cfg.model.backbone.out_indices}.')
|
||||
model = MODELS.build(cfg.model)
|
||||
model.init_weights()
|
||||
results = dict()
|
||||
features = []
|
||||
labels = []
|
||||
progress_bar = mmengine.ProgressBar(len(tsne_dataloader))
|
||||
for _, data in enumerate(tsne_dataloader):
|
||||
with torch.no_grad():
|
||||
# preprocess data
|
||||
data = model.data_preprocessor(data)
|
||||
batch_inputs, batch_data_samples = \
|
||||
data['inputs'], data['data_samples']
|
||||
|
||||
# model is determined in this priority: init_cfg > checkpoint > random
|
||||
if hasattr(cfg.model.backbone, 'init_cfg'):
|
||||
if getattr(cfg.model.backbone.init_cfg, 'type', None) == 'Pretrained':
|
||||
logger.info(
|
||||
f'Use pretrained model: '
|
||||
f'{cfg.model.backbone.init_cfg.checkpoint} to extract features'
|
||||
)
|
||||
elif args.checkpoint is not None:
|
||||
logger.info(f'Use checkpoint: {args.checkpoint} to extract features')
|
||||
load_checkpoint(model, args.checkpoint, map_location='cpu')
|
||||
else:
|
||||
logger.info('No pretrained or checkpoint is given, use random init.')
|
||||
# extract backbone features
|
||||
batch_features = model.extract_feat(
|
||||
batch_inputs, stage=args.vis_stage)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
model = model.cuda()
|
||||
# post process
|
||||
if args.vis_stage == 'backbone':
|
||||
if getattr(model.backbone, 'output_cls_token', False) is False:
|
||||
batch_features = [
|
||||
F.adaptive_avg_pool2d(inputs, 1).squeeze()
|
||||
for inputs in batch_features
|
||||
]
|
||||
else:
|
||||
# output_cls_token is True, here t-SNE uses cls_token
|
||||
batch_features = [feat[-1] for feat in batch_features]
|
||||
|
||||
if distributed:
|
||||
model = MMDistributedDataParallel(
|
||||
module=model.cuda(),
|
||||
device_ids=[torch.cuda.current_device()],
|
||||
broadcast_buffers=False)
|
||||
batch_labels = torch.cat(
|
||||
[i.gt_label.label for i in batch_data_samples])
|
||||
|
||||
if is_model_wrapper(model):
|
||||
model = model.module
|
||||
# save batch features
|
||||
features.append(batch_features)
|
||||
labels.extend(batch_labels.cpu().numpy())
|
||||
progress_bar.update()
|
||||
|
||||
# build extractor and extract features
|
||||
extractor = Extractor(
|
||||
extract_dataloader=tsne_dataloader,
|
||||
seed=args.seed,
|
||||
pool_cfg=dataset_cfg.pool_cfg,
|
||||
dist_mode=distributed)
|
||||
features = extractor(model)
|
||||
labels = tsne_dataloader.dataset.get_gt_labels()
|
||||
for i in range(len(features[0])):
|
||||
key = 'feat_' + str(model.backbone.out_indices[i])
|
||||
results[key] = np.concatenate(
|
||||
[batch[i].cpu().numpy() for batch in features], axis=0)
|
||||
|
||||
# save features
|
||||
mkdir_or_exist(f'{tsne_work_dir}features/')
|
||||
logger.info(f'Save features to {tsne_work_dir}features/')
|
||||
if distributed:
|
||||
rank = get_rank()
|
||||
if rank == 0:
|
||||
for key, val in features.items():
|
||||
output_file = \
|
||||
f'{tsne_work_dir}features/{dataset_cfg.name}_{key}.npy'
|
||||
np.save(output_file, val)
|
||||
else:
|
||||
for key, val in features.items():
|
||||
output_file = \
|
||||
f'{tsne_work_dir}features/{dataset_cfg.name}_{key}.npy'
|
||||
np.save(output_file, val)
|
||||
for key, val in results.items():
|
||||
output_file = f'{tsne_work_dir}features/{key}.npy'
|
||||
np.save(output_file, val)
|
||||
|
||||
# build t-SNE model
|
||||
tsne_model = TSNE(
|
||||
|
@ -255,8 +241,8 @@ def main():
|
|||
|
||||
# run and get results
|
||||
mkdir_or_exist(f'{tsne_work_dir}saved_pictures/')
|
||||
logger.info('Running t-SNE......')
|
||||
for key, val in features.items():
|
||||
logger.info('Running t-SNE.')
|
||||
for key, val in results.items():
|
||||
result = tsne_model.fit_transform(val)
|
||||
res_min, res_max = result.min(0), result.max(0)
|
||||
res_norm = (result - res_min) / (res_max - res_min)
|
||||
|
|
Loading…
Reference in New Issue