[Feature] Support pixel reconstruction visualization (#570)
* refactor reconstruction visualization * support simmim visualization * fix reconstruction bug of MAE * support visualization of MaskFeat * refaction mae visualization demo * add unit test * fix lint and ut * update * add docs * set random seed * update * update docstring * add torch version check * update * rename * update version * update * fix lint * add docstring * update docspull/616/head
parent
d73c953804
commit
73cd764b5f
|
@ -37,3 +37,19 @@ train_dataloader = dict(
|
||||||
ann_file='meta/train.txt',
|
ann_file='meta/train.txt',
|
||||||
data_prefix=dict(img_path='train/'),
|
data_prefix=dict(img_path='train/'),
|
||||||
pipeline=train_pipeline))
|
pipeline=train_pipeline))
|
||||||
|
|
||||||
|
# for visualization
|
||||||
|
vis_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile', file_client_args=file_client_args),
|
||||||
|
dict(type='Resize', scale=(224, 224), backend='pillow'),
|
||||||
|
dict(
|
||||||
|
type='BEiTMaskGenerator',
|
||||||
|
input_size=14,
|
||||||
|
num_masking_patches=78,
|
||||||
|
min_num_patches=15,
|
||||||
|
),
|
||||||
|
dict(
|
||||||
|
type='PackSelfSupInputs',
|
||||||
|
algorithm_keys=['mask'],
|
||||||
|
meta_keys=['img_path'])
|
||||||
|
]
|
||||||
|
|
|
@ -34,3 +34,19 @@ train_dataloader = dict(
|
||||||
ann_file='meta/train.txt',
|
ann_file='meta/train.txt',
|
||||||
data_prefix=dict(img_path='train/'),
|
data_prefix=dict(img_path='train/'),
|
||||||
pipeline=train_pipeline))
|
pipeline=train_pipeline))
|
||||||
|
|
||||||
|
# for visualization
|
||||||
|
vis_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile', file_client_args=file_client_args),
|
||||||
|
dict(type='Resize', scale=(192, 192), backend='pillow'),
|
||||||
|
dict(
|
||||||
|
type='SimMIMMaskGenerator',
|
||||||
|
input_size=192,
|
||||||
|
mask_patch_size=32,
|
||||||
|
model_patch_size=4,
|
||||||
|
mask_ratio=0.6),
|
||||||
|
dict(
|
||||||
|
type='PackSelfSupInputs',
|
||||||
|
algorithm_keys=['mask'],
|
||||||
|
meta_keys=['img_path'])
|
||||||
|
]
|
||||||
|
|
|
@ -6,8 +6,6 @@
|
||||||
"source": [
|
"source": [
|
||||||
"Copyright (c) OpenMMLab. All rights reserved.\n",
|
"Copyright (c) OpenMMLab. All rights reserved.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"Copyright (c) Meta Platforms, Inc. and affiliates.\n",
|
|
||||||
"\n",
|
|
||||||
"Modified from https://colab.research.google.com/github/facebookresearch/mae/blob/main/demo/mae_visualize.ipynb\n",
|
"Modified from https://colab.research.google.com/github/facebookresearch/mae/blob/main/demo/mae_visualize.ipynb\n",
|
||||||
"\n",
|
"\n",
|
||||||
"## Masked Autoencoders: Visualization Demo\n",
|
"## Masked Autoencoders: Visualization Demo\n",
|
||||||
|
@ -36,7 +34,8 @@
|
||||||
" print('Running in Colab.')\n",
|
" print('Running in Colab.')\n",
|
||||||
" !pip3 install openmim\n",
|
" !pip3 install openmim\n",
|
||||||
" !pip install -U openmim\n",
|
" !pip install -U openmim\n",
|
||||||
" !mim install 'mmengine==0.1.0' 'mmcv>=2.0.0rc1'\n",
|
" !mim install mmengine\n",
|
||||||
|
" !mim install 'mmcv>=2.0.0rc1'\n",
|
||||||
"\n",
|
"\n",
|
||||||
" !git clone https://github.com/open-mmlab/mmselfsup.git\n",
|
" !git clone https://github.com/open-mmlab/mmselfsup.git\n",
|
||||||
" %cd mmselfsup/\n",
|
" %cd mmselfsup/\n",
|
||||||
|
@ -51,18 +50,19 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 2,
|
"execution_count": 21,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
"from argparse import ArgumentParser\n",
|
||||||
|
"from typing import Tuple, Optional\n",
|
||||||
|
"\n",
|
||||||
"import matplotlib.pyplot as plt\n",
|
"import matplotlib.pyplot as plt\n",
|
||||||
"import numpy as np\n",
|
"import numpy as np\n",
|
||||||
"import torch\n",
|
"import torch\n",
|
||||||
"from mmengine.dataset import Compose, default_collate\n",
|
"from mmengine.dataset import Compose, default_collate\n",
|
||||||
"\n",
|
"\n",
|
||||||
"from mmselfsup.apis import inference_model\n",
|
"from mmselfsup.apis import inference_model, init_model\n",
|
||||||
"from mmselfsup.models.utils import SelfSupDataPreprocessor\n",
|
|
||||||
"from mmselfsup.registry import MODELS\n",
|
|
||||||
"from mmselfsup.utils import register_all_modules"
|
"from mmselfsup.utils import register_all_modules"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
@ -75,7 +75,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 3,
|
"execution_count": 22,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -84,49 +84,81 @@
|
||||||
"imagenet_mean = np.array([0.485, 0.456, 0.406])\n",
|
"imagenet_mean = np.array([0.485, 0.456, 0.406])\n",
|
||||||
"imagenet_std = np.array([0.229, 0.224, 0.225])\n",
|
"imagenet_std = np.array([0.229, 0.224, 0.225])\n",
|
||||||
"\n",
|
"\n",
|
||||||
"def show_image(image, title=''):\n",
|
"\n",
|
||||||
|
"def show_image(img: torch.Tensor, title: str = '') -> None:\n",
|
||||||
" # image is [H, W, 3]\n",
|
" # image is [H, W, 3]\n",
|
||||||
" assert image.shape[2] == 3\n",
|
" assert img.shape[2] == 3\n",
|
||||||
" image = torch.clip((image * imagenet_std + imagenet_mean) * 255, 0, 255).int()\n",
|
"\n",
|
||||||
" plt.imshow(image)\n",
|
" plt.imshow(img)\n",
|
||||||
" plt.title(title, fontsize=16)\n",
|
" plt.title(title, fontsize=16)\n",
|
||||||
" plt.axis('off')\n",
|
" plt.axis('off')\n",
|
||||||
" return\n",
|
" return\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"def show_images(x, im_masked, y, im_paste):\n",
|
"def save_images(original_img: torch.Tensor, img_masked: torch.Tensor,\n",
|
||||||
|
" pred_img: torch.Tensor, img_paste: torch.Tensor,\n",
|
||||||
|
" out_file: Optional[str] =None) -> None:\n",
|
||||||
" # make the plt figure larger\n",
|
" # make the plt figure larger\n",
|
||||||
" plt.rcParams['figure.figsize'] = [24, 6]\n",
|
" plt.rcParams['figure.figsize'] = [24, 6]\n",
|
||||||
"\n",
|
"\n",
|
||||||
" plt.subplot(1, 4, 1)\n",
|
" plt.subplot(1, 4, 1)\n",
|
||||||
" show_image(x, \"original\")\n",
|
" show_image(original_img, 'original')\n",
|
||||||
"\n",
|
"\n",
|
||||||
" plt.subplot(1, 4, 2)\n",
|
" plt.subplot(1, 4, 2)\n",
|
||||||
" show_image(im_masked, \"masked\")\n",
|
" show_image(img_masked, 'masked')\n",
|
||||||
"\n",
|
"\n",
|
||||||
" plt.subplot(1, 4, 3)\n",
|
" plt.subplot(1, 4, 3)\n",
|
||||||
" show_image(y, \"reconstruction\")\n",
|
" show_image(pred_img, 'reconstruction')\n",
|
||||||
"\n",
|
"\n",
|
||||||
" plt.subplot(1, 4, 4)\n",
|
" plt.subplot(1, 4, 4)\n",
|
||||||
" show_image(im_paste, \"reconstruction + visible\")\n",
|
" show_image(img_paste, 'reconstruction + visible')\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
" if out_file is None:\n",
|
||||||
" plt.show()\n",
|
" plt.show()\n",
|
||||||
|
" else:\n",
|
||||||
|
" plt.savefig(out_file)\n",
|
||||||
|
" print(f'Images are saved to {out_file}')\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"def post_process(x, y, mask):\n",
|
"def recover_norm(img: torch.Tensor,\n",
|
||||||
" x = torch.einsum('nchw->nhwc', x.cpu())\n",
|
" mean: np.ndarray = imagenet_mean,\n",
|
||||||
|
" std: np.ndarray = imagenet_std):\n",
|
||||||
|
" if mean is not None and std is not None:\n",
|
||||||
|
" img = torch.clip((img * std + mean) * 255, 0, 255).int()\n",
|
||||||
|
" return img\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"def post_process(\n",
|
||||||
|
" original_img: torch.Tensor,\n",
|
||||||
|
" pred_img: torch.Tensor,\n",
|
||||||
|
" mask: torch.Tensor,\n",
|
||||||
|
" mean: np.ndarray = imagenet_mean,\n",
|
||||||
|
" std: np.ndarray = imagenet_std\n",
|
||||||
|
") -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n",
|
||||||
|
" # channel conversion\n",
|
||||||
|
" original_img = torch.einsum('nchw->nhwc', original_img.cpu())\n",
|
||||||
" # masked image\n",
|
" # masked image\n",
|
||||||
" im_masked = x * (1 - mask)\n",
|
" img_masked = original_img * (1 - mask)\n",
|
||||||
" # MAE reconstruction pasted with visible patches\n",
|
" # reconstructed image pasted with visible patches\n",
|
||||||
" im_paste = x * (1 - mask) + y * mask\n",
|
" img_paste = original_img * (1 - mask) + pred_img * mask\n",
|
||||||
" return x[0], im_masked[0], y[0], im_paste[0]"
|
"\n",
|
||||||
|
" # muptiply std and add mean to each image\n",
|
||||||
|
" original_img = recover_norm(original_img[0])\n",
|
||||||
|
" img_masked = recover_norm(img_masked[0])\n",
|
||||||
|
"\n",
|
||||||
|
" pred_img = recover_norm(pred_img[0])\n",
|
||||||
|
" img_paste = recover_norm(img_paste[0])\n",
|
||||||
|
"\n",
|
||||||
|
" return original_img, img_masked, pred_img, img_paste\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"### Prepare config file"
|
"### Load a pre-trained MAE model\n",
|
||||||
|
"\n",
|
||||||
|
"This is an MAE model trained with config 'mae_vit-large-p16_8xb512-fp16-coslr-1600e_in1k.py'.\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -138,55 +170,20 @@
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"Overwriting ../configs/selfsup/mae/mae_visualization.py\n"
|
"--2022-11-08 11:00:50-- https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-large-p16_8xb512-fp16-coslr-1600e_in1k/mae_vit-large-p16_8xb512-fp16-coslr-1600e_in1k_20220825-cc7e98c9.pth\n",
|
||||||
|
"正在解析主机 download.openmmlab.com (download.openmmlab.com)... 47.102.71.233\n",
|
||||||
|
"正在连接 download.openmmlab.com (download.openmmlab.com)|47.102.71.233|:443... 已连接。\n",
|
||||||
|
"已发出 HTTP 请求,正在等待回应... 200 OK\n",
|
||||||
|
"长度: 1355429265 (1.3G) [application/octet-stream]\n",
|
||||||
|
"正在保存至: “mae_vit-large-p16_8xb512-fp16-coslr-1600e_in1k_20220825-cc7e98c9.pth”\n",
|
||||||
|
"\n",
|
||||||
|
"e_in1k_20220825-cc7 99%[==================> ] 1.26G 913KB/s 剩余 0s s "
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"%%writefile ../configs/selfsup/mae/mae_visualization.py\n",
|
"# download checkpoint if not exist\n",
|
||||||
"model = dict(\n",
|
"!wget -nc https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-large-p16_8xb512-fp16-coslr-1600e_in1k/mae_vit-large-p16_8xb512-fp16-coslr-1600e_in1k_20220825-cc7e98c9.pth"
|
||||||
" type='MAE',\n",
|
|
||||||
" data_preprocessor=dict(\n",
|
|
||||||
" mean=[123.675, 116.28, 103.53],\n",
|
|
||||||
" std=[58.395, 57.12, 57.375],\n",
|
|
||||||
" bgr_to_rgb=True),\n",
|
|
||||||
" backbone=dict(type='MAEViT', arch='l', patch_size=16, mask_ratio=0.75),\n",
|
|
||||||
" neck=dict(\n",
|
|
||||||
" type='MAEPretrainDecoder',\n",
|
|
||||||
" patch_size=16,\n",
|
|
||||||
" in_chans=3,\n",
|
|
||||||
" embed_dim=1024,\n",
|
|
||||||
" decoder_embed_dim=512,\n",
|
|
||||||
" decoder_depth=8,\n",
|
|
||||||
" decoder_num_heads=16,\n",
|
|
||||||
" mlp_ratio=4.,\n",
|
|
||||||
" ),\n",
|
|
||||||
" head=dict(\n",
|
|
||||||
" type='MAEPretrainHead',\n",
|
|
||||||
" norm_pix=True,\n",
|
|
||||||
" patch_size=16,\n",
|
|
||||||
" loss=dict(type='MAEReconstructionLoss')),\n",
|
|
||||||
" init_cfg=[\n",
|
|
||||||
" dict(type='Xavier', distribution='uniform', layer='Linear'),\n",
|
|
||||||
" dict(type='Constant', layer='LayerNorm', val=1.0, bias=0.0)\n",
|
|
||||||
" ])\n",
|
|
||||||
"\n",
|
|
||||||
"file_client_args = dict(backend='disk')\n",
|
|
||||||
"\n",
|
|
||||||
"# dataset summary\n",
|
|
||||||
"test_dataloader = dict(\n",
|
|
||||||
" dataset=dict(pipeline=[\n",
|
|
||||||
" dict(type='LoadImageFromFile', file_client_args=file_client_args),\n",
|
|
||||||
" dict(type='Resize', scale=(224, 224)),\n",
|
|
||||||
" dict(type='PackSelfSupInputs', meta_keys=['img_path'])\n",
|
|
||||||
" ]))"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"### Load a pre-trained MAE model"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -198,46 +195,21 @@
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"--2022-09-03 00:34:55-- https://download.openmmlab.com/mmselfsup/mae/mae_visualize_vit_large.pth\n",
|
"local loads checkpoint from path: mae_vit-large-p16_8xb512-fp16-coslr-1600e_in1k_20220825-cc7e98c9.pth\n",
|
||||||
"正在解析主机 download.openmmlab.com (download.openmmlab.com)... 47.107.10.247\n",
|
"The model and loaded state dict do not match exactly\n",
|
||||||
"正在连接 download.openmmlab.com (download.openmmlab.com)|47.107.10.247|:443... 已连接。\n",
|
|
||||||
"已发出 HTTP 请求,正在等待回应... 200 OK\n",
|
|
||||||
"长度: 1318299501 (1.2G) [application/octet-stream]\n",
|
|
||||||
"正在保存至: “mae_visualize_vit_large.pth”\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"mae_visualize_vit_l 100%[===================>] 1.23G 3.22MB/s 用时 6m 4s \n",
|
"unexpected key in source state_dict: data_preprocessor.mean, data_preprocessor.std\n",
|
||||||
"\n",
|
"\n",
|
||||||
"2022-09-03 00:40:59 (3.46 MB/s) - 已保存 “mae_visualize_vit_large.pth” [1318299501/1318299501])\n",
|
|
||||||
"\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"# This is an MAE model trained with pixels as targets for visualization (ViT-large, training mask ratio=0.75)\n",
|
|
||||||
"\n",
|
|
||||||
"# download checkpoint if not exist\n",
|
|
||||||
"# This ckpt is converted from https://dl.fbaipublicfiles.com/mae/visualize/mae_visualize_vit_large.pth\n",
|
|
||||||
"!wget -nc https://download.openmmlab.com/mmselfsup/mae/mae_visualize_vit_large.pth"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 6,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"local loads checkpoint from path: mae_visualize_vit_large.pth\n",
|
|
||||||
"Model loaded.\n"
|
"Model loaded.\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"from mmselfsup.apis import init_model\n",
|
"ckpt_path = \"mae_vit-large-p16_8xb512-fp16-coslr-1600e_in1k_20220825-cc7e98c9.pth\"\n",
|
||||||
"ckpt_path = \"mae_visualize_vit_large.pth\"\n",
|
"model = init_model(\n",
|
||||||
"model = init_model('../configs/selfsup/mae/mae_visualization.py', ckpt_path, device='cpu')\n",
|
" '../configs/selfsup/mae/mae_vit-large-p16_8xb512-amp-coslr-1600e_in1k.py',\n",
|
||||||
|
" ckpt_path,\n",
|
||||||
|
" device='cpu')\n",
|
||||||
"print('Model loaded.')"
|
"print('Model loaded.')"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
@ -250,16 +222,16 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 7,
|
"execution_count": 6,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"<torch._C.Generator at 0x7f5029d19950>"
|
"<torch._C.Generator at 0x7fb2ccfbac90>"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 7,
|
"execution_count": 6,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
|
@ -272,23 +244,23 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 8,
|
"execution_count": 7,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"--2022-09-03 00:41:01-- https://download.openmmlab.com/mmselfsup/mae/fox.jpg\n",
|
"--2022-11-08 11:21:14-- https://download.openmmlab.com/mmselfsup/mae/fox.jpg\n",
|
||||||
"正在解析主机 download.openmmlab.com (download.openmmlab.com)... 101.133.111.186\n",
|
"正在解析主机 download.openmmlab.com (download.openmmlab.com)... 47.102.71.233\n",
|
||||||
"正在连接 download.openmmlab.com (download.openmmlab.com)|101.133.111.186|:443... 已连接。\n",
|
"正在连接 download.openmmlab.com (download.openmmlab.com)|47.102.71.233|:443... 已连接。\n",
|
||||||
"已发出 HTTP 请求,正在等待回应... 200 OK\n",
|
"已发出 HTTP 请求,正在等待回应... 200 OK\n",
|
||||||
"长度: 60133 (59K) [image/jpeg]\n",
|
"长度: 60133 (59K) [image/jpeg]\n",
|
||||||
"正在保存至: “fox.jpg”\n",
|
"正在保存至: “fox.jpg”\n",
|
||||||
"\n",
|
"\n",
|
||||||
"fox.jpg 100%[===================>] 58.72K --.-KB/s 用时 0.06s \n",
|
"fox.jpg 100%[===================>] 58.72K --.-KB/s 用时 0.05s \n",
|
||||||
"\n",
|
"\n",
|
||||||
"2022-09-03 00:41:01 (962 KB/s) - 已保存 “fox.jpg” [60133/60133])\n",
|
"2022-11-08 11:21:15 (1.08 MB/s) - 已保存 “fox.jpg” [60133/60133])\n",
|
||||||
"\n"
|
"\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
@ -299,22 +271,34 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 9,
|
"execution_count": 8,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"img_path = 'fox.jpg'"
|
"img_path = 'fox.jpg'"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Build Pipeline"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 10,
|
"execution_count": 10,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"cfg = model.cfg\n",
|
"model.cfg.test_dataloader = dict(\n",
|
||||||
"test_pipeline = Compose(cfg.test_dataloader.dataset.pipeline)\n",
|
" dataset=dict(pipeline=[\n",
|
||||||
"data_preprocessor = MODELS.build(cfg.model.data_preprocessor)"
|
" dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),\n",
|
||||||
|
" dict(type='Resize', scale=(224, 224), backend='pillow'),\n",
|
||||||
|
" dict(type='PackSelfSupInputs', meta_keys=['img_path'])\n",
|
||||||
|
" ]))\n",
|
||||||
|
"\n",
|
||||||
|
"vis_pipeline = Compose(model.cfg.test_dataloader.dataset.pipeline)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -324,36 +308,62 @@
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"data = dict(img_path=img_path)\n",
|
"data = dict(img_path=img_path)\n",
|
||||||
"data = test_pipeline(data)\n",
|
"data = vis_pipeline(data)\n",
|
||||||
"data = default_collate([data])\n",
|
"data = default_collate([data])\n",
|
||||||
"img, _ = data_preprocessor(data, False)"
|
"img, _ = model.data_preprocessor(data, False)"
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"plt.rcParams['figure.figsize'] = [5, 5]\n",
|
|
||||||
"show_image(torch.einsum('nchw->nhwc', img[0].cpu())[0])"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"### Run MAE on the image"
|
"### Reconstruction pipeline"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 13,
|
"execution_count": 12,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"results = inference_model(model, img_path)\n",
|
"# for MAE reconstruction\n",
|
||||||
"x, im_masked, y, im_paste = post_process(img[0], results.pred.value, results.mask.value)"
|
"img_embedding = model.head.patchify(img[0])\n",
|
||||||
|
"# normalize the target image\n",
|
||||||
|
"mean = img_embedding.mean(dim=-1, keepdim=True)\n",
|
||||||
|
"std = (img_embedding.var(dim=-1, keepdim=True) + 1.e-6)**.5"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 16,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# get reconstruction image\n",
|
||||||
|
"features = inference_model(model, img_path)\n",
|
||||||
|
"results = model.reconstruct(features, mean=mean, std=std)\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 17,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"original_target = img[0]\n",
|
||||||
|
"original_img, img_masked, pred_img, img_paste = post_process(\n",
|
||||||
|
" original_target,\n",
|
||||||
|
" results.pred.value,\n",
|
||||||
|
" results.mask.value,\n",
|
||||||
|
" mean=mean,\n",
|
||||||
|
" std=std)\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Show the image"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -362,21 +372,13 @@
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"print('MAE with pixel reconstruction:')\n",
|
"save_images(original_img, img_masked, pred_img, img_paste)"
|
||||||
"show_images(x, im_masked, y, im_paste)"
|
|
||||||
]
|
]
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": []
|
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"kernelspec": {
|
"kernelspec": {
|
||||||
"display_name": "Python 3 (ipykernel)",
|
"display_name": "Python 3.7.0 ('openmmlab')",
|
||||||
"language": "python",
|
"language": "python",
|
||||||
"name": "python3"
|
"name": "python3"
|
||||||
},
|
},
|
||||||
|
@ -390,11 +392,11 @@
|
||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.7.13"
|
"version": "3.7.0"
|
||||||
},
|
},
|
||||||
"vscode": {
|
"vscode": {
|
||||||
"interpreter": {
|
"interpreter": {
|
||||||
"hash": "1742319693997e01e5942276ccf039297cd0a474ab9a20f711b7fa536eca5436"
|
"hash": "5909b3386efe3692f76356628babf720cfd47771f5d858315790cc041eb41361"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
|
@ -7,8 +7,6 @@
|
||||||
- [Publish a model](#publish-a-model)
|
- [Publish a model](#publish-a-model)
|
||||||
- [Reproducibility](#reproducibility)
|
- [Reproducibility](#reproducibility)
|
||||||
- [Log Analysis](#log-analysis)
|
- [Log Analysis](#log-analysis)
|
||||||
- [Visualize Datasets](#visualize-datasets)
|
|
||||||
- [Use t-SNE](#use-t-sne)
|
|
||||||
|
|
||||||
## Count number of parameters
|
## Count number of parameters
|
||||||
|
|
||||||
|
@ -92,53 +90,3 @@ Examples:
|
||||||
time std over epochs is 0.0028
|
time std over epochs is 0.0028
|
||||||
average iter time: 1.1959 s/iter
|
average iter time: 1.1959 s/iter
|
||||||
```
|
```
|
||||||
|
|
||||||
## Visualize Datasets
|
|
||||||
|
|
||||||
`tools/misc/browse_dataset.py` helps the user to browse a mmselfsup dataset (transformed images) visually, or save the image to a designated directory.
|
|
||||||
|
|
||||||
```shell
|
|
||||||
python tools/misc/browse_dataset.py ${CONFIG} [-h] [--skip-type ${SKIP_TYPE[SKIP_TYPE...]}] [--output-dir ${OUTPUT_DIR}] [--not-show] [--show-interval ${SHOW_INTERVAL}]
|
|
||||||
```
|
|
||||||
|
|
||||||
An example:
|
|
||||||
|
|
||||||
```shell
|
|
||||||
python tools/misc/browse_dataset.py configs/selfsup/simsiam/simsiam_resnet50_8xb32-coslr-100e_in1k.py
|
|
||||||
```
|
|
||||||
|
|
||||||
An example of visualization:
|
|
||||||
|
|
||||||
<div align="center">
|
|
||||||
<img src="https://user-images.githubusercontent.com/36138628/199387454-219e6f6c-fbb7-43bb-b319-61d3e6266abc.png" width="600" />
|
|
||||||
</div>
|
|
||||||
|
|
||||||
- The left two pictures are images from contrastive learning data pipeline.
|
|
||||||
- The right one is a masked image.
|
|
||||||
|
|
||||||
## Use t-SNE
|
|
||||||
|
|
||||||
We provide an off-the-shelf tool to visualize the quality of image representations by t-SNE.
|
|
||||||
|
|
||||||
```shell
|
|
||||||
python tools/analysis_tools/visualize_tsne.py ${CONFIG_FILE} --checkpoint ${CKPT_PATH} --work-dir ${WORK_DIR} [optional arguments]
|
|
||||||
```
|
|
||||||
|
|
||||||
Arguments:
|
|
||||||
|
|
||||||
- `CONFIG_FILE`: config file for the pre-trained model.
|
|
||||||
- `CKPT_PATH`: the path of model's checkpoint.
|
|
||||||
- `WORK_DIR`: the directory to save the results of visualization.
|
|
||||||
- `[optional arguments]`: for optional arguments, you can refer to [visualize_tsne.py](https://github.com/open-mmlab/mmselfsup/blob/master/tools/analysis_tools/visualize_tsne.py)
|
|
||||||
|
|
||||||
An example:
|
|
||||||
|
|
||||||
```shell
|
|
||||||
python tools/analysis_tools/visualize_tsne.py configs/selfsup/simsiam/simsiam_resnet50_8xb32-coslr-100e_in1k.py --checkpoint epoch_100.pth --work-dir work_dirs/selfsup/simsiam_resnet50_8xb32-coslr-200e_in1k
|
|
||||||
```
|
|
||||||
|
|
||||||
An example of visualization:
|
|
||||||
|
|
||||||
<div align="center">
|
|
||||||
<img src="https://user-images.githubusercontent.com/36138628/199388251-476a5ad2-f9c1-4dfb-afe2-73cf41b5793b.jpg" width="800" />
|
|
||||||
</div>
|
|
||||||
|
|
|
@ -7,8 +7,11 @@ Visualization can give an intuitive interpretation of the performance of the mod
|
||||||
- [Visualization](#visualization)
|
- [Visualization](#visualization)
|
||||||
- [How visualization is implemented](#how-visualization-is-implemented)
|
- [How visualization is implemented](#how-visualization-is-implemented)
|
||||||
- [What Visualization do in MMSelfsup](#what-visualization-do-in-mmselfsup)
|
- [What Visualization do in MMSelfsup](#what-visualization-do-in-mmselfsup)
|
||||||
- [Use different storage backends](#use-different-storage-backends)
|
- [Use Different Storage Backends](#use-different-storage-backends)
|
||||||
- [Customize Visualization](#customize-visualization)
|
- [Customize Visualization](#customize-visualization)
|
||||||
|
- [Visualize Datasets](#visualize-datasets)
|
||||||
|
- [Visualize t-SNE](#visualize-t-sne)
|
||||||
|
- [Visualize Low-level Feature Reconstruction](#visualize-low-level-feature-reconstruction)
|
||||||
|
|
||||||
<!-- /TOC -->
|
<!-- /TOC -->
|
||||||
|
|
||||||
|
@ -43,7 +46,7 @@ def after_train_iter(...):
|
||||||
|
|
||||||
The function [`add_datasample()`](https://github.com/open-mmlab/mmselfsup/blob/dev-1.x/mmselfsup/visualization/selfsup_visualizer.py#L151) is impleted in [`SelfSupVisualizer`](mmselfsup.visualization.SelfSupVisualizer), and it is mainly used in [browse_dataset.py](https://github.com/open-mmlab/mmselfsup/blob/dev-1.x/tools/analysis_tools/browse_dataset.py) for browsing dataset. More tutorial is in [analysis_tools.md](analysis_tools.md)
|
The function [`add_datasample()`](https://github.com/open-mmlab/mmselfsup/blob/dev-1.x/mmselfsup/visualization/selfsup_visualizer.py#L151) is impleted in [`SelfSupVisualizer`](mmselfsup.visualization.SelfSupVisualizer), and it is mainly used in [browse_dataset.py](https://github.com/open-mmlab/mmselfsup/blob/dev-1.x/tools/analysis_tools/browse_dataset.py) for browsing dataset. More tutorial is in [analysis_tools.md](analysis_tools.md)
|
||||||
|
|
||||||
## Use different storage backends
|
## Use Different Storage Backends
|
||||||
|
|
||||||
If you want to use a different backend (Wandb, Tensorboard, or a custom backend with a remote window), just change the `vis_backends` in the config, as follows:
|
If you want to use a different backend (Wandb, Tensorboard, or a custom backend with a remote window), just change the `vis_backends` in the config, as follows:
|
||||||
|
|
||||||
|
@ -86,3 +89,114 @@ E.g.
|
||||||
## Customize Visualization
|
## Customize Visualization
|
||||||
|
|
||||||
The customization of the visualization is similar to other components. If you want to customize `Visualizer`, `VisBackend` or `VisualizationHook`, you can refer to [Visualization Doc](https://github.com/open-mmlab/mmengine/blob/main/docs/zh_cn/tutorials/visualization.md) in MMEngine.
|
The customization of the visualization is similar to other components. If you want to customize `Visualizer`, `VisBackend` or `VisualizationHook`, you can refer to [Visualization Doc](https://github.com/open-mmlab/mmengine/blob/main/docs/zh_cn/tutorials/visualization.md) in MMEngine.
|
||||||
|
|
||||||
|
## Visualize Datasets
|
||||||
|
|
||||||
|
`tools/misc/browse_dataset.py` helps the user to browse a mmselfsup dataset (transformed images) visually, or save the image to a designated directory.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
python tools/misc/browse_dataset.py ${CONFIG} [-h] [--skip-type ${SKIP_TYPE[SKIP_TYPE...]}] [--output-dir ${OUTPUT_DIR}] [--not-show] [--show-interval ${SHOW_INTERVAL}]
|
||||||
|
```
|
||||||
|
|
||||||
|
An example:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
python tools/misc/browse_dataset.py configs/selfsup/simsiam/simsiam_resnet50_8xb32-coslr-100e_in1k.py
|
||||||
|
```
|
||||||
|
|
||||||
|
An example of visualization:
|
||||||
|
|
||||||
|
<div align="center">
|
||||||
|
<img src="https://user-images.githubusercontent.com/36138628/199387454-219e6f6c-fbb7-43bb-b319-61d3e6266abc.png" width="600" />
|
||||||
|
</div>
|
||||||
|
|
||||||
|
- The left two pictures are images from contrastive learning data pipeline.
|
||||||
|
- The right one is a masked image.
|
||||||
|
|
||||||
|
## Visualize t-SNE
|
||||||
|
|
||||||
|
We provide an off-the-shelf tool to visualize the quality of image representations by t-SNE.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
python tools/analysis_tools/visualize_tsne.py ${CONFIG_FILE} --checkpoint ${CKPT_PATH} --work-dir ${WORK_DIR} [optional arguments]
|
||||||
|
```
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
|
||||||
|
- `CONFIG_FILE`: config file for the pre-trained model.
|
||||||
|
- `CKPT_PATH`: the path of model's checkpoint.
|
||||||
|
- `WORK_DIR`: the directory to save the results of visualization.
|
||||||
|
- `[optional arguments]`: for optional arguments, you can refer to [visualize_tsne.py](https://github.com/open-mmlab/mmselfsup/blob/master/tools/analysis_tools/visualize_tsne.py)
|
||||||
|
|
||||||
|
An example:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
python tools/analysis_tools/visualize_tsne.py configs/selfsup/simsiam/simsiam_resnet50_8xb32-coslr-100e_in1k.py --checkpoint epoch_100.pth --work-dir work_dirs/selfsup/simsiam_resnet50_8xb32-coslr-200e_in1k
|
||||||
|
```
|
||||||
|
|
||||||
|
An example of visualization:
|
||||||
|
|
||||||
|
<div align="center">
|
||||||
|
<img src="https://user-images.githubusercontent.com/36138628/199388251-476a5ad2-f9c1-4dfb-afe2-73cf41b5793b.jpg" width="800" />
|
||||||
|
</div>
|
||||||
|
|
||||||
|
## Visualize Low-level Feature Reconstruction
|
||||||
|
|
||||||
|
We provide several reconstruction visualization for listed algorithms:
|
||||||
|
|
||||||
|
- MAE
|
||||||
|
- SimMIM
|
||||||
|
- MaskFeat
|
||||||
|
|
||||||
|
Users can run command below to visualize the reconstruction.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
python tools/analysis_tools/visualize_reconstruction.py ${CONFIG_FILE} \
|
||||||
|
--checkpoint ${CKPT_PATH} \
|
||||||
|
--img-path ${IMAGE_PATH} \
|
||||||
|
--out-file ${OUTPUT_PATH}
|
||||||
|
```
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
|
||||||
|
- `CONFIG_FILE`: config file for the pre-trained model.
|
||||||
|
- `CKPT_PATH`: the path of model's checkpoint.
|
||||||
|
- `IMAGE_PATH`: the input image path.
|
||||||
|
- `OUTPUT_PATH`: the output image path, including 4 sub-images.
|
||||||
|
- `[optional arguments]`: for optional arguments, you can refer to [visualize_reconstruction.py](https://github.com/open-mmlab/mmselfsup/blob/dev-1.x/tools/analysis_tools/visualize_reconstruction.py)
|
||||||
|
|
||||||
|
An example:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
python tools/analysis_tools/visualize_reconstruction.py configs/selfsup/mae/mae_vit-huge-p16_8xb512-amp-coslr-1600e_in1k.py \
|
||||||
|
--checkpoint https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-huge-p16_8xb512-fp16-coslr-1600e_in1k/mae_vit-huge-p16_8xb512-fp16-coslr-1600e_in1k_20220916-ff848775.pth \
|
||||||
|
--img-path data/imagenet/val/ILSVRC2012_val_00000003.JPEG \
|
||||||
|
--out-file test_mae.jpg \
|
||||||
|
--norm-pix
|
||||||
|
|
||||||
|
|
||||||
|
# As for SimMIM, it generates the mask in data pipeline, thus we use '--use-vis-pipeline' to apply 'vis_pipeline' defined in config instead of the pipeline defined in script.
|
||||||
|
python tools/analysis_tools/visualize_reconstruction.py configs/selfsup/simmim/simmim_swin-large_16xb128-amp-coslr-800e_in1k-192.py \
|
||||||
|
--checkpoint https://download.openmmlab.com/mmselfsup/1.x/simmim/simmim_swin-large_16xb128-amp-coslr-800e_in1k-192/simmim_swin-large_16xb128-amp-coslr-800e_in1k-192_20220916-4ad216d3.pth \
|
||||||
|
--img-path data/imagenet/val/ILSVRC2012_val_00000003.JPEG \
|
||||||
|
--out-file test_simmim.jpg \
|
||||||
|
--use-vis-pipeline
|
||||||
|
```
|
||||||
|
|
||||||
|
Results of MAE:
|
||||||
|
|
||||||
|
<div align="center">
|
||||||
|
<img src="https://user-images.githubusercontent.com/36138628/200465826-83f316ed-5a46-46a9-b665-784b5332d348.jpg" width="800" />
|
||||||
|
</div>
|
||||||
|
|
||||||
|
Results of SimMIM:
|
||||||
|
|
||||||
|
<div align="center">
|
||||||
|
<img src="https://user-images.githubusercontent.com/36138628/200466133-b77bc9af-224b-4810-863c-eed81ddd1afa.jpg" width="800" />
|
||||||
|
</div>
|
||||||
|
|
||||||
|
Results of MaskFeat:
|
||||||
|
|
||||||
|
<div align="center">
|
||||||
|
<img src="https://user-images.githubusercontent.com/36138628/200465876-7e7dcb6f-5e8d-4d80-b300-9e1847cb975f.jpg" width="800" />
|
||||||
|
</div>
|
||||||
|
|
|
@ -81,5 +81,6 @@ def inference_model(model: nn.Module,
|
||||||
|
|
||||||
# forward the model
|
# forward the model
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
results = model.test_step(data)
|
inputs, data_samples = model.data_preprocessor(data, False)
|
||||||
return results
|
features = model(inputs, data_samples, mode='tensor')
|
||||||
|
return features
|
||||||
|
|
|
@ -119,7 +119,7 @@ class BaseModel(_BaseModel):
|
||||||
or ``dict of tensor for custom use.
|
or ``dict of tensor for custom use.
|
||||||
"""
|
"""
|
||||||
if mode == 'tensor':
|
if mode == 'tensor':
|
||||||
feats = self.extract_feat(inputs)
|
feats = self.extract_feat(inputs, data_samples=data_samples)
|
||||||
return feats
|
return feats
|
||||||
elif mode == 'loss':
|
elif mode == 'loss':
|
||||||
return self.loss(inputs, data_samples)
|
return self.loss(inputs, data_samples)
|
||||||
|
|
|
@ -17,7 +17,9 @@ class MAE(BaseModel):
|
||||||
<https://arxiv.org/abs/2111.06377>`_.
|
<https://arxiv.org/abs/2111.06377>`_.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def extract_feat(self, inputs: List[torch.Tensor],
|
def extract_feat(self,
|
||||||
|
inputs: List[torch.Tensor],
|
||||||
|
data_samples: Optional[List[SelfSupDataSample]] = None,
|
||||||
**kwarg) -> Tuple[torch.Tensor]:
|
**kwarg) -> Tuple[torch.Tensor]:
|
||||||
"""The forward function to extract features from neck.
|
"""The forward function to extract features from neck.
|
||||||
|
|
||||||
|
@ -27,33 +29,33 @@ class MAE(BaseModel):
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[torch.Tensor]: Neck outputs.
|
Tuple[torch.Tensor]: Neck outputs.
|
||||||
"""
|
"""
|
||||||
latent, _, ids_restore = self.backbone(inputs[0])
|
latent, mask, ids_restore = self.backbone(inputs[0])
|
||||||
pred = self.neck(latent, ids_restore)
|
pred = self.neck(latent, ids_restore)
|
||||||
|
self.mask = mask
|
||||||
return pred
|
return pred
|
||||||
|
|
||||||
def predict(self,
|
def reconstruct(self,
|
||||||
inputs: List[torch.Tensor],
|
features: torch.Tensor,
|
||||||
data_samples: Optional[List[SelfSupDataSample]] = None,
|
data_samples: Optional[List[SelfSupDataSample]] = None,
|
||||||
**kwargs) -> SelfSupDataSample:
|
**kwargs) -> SelfSupDataSample:
|
||||||
"""The forward function in testing. It is mainly for image
|
"""The function is for image reconstruction.
|
||||||
reconstruction.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
inputs (List[torch.Tensor]): The input images.
|
features (torch.Tensor): The input images.
|
||||||
data_samples (List[SelfSupDataSample]): All elements required
|
data_samples (List[SelfSupDataSample]): All elements required
|
||||||
during the forward function.
|
during the forward function.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
SelfSupDataSample: The prediction from model.
|
SelfSupDataSample: The prediction from model.
|
||||||
"""
|
"""
|
||||||
|
mean = kwargs['mean']
|
||||||
|
std = kwargs['std']
|
||||||
|
features = features * std + mean
|
||||||
|
|
||||||
latent, mask, ids_restore = self.backbone(inputs[0])
|
pred = self.head.unpatchify(features)
|
||||||
pred = self.neck(latent, ids_restore)
|
|
||||||
|
|
||||||
pred = self.head.unpatchify(pred)
|
|
||||||
pred = torch.einsum('nchw->nhwc', pred).detach().cpu()
|
pred = torch.einsum('nchw->nhwc', pred).detach().cpu()
|
||||||
|
|
||||||
mask = mask.detach()
|
mask = self.mask.detach()
|
||||||
mask = mask.unsqueeze(-1).repeat(1, 1, self.head.patch_size**2 *
|
mask = mask.unsqueeze(-1).repeat(1, 1, self.head.patch_size**2 *
|
||||||
3) # (N, H*W, p*p*3)
|
3) # (N, H*W, p*p*3)
|
||||||
mask = self.head.unpatchify(mask) # 1 is removing, 0 is keeping
|
mask = self.head.unpatchify(mask) # 1 is removing, 0 is keeping
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from mmengine.structures import BaseDataElement
|
||||||
|
|
||||||
from mmselfsup.registry import MODELS
|
from mmselfsup.registry import MODELS
|
||||||
from mmselfsup.structures import SelfSupDataSample
|
from mmselfsup.structures import SelfSupDataSample
|
||||||
|
@ -16,8 +17,10 @@ class MaskFeat(BaseModel):
|
||||||
Pre-Training <https://arxiv.org/abs/2112.09133>`_.
|
Pre-Training <https://arxiv.org/abs/2112.09133>`_.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def extract_feat(self, inputs: List[torch.Tensor],
|
def extract_feat(self,
|
||||||
|
inputs: List[torch.Tensor],
|
||||||
data_samples: List[SelfSupDataSample],
|
data_samples: List[SelfSupDataSample],
|
||||||
|
compute_hog: bool = True,
|
||||||
**kwarg) -> Tuple[torch.Tensor]:
|
**kwarg) -> Tuple[torch.Tensor]:
|
||||||
"""The forward function to extract features from neck.
|
"""The forward function to extract features from neck.
|
||||||
|
|
||||||
|
@ -25,15 +28,30 @@ class MaskFeat(BaseModel):
|
||||||
inputs (List[torch.Tensor]): The input images and mask.
|
inputs (List[torch.Tensor]): The input images and mask.
|
||||||
data_samples (List[SelfSupDataSample]): All elements required
|
data_samples (List[SelfSupDataSample]): All elements required
|
||||||
during the forward function.
|
during the forward function.
|
||||||
|
compute_hog (bool): Whether to compute hog during extraction. If
|
||||||
|
True, the batch size of inputs need to be 1. Defaults to True.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[torch.Tensor]: Neck outputs.
|
Tuple[torch.Tensor]: Neck outputs.
|
||||||
"""
|
"""
|
||||||
img = inputs[0]
|
img = inputs[0]
|
||||||
mask = torch.stack(
|
self.mask = torch.stack(
|
||||||
[data_sample.mask.value for data_sample in data_samples])
|
[data_sample.mask.value for data_sample in data_samples])
|
||||||
latent = self.backbone(img, mask)
|
latent = self.backbone(img, self.mask)
|
||||||
return latent
|
B, L, C = latent.shape
|
||||||
|
pred = self.neck([latent.view(B * L, C)])
|
||||||
|
pred = pred[0].view(B, L, -1)
|
||||||
|
|
||||||
|
# compute hog
|
||||||
|
if compute_hog:
|
||||||
|
assert img.size(0) == 1, 'Currently only support batch size 1.'
|
||||||
|
_ = self.target_generator(img)
|
||||||
|
hog_image = torch.from_numpy(
|
||||||
|
self.target_generator.generate_hog_image(
|
||||||
|
self.target_generator.out)).unsqueeze(0).unsqueeze(0)
|
||||||
|
self.target = hog_image.expand(-1, 3, -1, -1)
|
||||||
|
|
||||||
|
return pred[:, 1:, :] # remove cls token
|
||||||
|
|
||||||
def loss(self, inputs: List[torch.Tensor],
|
def loss(self, inputs: List[torch.Tensor],
|
||||||
data_samples: List[SelfSupDataSample],
|
data_samples: List[SelfSupDataSample],
|
||||||
|
@ -62,3 +80,60 @@ class MaskFeat(BaseModel):
|
||||||
loss = self.head(pred, hog, mask)
|
loss = self.head(pred, hog, mask)
|
||||||
losses = dict(loss=loss)
|
losses = dict(loss=loss)
|
||||||
return losses
|
return losses
|
||||||
|
|
||||||
|
def reconstruct(self,
|
||||||
|
features: List[torch.Tensor],
|
||||||
|
data_samples: Optional[List[SelfSupDataSample]] = None,
|
||||||
|
**kwargs) -> SelfSupDataSample:
|
||||||
|
"""The function is for image reconstruction.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
features (List[torch.Tensor]): The input images.
|
||||||
|
data_samples (List[SelfSupDataSample]): All elements required
|
||||||
|
during the forward function.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SelfSupDataSample: The prediction from model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# recover to HOG description from feature embeddings
|
||||||
|
unfold_size = self.target_generator.unfold_size
|
||||||
|
tmp4 = features.unflatten(2,
|
||||||
|
(features.shape[2] // unfold_size**2,
|
||||||
|
unfold_size, unfold_size)) # 1,196,27,2,2
|
||||||
|
tmp3 = tmp4.unflatten(1, self.backbone.patch_resolution)
|
||||||
|
|
||||||
|
b, p1, p2, c_nbins, _, _ = tmp3.shape # 1,14,14,27,2,2
|
||||||
|
tmp2 = tmp3.permute(0, 1, 2, 5, 3, 4).reshape(
|
||||||
|
(b, p1, p2 * unfold_size, c_nbins, unfold_size))
|
||||||
|
tmp1 = tmp2.permute(0, 1, 4, 2, 3).reshape(
|
||||||
|
(b, p1 * unfold_size, p2 * unfold_size, c_nbins))
|
||||||
|
tmp0 = tmp1.permute(0, 3, 1, 2) # 1,27,28,28
|
||||||
|
hog_out = tmp0.unflatten(1,
|
||||||
|
(int(c_nbins // self.target_generator.nbins),
|
||||||
|
self.target_generator.nbins)) # 1,3,9,28,28
|
||||||
|
|
||||||
|
# generate predction of HOG
|
||||||
|
hog_image = torch.from_numpy(
|
||||||
|
self.target_generator.generate_hog_image(hog_out))
|
||||||
|
hog_image = hog_image.unsqueeze(0).unsqueeze(0)
|
||||||
|
pred = torch.einsum('nchw->nhwc', hog_image).expand(-1, -1, -1,
|
||||||
|
3).detach().cpu()
|
||||||
|
|
||||||
|
# transform patch mask to pixel mask
|
||||||
|
mask = self.mask
|
||||||
|
patch_dim_1 = int(self.backbone.patch_embed.init_input_size[0] //
|
||||||
|
self.backbone.patch_resolution[0])
|
||||||
|
patch_dim_2 = int(self.backbone.patch_embed.init_input_size[1] //
|
||||||
|
self.backbone.patch_resolution[1])
|
||||||
|
mask = mask.repeat_interleave(
|
||||||
|
patch_dim_1, dim=1).repeat_interleave(
|
||||||
|
patch_dim_2, dim=2).unsqueeze(-1).repeat(1, 1, 1, 3)
|
||||||
|
# 1 is removing, 0 is keeping
|
||||||
|
mask = mask.detach().cpu()
|
||||||
|
|
||||||
|
results = SelfSupDataSample()
|
||||||
|
results.mask = BaseDataElement(**dict(value=mask))
|
||||||
|
results.pred = BaseDataElement(**dict(value=pred))
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from typing import Dict, List
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from mmengine.structures import BaseDataElement
|
||||||
|
|
||||||
from mmselfsup.registry import MODELS
|
from mmselfsup.registry import MODELS
|
||||||
from mmselfsup.structures import SelfSupDataSample
|
from mmselfsup.structures import SelfSupDataSample
|
||||||
|
@ -33,6 +34,7 @@ class SimMIM(BaseModel):
|
||||||
[data_sample.mask.value for data_sample in data_samples])
|
[data_sample.mask.value for data_sample in data_samples])
|
||||||
img_latent = self.backbone(inputs[0], mask)
|
img_latent = self.backbone(inputs[0], mask)
|
||||||
feat = self.neck(img_latent[0])
|
feat = self.neck(img_latent[0])
|
||||||
|
self.mask = mask
|
||||||
return feat
|
return feat
|
||||||
|
|
||||||
def loss(self, inputs: List[torch.Tensor],
|
def loss(self, inputs: List[torch.Tensor],
|
||||||
|
@ -58,3 +60,37 @@ class SimMIM(BaseModel):
|
||||||
losses = dict(loss=loss)
|
losses = dict(loss=loss)
|
||||||
|
|
||||||
return losses
|
return losses
|
||||||
|
|
||||||
|
def reconstruct(self,
|
||||||
|
features: torch.Tensor,
|
||||||
|
data_samples: Optional[List[SelfSupDataSample]] = None,
|
||||||
|
**kwargs) -> SelfSupDataSample:
|
||||||
|
"""The function is for image reconstruction.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
features (torch.Tensor): The input images.
|
||||||
|
data_samples (List[SelfSupDataSample]): All elements required
|
||||||
|
during the forward function.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SelfSupDataSample: The prediction from model.
|
||||||
|
"""
|
||||||
|
pred = torch.einsum('nchw->nhwc', features).detach().cpu()
|
||||||
|
|
||||||
|
# transform patch mask to pixel mask
|
||||||
|
mask = self.mask.detach()
|
||||||
|
p1 = int(self.backbone.patch_embed.init_input_size[0] //
|
||||||
|
self.backbone.patch_resolution[0])
|
||||||
|
p2 = int(self.backbone.patch_embed.init_input_size[1] //
|
||||||
|
self.backbone.patch_resolution[1])
|
||||||
|
mask = mask.repeat_interleave(
|
||||||
|
p1, dim=1).repeat_interleave(
|
||||||
|
p2, dim=2).unsqueeze(-1).repeat(1, 1, 1, 3) # (N, H, W, 3)
|
||||||
|
# 1 is removing, 0 is keeping
|
||||||
|
mask = mask.detach().cpu()
|
||||||
|
|
||||||
|
results = SelfSupDataSample()
|
||||||
|
results.mask = BaseDataElement(**dict(value=mask))
|
||||||
|
results.pred = BaseDataElement(**dict(value=pred))
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import math
|
import math
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from mmengine.model import BaseModule
|
from mmengine.model import BaseModule
|
||||||
|
@ -13,10 +15,10 @@ class HOGGenerator(BaseModule):
|
||||||
"""Generate HOG feature for images.
|
"""Generate HOG feature for images.
|
||||||
|
|
||||||
This module is used in MaskFeat to generate HOG feature. The code is
|
This module is used in MaskFeat to generate HOG feature. The code is
|
||||||
modified from this `file
|
modified from file `slowfast/models/operators.py
|
||||||
<https://github.com/facebookresearch/SlowFast/blob/main/slowfast/models/operators.py>`_.
|
<https://github.com/facebookresearch/SlowFast/blob/main/slowfast/models/operators.py>`_.
|
||||||
Here is the link `HOG wikipedia
|
Here is the link of `HOG wikipedia
|
||||||
<https://en.m.wikipedia.org/wiki/Histogram_of_oriented_gradients>`_.
|
<https://en.wikipedia.org/wiki/Histogram_of_oriented_gradients>`_.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
nbins (int): Number of bin. Defaults to 9.
|
nbins (int): Number of bin. Defaults to 9.
|
||||||
|
@ -61,12 +63,12 @@ class HOGGenerator(BaseModule):
|
||||||
def _reshape(self, hog_feat: torch.Tensor) -> torch.Tensor:
|
def _reshape(self, hog_feat: torch.Tensor) -> torch.Tensor:
|
||||||
"""Reshape HOG Features for output."""
|
"""Reshape HOG Features for output."""
|
||||||
hog_feat = hog_feat.flatten(1, 2)
|
hog_feat = hog_feat.flatten(1, 2)
|
||||||
unfold_size = hog_feat.shape[-1] // 14
|
self.unfold_size = hog_feat.shape[-1] // 14
|
||||||
hog_feat = (
|
hog_feat = hog_feat.permute(0, 2, 3, 1)
|
||||||
hog_feat.permute(0, 2, 3,
|
hog_feat = hog_feat.unfold(1, self.unfold_size,
|
||||||
1).unfold(1, unfold_size, unfold_size).unfold(
|
self.unfold_size).unfold(
|
||||||
2, unfold_size,
|
2, self.unfold_size, self.unfold_size)
|
||||||
unfold_size).flatten(1, 2).flatten(2))
|
hog_feat = hog_feat.flatten(1, 2).flatten(2)
|
||||||
return hog_feat
|
return hog_feat
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
@ -80,6 +82,7 @@ class HOGGenerator(BaseModule):
|
||||||
torch.Tensor: Hog features.
|
torch.Tensor: Hog features.
|
||||||
"""
|
"""
|
||||||
# input is RGB image with shape [B 3 H W]
|
# input is RGB image with shape [B 3 H W]
|
||||||
|
self.h, self.w = x.size(-2), x.size(-1)
|
||||||
x = F.pad(x, pad=(1, 1, 1, 1), mode='reflect')
|
x = F.pad(x, pad=(1, 1, 1, 1), mode='reflect')
|
||||||
gx_rgb = F.conv2d(
|
gx_rgb = F.conv2d(
|
||||||
x, self.weight_x, bias=None, stride=1, padding=0, groups=3)
|
x, self.weight_x, bias=None, stride=1, padding=0, groups=3)
|
||||||
|
@ -112,6 +115,38 @@ class HOGGenerator(BaseModule):
|
||||||
out = out.unfold(4, self.pool, self.pool)
|
out = out.unfold(4, self.pool, self.pool)
|
||||||
out = out.sum(dim=[-1, -2])
|
out = out.sum(dim=[-1, -2])
|
||||||
|
|
||||||
out = F.normalize(out, p=2, dim=2)
|
self.out = F.normalize(out, p=2, dim=2)
|
||||||
|
|
||||||
return self._reshape(out)
|
return self._reshape(self.out)
|
||||||
|
|
||||||
|
def generate_hog_image(self, hog_out: torch.Tensor) -> np.ndarray:
|
||||||
|
"""Generate HOG image according to HOG features."""
|
||||||
|
assert hog_out.size(0) == 1 and hog_out.size(1) == 3, \
|
||||||
|
'Check the input batch size and the channcel number, only support'\
|
||||||
|
'"batch_size = 1".'
|
||||||
|
hog_image = np.zeros([self.h, self.w])
|
||||||
|
cell_gradient = np.array(hog_out.mean(dim=1).squeeze().detach().cpu())
|
||||||
|
cell_width = self.pool / 2
|
||||||
|
max_mag = np.array(cell_gradient).max()
|
||||||
|
angle_gap = 360 / self.nbins
|
||||||
|
|
||||||
|
for x in range(cell_gradient.shape[1]):
|
||||||
|
for y in range(cell_gradient.shape[2]):
|
||||||
|
cell_grad = cell_gradient[:, x, y]
|
||||||
|
cell_grad /= max_mag
|
||||||
|
angle = 0
|
||||||
|
for magnitude in cell_grad:
|
||||||
|
angle_radian = math.radians(angle)
|
||||||
|
x1 = int(x * self.pool +
|
||||||
|
magnitude * cell_width * math.cos(angle_radian))
|
||||||
|
y1 = int(y * self.pool +
|
||||||
|
magnitude * cell_width * math.sin(angle_radian))
|
||||||
|
x2 = int(x * self.pool -
|
||||||
|
magnitude * cell_width * math.cos(angle_radian))
|
||||||
|
y2 = int(y * self.pool -
|
||||||
|
magnitude * cell_width * math.sin(angle_radian))
|
||||||
|
magnitude = 0 if magnitude < 0 else magnitude
|
||||||
|
cv2.line(hog_image, (y1, x1), (y2, x2),
|
||||||
|
int(255 * math.sqrt(magnitude)))
|
||||||
|
angle += angle_gap
|
||||||
|
return hog_image
|
||||||
|
|
|
@ -27,7 +27,7 @@ class ExampleModel(BaseModel):
|
||||||
super(ExampleModel, self).__init__(backbone=backbone)
|
super(ExampleModel, self).__init__(backbone=backbone)
|
||||||
self.layer = nn.Linear(1, 1)
|
self.layer = nn.Linear(1, 1)
|
||||||
|
|
||||||
def predict(self,
|
def extract_feat(self,
|
||||||
inputs: List[torch.Tensor],
|
inputs: List[torch.Tensor],
|
||||||
data_samples: Optional[List[SelfSupDataSample]] = None,
|
data_samples: Optional[List[SelfSupDataSample]] = None,
|
||||||
**kwargs) -> SelfSupDataSample:
|
**kwargs) -> SelfSupDataSample:
|
||||||
|
|
|
@ -48,9 +48,14 @@ def test_mae():
|
||||||
fake_outputs = alg(fake_batch_inputs, fake_data_samples, mode='loss')
|
fake_outputs = alg(fake_batch_inputs, fake_data_samples, mode='loss')
|
||||||
assert isinstance(fake_outputs['loss'].item(), float)
|
assert isinstance(fake_outputs['loss'].item(), float)
|
||||||
|
|
||||||
|
# test extraction
|
||||||
fake_feats = alg(fake_batch_inputs, fake_data_samples, mode='tensor')
|
fake_feats = alg(fake_batch_inputs, fake_data_samples, mode='tensor')
|
||||||
assert list(fake_feats.shape) == [2, 196, 768]
|
assert list(fake_feats.shape) == [2, 196, 768]
|
||||||
|
|
||||||
results = alg(fake_batch_inputs, fake_data_samples, mode='predict')
|
# test reconstruct
|
||||||
|
mean = fake_feats.mean(dim=-1, keepdim=True)
|
||||||
|
std = (fake_feats.var(dim=-1, keepdim=True) + 1.e-6)**.5
|
||||||
|
results = alg.reconstruct(
|
||||||
|
fake_feats, fake_data_samples, mean=mean, std=std)
|
||||||
assert list(results.mask.value.shape) == [2, 224, 224, 3]
|
assert list(results.mask.value.shape) == [2, 224, 224, 3]
|
||||||
assert list(results.pred.value.shape) == [2, 224, 224, 3]
|
assert list(results.pred.value.shape) == [2, 224, 224, 3]
|
||||||
|
|
|
@ -5,6 +5,7 @@ import platform
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from mmengine.structures import InstanceData
|
from mmengine.structures import InstanceData
|
||||||
|
from mmengine.utils import digit_version
|
||||||
|
|
||||||
from mmselfsup.models.algorithms.maskfeat import MaskFeat
|
from mmselfsup.models.algorithms.maskfeat import MaskFeat
|
||||||
from mmselfsup.structures import SelfSupDataSample
|
from mmselfsup.structures import SelfSupDataSample
|
||||||
|
@ -22,6 +23,9 @@ target_generator = dict(
|
||||||
type='HOGGenerator', nbins=9, pool=8, gaussian_window=16)
|
type='HOGGenerator', nbins=9, pool=8, gaussian_window=16)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
digit_version(torch.__version__) < digit_version('1.7.0'),
|
||||||
|
reason='torch version')
|
||||||
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
|
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
|
||||||
def test_maskfeat():
|
def test_maskfeat():
|
||||||
data_preprocessor = {
|
data_preprocessor = {
|
||||||
|
@ -42,13 +46,19 @@ def test_maskfeat():
|
||||||
fake_mask = InstanceData(value=torch.rand((14, 14)).bool())
|
fake_mask = InstanceData(value=torch.rand((14, 14)).bool())
|
||||||
fake_data_sample.mask = fake_mask
|
fake_data_sample.mask = fake_mask
|
||||||
fake_data = {
|
fake_data = {
|
||||||
'inputs': [torch.randn((2, 3, 224, 224))],
|
'inputs': [torch.randn((1, 3, 224, 224))],
|
||||||
'data_sample': [fake_data_sample for _ in range(2)]
|
'data_sample': [fake_data_sample for _ in range(1)]
|
||||||
}
|
}
|
||||||
|
|
||||||
fake_batch_inputs, fake_data_samples = alg.data_preprocessor(fake_data)
|
fake_batch_inputs, fake_data_samples = alg.data_preprocessor(fake_data)
|
||||||
fake_outputs = alg(fake_batch_inputs, fake_data_samples, mode='loss')
|
fake_outputs = alg(fake_batch_inputs, fake_data_samples, mode='loss')
|
||||||
assert isinstance(fake_outputs['loss'].item(), float)
|
assert isinstance(fake_outputs['loss'].item(), float)
|
||||||
|
|
||||||
|
# test extraction
|
||||||
fake_feats = alg.extract_feat(fake_batch_inputs, fake_data_samples)
|
fake_feats = alg.extract_feat(fake_batch_inputs, fake_data_samples)
|
||||||
assert list(fake_feats.shape) == [2, 197, 768]
|
assert list(fake_feats.shape) == [1, 196, 108]
|
||||||
|
|
||||||
|
# test reconstruction
|
||||||
|
results = alg.reconstruct(fake_feats, fake_data_samples)
|
||||||
|
assert list(results.mask.value.shape) == [1, 224, 224, 3]
|
||||||
|
assert list(results.pred.value.shape) == [1, 224, 224, 3]
|
||||||
|
|
|
@ -50,5 +50,10 @@ def test_simmim():
|
||||||
|
|
||||||
# test extract_feat
|
# test extract_feat
|
||||||
fake_inputs, fake_data_samples = model.data_preprocessor(fake_data)
|
fake_inputs, fake_data_samples = model.data_preprocessor(fake_data)
|
||||||
fake_feat = model.extract_feat(fake_inputs, fake_data_samples)
|
fake_feats = model.extract_feat(fake_inputs, fake_data_samples)
|
||||||
assert list(fake_feat.shape) == [2, 3, 192, 192]
|
assert list(fake_feats.shape) == [2, 3, 192, 192]
|
||||||
|
|
||||||
|
# test reconstruct
|
||||||
|
results = model.reconstruct(fake_feats, fake_data_samples)
|
||||||
|
assert list(results.mask.value.shape) == [2, 192, 192, 3]
|
||||||
|
assert list(results.pred.value.shape) == [2, 192, 192, 3]
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from mmselfsup.models.target_generators import HOGGenerator
|
from mmselfsup.models.target_generators import HOGGenerator
|
||||||
|
@ -10,3 +11,10 @@ def test_hog_generator():
|
||||||
fake_input = torch.randn((2, 3, 224, 224))
|
fake_input = torch.randn((2, 3, 224, 224))
|
||||||
fake_output = hog_generator(fake_input)
|
fake_output = hog_generator(fake_input)
|
||||||
assert list(fake_output.shape) == [2, 196, 108]
|
assert list(fake_output.shape) == [2, 196, 108]
|
||||||
|
|
||||||
|
fake_hog_out = hog_generator.out[0].unsqueeze(0)
|
||||||
|
fake_hog_img = hog_generator.generate_hog_image(fake_hog_out)
|
||||||
|
assert fake_hog_img.shape == (224, 224)
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
fake_hog_img = hog_generator.generate_hog_image(hog_generator.out[0])
|
||||||
|
|
|
@ -1,108 +0,0 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# Modified from https://colab.research.google.com/github/facebookresearch/mae
|
|
||||||
# /blob/main/demo/mae_visualize.ipynb
|
|
||||||
from argparse import ArgumentParser
|
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from mmengine.dataset import Compose, default_collate
|
|
||||||
|
|
||||||
from mmselfsup.apis import inference_model, init_model
|
|
||||||
from mmselfsup.registry import MODELS
|
|
||||||
from mmselfsup.utils import register_all_modules
|
|
||||||
|
|
||||||
imagenet_mean = np.array([0.485, 0.456, 0.406])
|
|
||||||
imagenet_std = np.array([0.229, 0.224, 0.225])
|
|
||||||
|
|
||||||
|
|
||||||
def show_image(image: torch.Tensor, title: str = '') -> None:
|
|
||||||
# image is [H, W, 3]
|
|
||||||
assert image.shape[2] == 3
|
|
||||||
image = torch.clip((image * imagenet_std + imagenet_mean) * 255, 0,
|
|
||||||
255).int()
|
|
||||||
plt.imshow(image)
|
|
||||||
plt.title(title, fontsize=16)
|
|
||||||
plt.axis('off')
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
def save_images(x: torch.Tensor, im_masked: torch.Tensor, y: torch.Tensor,
|
|
||||||
im_paste: torch.Tensor, out_file: str) -> None:
|
|
||||||
# make the plt figure larger
|
|
||||||
plt.rcParams['figure.figsize'] = [24, 6]
|
|
||||||
|
|
||||||
plt.subplot(1, 4, 1)
|
|
||||||
show_image(x, 'original')
|
|
||||||
|
|
||||||
plt.subplot(1, 4, 2)
|
|
||||||
show_image(im_masked, 'masked')
|
|
||||||
|
|
||||||
plt.subplot(1, 4, 3)
|
|
||||||
show_image(y, 'reconstruction')
|
|
||||||
|
|
||||||
plt.subplot(1, 4, 4)
|
|
||||||
show_image(im_paste, 'reconstruction + visible')
|
|
||||||
|
|
||||||
plt.savefig(out_file)
|
|
||||||
|
|
||||||
|
|
||||||
def post_process(
|
|
||||||
x: torch.Tensor, y: torch.Tensor, mask: torch.Tensor
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
||||||
x = torch.einsum('nchw->nhwc', x.cpu())
|
|
||||||
# masked image
|
|
||||||
im_masked = x * (1 - mask)
|
|
||||||
# MAE reconstruction pasted with visible patches
|
|
||||||
im_paste = x * (1 - mask) + y * mask
|
|
||||||
return x[0], im_masked[0], y[0], im_paste[0]
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = ArgumentParser()
|
|
||||||
parser.add_argument('img_path', help='Image file path')
|
|
||||||
parser.add_argument('config', help='MAE Config file')
|
|
||||||
parser.add_argument('checkpoint', help='Checkpoint file')
|
|
||||||
parser.add_argument('out_file', help='The output image file path')
|
|
||||||
parser.add_argument(
|
|
||||||
'--device', default='cuda:0', help='Device used for inference')
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
register_all_modules()
|
|
||||||
|
|
||||||
# build the model from a config file and a checkpoint file
|
|
||||||
model = init_model(args.config, args.checkpoint, device=args.device)
|
|
||||||
print('Model loaded.')
|
|
||||||
|
|
||||||
# make random mask reproducible (comment out to make it change)
|
|
||||||
torch.manual_seed(2)
|
|
||||||
print('MAE with pixel reconstruction:')
|
|
||||||
|
|
||||||
model.cfg.test_dataloader = dict(
|
|
||||||
dataset=dict(pipeline=[
|
|
||||||
dict(
|
|
||||||
type='LoadImageFromFile',
|
|
||||||
file_client_args=dict(backend='disk')),
|
|
||||||
dict(type='Resize', scale=(224, 224)),
|
|
||||||
dict(type='PackSelfSupInputs', meta_keys=['img_path'])
|
|
||||||
]))
|
|
||||||
|
|
||||||
results = inference_model(model, args.img_path)
|
|
||||||
|
|
||||||
cfg = model.cfg
|
|
||||||
test_pipeline = Compose(cfg.test_dataloader.dataset.pipeline)
|
|
||||||
data_preprocessor = MODELS.build(cfg.model.data_preprocessor)
|
|
||||||
data = dict(img_path=args.img_path)
|
|
||||||
data = test_pipeline(data)
|
|
||||||
data = default_collate([data])
|
|
||||||
img, _ = data_preprocessor(data, False)
|
|
||||||
|
|
||||||
x, im_masked, y, im_paste = post_process(img[0], results.pred.value,
|
|
||||||
results.mask.value)
|
|
||||||
save_images(x, im_masked, y, im_paste, args.out_file)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
|
@ -0,0 +1,178 @@
|
||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
# Modified from https://colab.research.google.com/github/facebookresearch/mae
|
||||||
|
# /blob/main/demo/mae_visualize.ipynb
|
||||||
|
import random
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from mmengine.dataset import Compose, default_collate
|
||||||
|
|
||||||
|
from mmselfsup.apis import inference_model, init_model
|
||||||
|
from mmselfsup.utils import register_all_modules
|
||||||
|
|
||||||
|
imagenet_mean = np.array([0.485, 0.456, 0.406])
|
||||||
|
imagenet_std = np.array([0.229, 0.224, 0.225])
|
||||||
|
|
||||||
|
|
||||||
|
def show_image(img: torch.Tensor, title: str = '') -> None:
|
||||||
|
# image is [H, W, 3]
|
||||||
|
assert img.shape[2] == 3
|
||||||
|
|
||||||
|
plt.imshow(img)
|
||||||
|
plt.title(title, fontsize=16)
|
||||||
|
plt.axis('off')
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def save_images(original_img: torch.Tensor, img_masked: torch.Tensor,
|
||||||
|
pred_img: torch.Tensor, img_paste: torch.Tensor,
|
||||||
|
out_file: str) -> None:
|
||||||
|
# make the plt figure larger
|
||||||
|
plt.rcParams['figure.figsize'] = [24, 6]
|
||||||
|
|
||||||
|
plt.subplot(1, 4, 1)
|
||||||
|
show_image(original_img, 'original')
|
||||||
|
|
||||||
|
plt.subplot(1, 4, 2)
|
||||||
|
show_image(img_masked, 'masked')
|
||||||
|
|
||||||
|
plt.subplot(1, 4, 3)
|
||||||
|
show_image(pred_img, 'reconstruction')
|
||||||
|
|
||||||
|
plt.subplot(1, 4, 4)
|
||||||
|
show_image(img_paste, 'reconstruction + visible')
|
||||||
|
|
||||||
|
plt.savefig(out_file)
|
||||||
|
print(f'Images are saved to {out_file}')
|
||||||
|
|
||||||
|
|
||||||
|
def recover_norm(img: torch.Tensor,
|
||||||
|
mean: np.ndarray = imagenet_mean,
|
||||||
|
std: np.ndarray = imagenet_std):
|
||||||
|
if mean is not None and std is not None:
|
||||||
|
img = torch.clip((img * std + mean) * 255, 0, 255).int()
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
def post_process(
|
||||||
|
original_img: torch.Tensor,
|
||||||
|
pred_img: torch.Tensor,
|
||||||
|
mask: torch.Tensor,
|
||||||
|
mean: np.ndarray = imagenet_mean,
|
||||||
|
std: np.ndarray = imagenet_std
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
# channel conversion
|
||||||
|
original_img = torch.einsum('nchw->nhwc', original_img.cpu())
|
||||||
|
# masked image
|
||||||
|
img_masked = original_img * (1 - mask)
|
||||||
|
# reconstructed image pasted with visible patches
|
||||||
|
img_paste = original_img * (1 - mask) + pred_img * mask
|
||||||
|
|
||||||
|
# muptiply std and add mean to each image
|
||||||
|
original_img = recover_norm(original_img[0])
|
||||||
|
img_masked = recover_norm(img_masked[0])
|
||||||
|
|
||||||
|
pred_img = recover_norm(pred_img[0])
|
||||||
|
img_paste = recover_norm(img_paste[0])
|
||||||
|
|
||||||
|
return original_img, img_masked, pred_img, img_paste
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = ArgumentParser()
|
||||||
|
parser.add_argument('config', help='Model config file')
|
||||||
|
parser.add_argument('--checkpoint', help='Checkpoint file')
|
||||||
|
parser.add_argument('--img-path', help='Image file path')
|
||||||
|
parser.add_argument('--out-file', help='The output image file path')
|
||||||
|
parser.add_argument(
|
||||||
|
'--use-vis-pipeline',
|
||||||
|
action='store_true',
|
||||||
|
help='Use vis_pipeline defined in config. For some algorithms, such '
|
||||||
|
'as SimMIM and MaskFeat, they generate mask in data pipeline, thus '
|
||||||
|
'the visualization process applies vis_pipeline in config to obtain '
|
||||||
|
'the mask.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--norm-pix',
|
||||||
|
action='store_true',
|
||||||
|
help='MAE uses `norm_pix_loss` for optimization in pre-training, thus '
|
||||||
|
'the visualization process also need to compute mean and std of each '
|
||||||
|
'patch embedding while reconstructing the original images.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--target-generator',
|
||||||
|
action='store_true',
|
||||||
|
help='Some algorithms use target_generator for optimization in '
|
||||||
|
'pre-training, such as MaskFeat, thus the visualization process could '
|
||||||
|
'turn this on to visualize the target instead of RGB image.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--device', default='cuda:0', help='Device used for inference')
|
||||||
|
parser.add_argument(
|
||||||
|
'--seed',
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help='The random seed for visualization')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
register_all_modules()
|
||||||
|
|
||||||
|
# build the model from a config file and a checkpoint file
|
||||||
|
model = init_model(args.config, args.checkpoint, device=args.device)
|
||||||
|
print('Model loaded.')
|
||||||
|
|
||||||
|
# make random mask reproducible (comment out to make it change)
|
||||||
|
random.seed(args.seed)
|
||||||
|
np.random.seed(args.seed)
|
||||||
|
torch.manual_seed(args.seed)
|
||||||
|
|
||||||
|
print('Reconstruction visualization.')
|
||||||
|
|
||||||
|
if args.use_vis_pipeline:
|
||||||
|
model.cfg.test_dataloader = dict(
|
||||||
|
dataset=dict(pipeline=model.cfg.vis_pipeline))
|
||||||
|
else:
|
||||||
|
model.cfg.test_dataloader = dict(
|
||||||
|
dataset=dict(pipeline=[
|
||||||
|
dict(
|
||||||
|
type='LoadImageFromFile',
|
||||||
|
file_client_args=dict(backend='disk')),
|
||||||
|
dict(type='Resize', scale=(224, 224), backend='pillow'),
|
||||||
|
dict(type='PackSelfSupInputs', meta_keys=['img_path'])
|
||||||
|
]))
|
||||||
|
|
||||||
|
# get original image
|
||||||
|
vis_pipeline = Compose(model.cfg.test_dataloader.dataset.pipeline)
|
||||||
|
data = dict(img_path=args.img_path)
|
||||||
|
data = vis_pipeline(data)
|
||||||
|
data = default_collate([data])
|
||||||
|
img, _ = model.data_preprocessor(data, False)
|
||||||
|
|
||||||
|
if args.norm_pix:
|
||||||
|
# for MAE reconstruction
|
||||||
|
img_embedding = model.head.patchify(img[0])
|
||||||
|
# normalize the target image
|
||||||
|
mean = img_embedding.mean(dim=-1, keepdim=True)
|
||||||
|
std = (img_embedding.var(dim=-1, keepdim=True) + 1.e-6)**.5
|
||||||
|
else:
|
||||||
|
mean = imagenet_mean
|
||||||
|
std = imagenet_std
|
||||||
|
|
||||||
|
# get reconstruction image
|
||||||
|
features = inference_model(model, args.img_path)
|
||||||
|
results = model.reconstruct(features, mean=mean, std=std)
|
||||||
|
|
||||||
|
original_target = model.target if args.target_generator else img[0]
|
||||||
|
|
||||||
|
original_img, img_masked, pred_img, img_paste = post_process(
|
||||||
|
original_target,
|
||||||
|
results.pred.value,
|
||||||
|
results.mask.value,
|
||||||
|
mean=mean,
|
||||||
|
std=std)
|
||||||
|
|
||||||
|
save_images(original_img, img_masked, pred_img, img_paste, args.out_file)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
Loading…
Reference in New Issue