From 347b49e77fa0643740bb4fc4570699b106d6cdb3 Mon Sep 17 00:00:00 2001 From: Yixiao Fang <36138628+fangyixiao18@users.noreply.github.com> Date: Wed, 21 Dec 2022 16:21:12 +0800 Subject: [PATCH] [Refactor] Refactor t-SNE (#629) * refactor tsne * update configs * update docs * add cls_token option --- .../classification/tsne_imagenet.py | 26 ---- configs/tsne/resnet50_imagenet.py | 48 +++++++ configs/tsne/swin-base_imagenet.py | 42 ++++++ configs/tsne/vit-base-p16_imagenet.py | 45 ++++++ docs/en/user_guides/visualization.md | 19 ++- tools/analysis_tools/visualize_tsne.py | 134 ++++++++---------- 6 files changed, 207 insertions(+), 107 deletions(-) delete mode 100644 configs/benchmarks/classification/tsne_imagenet.py create mode 100644 configs/tsne/resnet50_imagenet.py create mode 100644 configs/tsne/swin-base_imagenet.py create mode 100644 configs/tsne/vit-base-p16_imagenet.py diff --git a/configs/benchmarks/classification/tsne_imagenet.py b/configs/benchmarks/classification/tsne_imagenet.py deleted file mode 100644 index 0dc0bf68..00000000 --- a/configs/benchmarks/classification/tsne_imagenet.py +++ /dev/null @@ -1,26 +0,0 @@ -dataset_type = 'mmcls.ImageNet' -data_root = 'data/imagenet/' -file_client_args = dict(backend='disk') -name = 'imagenet_val' - -extract_pipeline = [ - dict(type='LoadImageFromFile', file_client_args=file_client_args), - dict(type='mmcls.ResizeEdge', scale=256, edge='short'), - dict(type='CenterCrop', crop_size=224), - dict(type='PackSelfSupInputs'), -] - -extract_dataloader = dict( - batch_size=8, - num_workers=4, - dataset=dict( - type=dataset_type, - data_root='data/imagenet', - ann_file='meta/val.txt', - data_prefix='val', - pipeline=extract_pipeline), - sampler=dict(type='DefaultSampler', shuffle=False), -) - -# pooling cfg -pool_cfg = dict(type='MultiPooling', in_indices=(1, 2, 3, 4)) diff --git a/configs/tsne/resnet50_imagenet.py b/configs/tsne/resnet50_imagenet.py new file mode 100644 index 00000000..7d93906c --- /dev/null +++ b/configs/tsne/resnet50_imagenet.py @@ -0,0 +1,48 @@ +_base_ = 'mmcls::_base_/default_runtime.py' + +model = dict( + _scope_='mmcls', + type='ImageClassifier', + data_preprocessor=dict( + num_classes=1000, + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + to_rgb=True, + ), + backbone=dict( + type='ResNet', + depth=50, + in_channels=3, + num_stages=4, + out_indices=(3), + norm_cfg=dict(type='BN'), + frozen_stages=-1), + neck=dict(type='GlobalAveragePooling'), + head=dict( + type='LinearClsHead', + num_classes=1000, + in_channels=2048, + loss=dict(type='CrossEntropyLoss', loss_weight=1.0), + topk=(1, 5), + )) + +dataset_type = 'mmcls.ImageNet' +data_root = 'data/imagenet/' +file_client_args = dict(backend='disk') +extract_pipeline = [ + dict(type='LoadImageFromFile', file_client_args=file_client_args), + dict(type='mmcls.ResizeEdge', scale=256, edge='short'), + dict(type='CenterCrop', crop_size=224), + dict(type='mmcls.PackClsInputs'), +] +extract_dataloader = dict( + batch_size=8, + num_workers=4, + dataset=dict( + type=dataset_type, + data_root='data/imagenet', + ann_file='meta/val.txt', + data_prefix='val', + pipeline=extract_pipeline), + sampler=dict(type='DefaultSampler', shuffle=False), +) diff --git a/configs/tsne/swin-base_imagenet.py b/configs/tsne/swin-base_imagenet.py new file mode 100644 index 00000000..64827a66 --- /dev/null +++ b/configs/tsne/swin-base_imagenet.py @@ -0,0 +1,42 @@ +_base_ = 'mmcls::_base_/default_runtime.py' + +model = dict( + _scope_='mmcls', + type='ImageClassifier', + backbone=dict( + type='SwinTransformer', + arch='base', + img_size=192, + out_indices=-1, + drop_path_rate=0.1, + stage_cfgs=dict(block_cfgs=dict(window_size=6))), + neck=dict(type='GlobalAveragePooling'), + head=dict( + type='LinearClsHead', + num_classes=1000, + in_channels=1024, + init_cfg=None, # suppress the default init_cfg of LinearClsHead. + loss=dict( + type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'), + cal_acc=False)) + +dataset_type = 'mmcls.ImageNet' +data_root = 'data/imagenet/' +file_client_args = dict(backend='disk') +extract_pipeline = [ + dict(type='LoadImageFromFile', file_client_args=file_client_args), + dict(type='mmcls.ResizeEdge', scale=256, edge='short'), + dict(type='CenterCrop', crop_size=224), + dict(type='mmcls.PackClsInputs'), +] +extract_dataloader = dict( + batch_size=8, + num_workers=4, + dataset=dict( + type=dataset_type, + data_root='data/imagenet', + ann_file='meta/val.txt', + data_prefix='val', + pipeline=extract_pipeline), + sampler=dict(type='DefaultSampler', shuffle=False), +) diff --git a/configs/tsne/vit-base-p16_imagenet.py b/configs/tsne/vit-base-p16_imagenet.py new file mode 100644 index 00000000..1569e589 --- /dev/null +++ b/configs/tsne/vit-base-p16_imagenet.py @@ -0,0 +1,45 @@ +_base_ = 'mmcls::_base_/default_runtime.py' + +model = dict( + _scope_='mmcls', + type='ImageClassifier', + backbone=dict( + type='VisionTransformer', + arch='base', + img_size=224, + patch_size=16, + out_indices=-1, + drop_path_rate=0.1, + avg_token=False, + output_cls_token=False, + final_norm=False), + neck=None, + head=dict( + type='LinearClsHead', + num_classes=1000, + in_channels=768, + loss=dict( + type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'), + init_cfg=[dict(type='TruncNormal', layer='Linear', std=2e-5)]), +) + +dataset_type = 'mmcls.ImageNet' +data_root = 'data/imagenet/' +file_client_args = dict(backend='disk') +extract_pipeline = [ + dict(type='LoadImageFromFile', file_client_args=file_client_args), + dict(type='mmcls.ResizeEdge', scale=256, edge='short'), + dict(type='CenterCrop', crop_size=224), + dict(type='mmcls.PackClsInputs'), +] +extract_dataloader = dict( + batch_size=8, + num_workers=4, + dataset=dict( + type=dataset_type, + data_root='data/imagenet', + ann_file='meta/val.txt', + data_prefix='val', + pipeline=extract_pipeline), + sampler=dict(type='DefaultSampler', shuffle=False), +) diff --git a/docs/en/user_guides/visualization.md b/docs/en/user_guides/visualization.md index 6f167370..3f6e710a 100644 --- a/docs/en/user_guides/visualization.md +++ b/docs/en/user_guides/visualization.md @@ -123,21 +123,26 @@ python tools/analysis_tools/visualize_tsne.py ${CONFIG_FILE} --checkpoint ${CKPT Arguments: -- `CONFIG_FILE`: config file for the pre-trained model. -- `CKPT_PATH`: the path of model's checkpoint. +- `CONFIG_FILE`: config file for t-SNE, which listed in the directory `configs/tsne/` +- `CKPT_PATH`: the path or link of the model's checkpoint. - `WORK_DIR`: the directory to save the results of visualization. -- `[optional arguments]`: for optional arguments, you can refer to [visualize_tsne.py](https://github.com/open-mmlab/mmselfsup/blob/master/tools/analysis_tools/visualize_tsne.py) +- `[optional arguments]`: for optional arguments, you can refer to [visualize_tsne.py](https://github.com/open-mmlab/mmselfsup/blob/dev-1.x/tools/analysis_tools/visualize_tsne.py) -An example: +An example of command: ```shell -python tools/analysis_tools/visualize_tsne.py configs/selfsup/simsiam/simsiam_resnet50_8xb32-coslr-100e_in1k.py --checkpoint epoch_100.pth --work-dir work_dirs/selfsup/simsiam_resnet50_8xb32-coslr-200e_in1k +python ./tools/analysis_tools/visualize_tsne.py \ + configs/tsne/resnet50_imagenet.py \ + --checkpoint https://download.openmmlab.com/mmselfsup/1.x/mocov2/mocov2_resnet50_8xb32-coslr-200e_in1k/mocov2_resnet50_8xb32-coslr-200e_in1k_20220825-b6d23c86.pth \ + --work-dir ./work_dirs/tsne/mocov2/ \ + --max-num-class 100 ``` -An example of visualization: +An example of visualization, left is from `MoCoV2_ResNet50` and right is from `MAE_ViT-base`:
- + +
## Visualize Low-level Feature Reconstruction diff --git a/tools/analysis_tools/visualize_tsne.py b/tools/analysis_tools/visualize_tsne.py index 9459a973..10c2a3a8 100644 --- a/tools/analysis_tools/visualize_tsne.py +++ b/tools/analysis_tools/visualize_tsne.py @@ -6,37 +6,33 @@ from functools import partial from typing import Optional import matplotlib.pyplot as plt +import mmengine import numpy as np import torch +import torch.nn.functional as F from mmengine.config import Config, DictAction from mmengine.dataset import default_collate, worker_init_fn -from mmengine.dist import get_rank, init_dist +from mmengine.dist import get_rank from mmengine.logging import MMLogger -from mmengine.model.wrappers import MMDistributedDataParallel, is_model_wrapper -from mmengine.runner import load_checkpoint from mmengine.utils import mkdir_or_exist from sklearn.manifold import TSNE from torch.utils.data import DataLoader -from mmselfsup.models.utils import Extractor -from mmselfsup.registry import DATA_SAMPLERS, DATASETS, MODELS +from mmselfsup.apis import init_model +from mmselfsup.registry import DATA_SAMPLERS, DATASETS from mmselfsup.utils import register_all_modules def parse_args(): parser = argparse.ArgumentParser(description='t-SNE visualization') - parser.add_argument('config', help='train config file path') + parser.add_argument('config', help='tsne config file path') parser.add_argument('--checkpoint', default=None, help='checkpoint file') parser.add_argument('--work-dir', help='the dir to save logs and models') parser.add_argument( - '--launcher', - choices=['none', 'pytorch', 'slurm', 'mpi'], - default='none', - help='job launcher') - parser.add_argument( - '--dataset-config', - default='configs/benchmarks/classification/tsne_imagenet.py', - help='extract dataset config file path') + '--vis-stage', + choices=['backbone', 'neck', 'pre_logits'], + default='backbone', + help='the visualization stage of the model') parser.add_argument( '--max-num-class', type=int, @@ -58,6 +54,8 @@ def parse_args(): 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 'Note that the quotation marks are necessary and that no white space ' 'is allowed.') + parser.add_argument( + '--device', default='cuda:0', help='Device used for inference') # t-SNE settings parser.add_argument( @@ -101,6 +99,10 @@ def parse_args(): return args +def post_process(): + pass + + def main(): args = parse_args() @@ -123,13 +125,6 @@ def main(): cfg.work_dir = osp.join('./work_dirs', work_type, osp.splitext(osp.basename(args.config))[0]) - # init distributed env first, since logger depends on the dist info. - if args.launcher == 'none': - distributed = False - else: - distributed = True - init_dist(args.launcher) - # create work_dir timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) tsne_work_dir = osp.join(cfg.work_dir, f'tsne_{timestamp}/') @@ -143,9 +138,13 @@ def main(): log_file=log_file, log_level=cfg.log_level) + # build the model from a config file and a checkpoint file + model = init_model(cfg, args.checkpoint, device=args.device) + logger.info(f'Model loaded and the output indices of backbone is ' + f'{model.backbone.out_indices}.') + # build the dataset - dataset_cfg = Config.fromfile(args.dataset_config) - extract_dataloader_cfg = dataset_cfg.get('extract_dataloader') + extract_dataloader_cfg = cfg.get('extract_dataloader') extract_dataset_cfg = extract_dataloader_cfg.pop('dataset') if isinstance(extract_dataset_cfg, dict): dataset = DATASETS.build(extract_dataset_cfg) @@ -184,64 +183,51 @@ def main(): worker_init_fn=init_fn, **extract_dataloader_cfg) - # build the model - # get backbone out_indices from pool_cfg in tsne_dataset_config - cfg.model.backbone.out_indices = dataset_cfg.pool_cfg.get( - 'in_indices', [4]) - logger.info( - f'The output indices of backbone is {cfg.model.backbone.out_indices}.') - model = MODELS.build(cfg.model) - model.init_weights() + results = dict() + features = [] + labels = [] + progress_bar = mmengine.ProgressBar(len(tsne_dataloader)) + for _, data in enumerate(tsne_dataloader): + with torch.no_grad(): + # preprocess data + data = model.data_preprocessor(data) + batch_inputs, batch_data_samples = \ + data['inputs'], data['data_samples'] - # model is determined in this priority: init_cfg > checkpoint > random - if hasattr(cfg.model.backbone, 'init_cfg'): - if getattr(cfg.model.backbone.init_cfg, 'type', None) == 'Pretrained': - logger.info( - f'Use pretrained model: ' - f'{cfg.model.backbone.init_cfg.checkpoint} to extract features' - ) - elif args.checkpoint is not None: - logger.info(f'Use checkpoint: {args.checkpoint} to extract features') - load_checkpoint(model, args.checkpoint, map_location='cpu') - else: - logger.info('No pretrained or checkpoint is given, use random init.') + # extract backbone features + batch_features = model.extract_feat( + batch_inputs, stage=args.vis_stage) - if torch.cuda.is_available(): - model = model.cuda() + # post process + if args.vis_stage == 'backbone': + if getattr(model.backbone, 'output_cls_token', False) is False: + batch_features = [ + F.adaptive_avg_pool2d(inputs, 1).squeeze() + for inputs in batch_features + ] + else: + # output_cls_token is True, here t-SNE uses cls_token + batch_features = [feat[-1] for feat in batch_features] - if distributed: - model = MMDistributedDataParallel( - module=model.cuda(), - device_ids=[torch.cuda.current_device()], - broadcast_buffers=False) + batch_labels = torch.cat( + [i.gt_label.label for i in batch_data_samples]) - if is_model_wrapper(model): - model = model.module + # save batch features + features.append(batch_features) + labels.extend(batch_labels.cpu().numpy()) + progress_bar.update() - # build extractor and extract features - extractor = Extractor( - extract_dataloader=tsne_dataloader, - seed=args.seed, - pool_cfg=dataset_cfg.pool_cfg, - dist_mode=distributed) - features = extractor(model) - labels = tsne_dataloader.dataset.get_gt_labels() + for i in range(len(features[0])): + key = 'feat_' + str(model.backbone.out_indices[i]) + results[key] = np.concatenate( + [batch[i].cpu().numpy() for batch in features], axis=0) # save features mkdir_or_exist(f'{tsne_work_dir}features/') logger.info(f'Save features to {tsne_work_dir}features/') - if distributed: - rank = get_rank() - if rank == 0: - for key, val in features.items(): - output_file = \ - f'{tsne_work_dir}features/{dataset_cfg.name}_{key}.npy' - np.save(output_file, val) - else: - for key, val in features.items(): - output_file = \ - f'{tsne_work_dir}features/{dataset_cfg.name}_{key}.npy' - np.save(output_file, val) + for key, val in results.items(): + output_file = f'{tsne_work_dir}features/{key}.npy' + np.save(output_file, val) # build t-SNE model tsne_model = TSNE( @@ -255,8 +241,8 @@ def main(): # run and get results mkdir_or_exist(f'{tsne_work_dir}saved_pictures/') - logger.info('Running t-SNE......') - for key, val in features.items(): + logger.info('Running t-SNE.') + for key, val in results.items(): result = tsne_model.fit_transform(val) res_min, res_max = result.min(0), result.max(0) res_norm = (result - res_min) / (res_max - res_min)