[Feature] Support pixel reconstruction visualization (#570)
* refactor reconstruction visualization * support simmim visualization * fix reconstruction bug of MAE * support visualization of MaskFeat * refaction mae visualization demo * add unit test * fix lint and ut * update * add docs * set random seed * update * update docstring * add torch version check * update * rename * update version * update * fix lint * add docstring * update docspull/616/head
parent
d73c953804
commit
73cd764b5f
|
@ -37,3 +37,19 @@ train_dataloader = dict(
|
|||
ann_file='meta/train.txt',
|
||||
data_prefix=dict(img_path='train/'),
|
||||
pipeline=train_pipeline))
|
||||
|
||||
# for visualization
|
||||
vis_pipeline = [
|
||||
dict(type='LoadImageFromFile', file_client_args=file_client_args),
|
||||
dict(type='Resize', scale=(224, 224), backend='pillow'),
|
||||
dict(
|
||||
type='BEiTMaskGenerator',
|
||||
input_size=14,
|
||||
num_masking_patches=78,
|
||||
min_num_patches=15,
|
||||
),
|
||||
dict(
|
||||
type='PackSelfSupInputs',
|
||||
algorithm_keys=['mask'],
|
||||
meta_keys=['img_path'])
|
||||
]
|
||||
|
|
|
@ -34,3 +34,19 @@ train_dataloader = dict(
|
|||
ann_file='meta/train.txt',
|
||||
data_prefix=dict(img_path='train/'),
|
||||
pipeline=train_pipeline))
|
||||
|
||||
# for visualization
|
||||
vis_pipeline = [
|
||||
dict(type='LoadImageFromFile', file_client_args=file_client_args),
|
||||
dict(type='Resize', scale=(192, 192), backend='pillow'),
|
||||
dict(
|
||||
type='SimMIMMaskGenerator',
|
||||
input_size=192,
|
||||
mask_patch_size=32,
|
||||
model_patch_size=4,
|
||||
mask_ratio=0.6),
|
||||
dict(
|
||||
type='PackSelfSupInputs',
|
||||
algorithm_keys=['mask'],
|
||||
meta_keys=['img_path'])
|
||||
]
|
||||
|
|
|
@ -6,8 +6,6 @@
|
|||
"source": [
|
||||
"Copyright (c) OpenMMLab. All rights reserved.\n",
|
||||
"\n",
|
||||
"Copyright (c) Meta Platforms, Inc. and affiliates.\n",
|
||||
"\n",
|
||||
"Modified from https://colab.research.google.com/github/facebookresearch/mae/blob/main/demo/mae_visualize.ipynb\n",
|
||||
"\n",
|
||||
"## Masked Autoencoders: Visualization Demo\n",
|
||||
|
@ -36,7 +34,8 @@
|
|||
" print('Running in Colab.')\n",
|
||||
" !pip3 install openmim\n",
|
||||
" !pip install -U openmim\n",
|
||||
" !mim install 'mmengine==0.1.0' 'mmcv>=2.0.0rc1'\n",
|
||||
" !mim install mmengine\n",
|
||||
" !mim install 'mmcv>=2.0.0rc1'\n",
|
||||
"\n",
|
||||
" !git clone https://github.com/open-mmlab/mmselfsup.git\n",
|
||||
" %cd mmselfsup/\n",
|
||||
|
@ -51,18 +50,19 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": 21,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from argparse import ArgumentParser\n",
|
||||
"from typing import Tuple, Optional\n",
|
||||
"\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import numpy as np\n",
|
||||
"import torch\n",
|
||||
"from mmengine.dataset import Compose, default_collate\n",
|
||||
"\n",
|
||||
"from mmselfsup.apis import inference_model\n",
|
||||
"from mmselfsup.models.utils import SelfSupDataPreprocessor\n",
|
||||
"from mmselfsup.registry import MODELS\n",
|
||||
"from mmselfsup.apis import inference_model, init_model\n",
|
||||
"from mmselfsup.utils import register_all_modules"
|
||||
]
|
||||
},
|
||||
|
@ -75,7 +75,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": 22,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -84,49 +84,81 @@
|
|||
"imagenet_mean = np.array([0.485, 0.456, 0.406])\n",
|
||||
"imagenet_std = np.array([0.229, 0.224, 0.225])\n",
|
||||
"\n",
|
||||
"def show_image(image, title=''):\n",
|
||||
"\n",
|
||||
"def show_image(img: torch.Tensor, title: str = '') -> None:\n",
|
||||
" # image is [H, W, 3]\n",
|
||||
" assert image.shape[2] == 3\n",
|
||||
" image = torch.clip((image * imagenet_std + imagenet_mean) * 255, 0, 255).int()\n",
|
||||
" plt.imshow(image)\n",
|
||||
" assert img.shape[2] == 3\n",
|
||||
"\n",
|
||||
" plt.imshow(img)\n",
|
||||
" plt.title(title, fontsize=16)\n",
|
||||
" plt.axis('off')\n",
|
||||
" return\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def show_images(x, im_masked, y, im_paste):\n",
|
||||
"def save_images(original_img: torch.Tensor, img_masked: torch.Tensor,\n",
|
||||
" pred_img: torch.Tensor, img_paste: torch.Tensor,\n",
|
||||
" out_file: Optional[str] =None) -> None:\n",
|
||||
" # make the plt figure larger\n",
|
||||
" plt.rcParams['figure.figsize'] = [24, 6]\n",
|
||||
"\n",
|
||||
" plt.subplot(1, 4, 1)\n",
|
||||
" show_image(x, \"original\")\n",
|
||||
" show_image(original_img, 'original')\n",
|
||||
"\n",
|
||||
" plt.subplot(1, 4, 2)\n",
|
||||
" show_image(im_masked, \"masked\")\n",
|
||||
" show_image(img_masked, 'masked')\n",
|
||||
"\n",
|
||||
" plt.subplot(1, 4, 3)\n",
|
||||
" show_image(y, \"reconstruction\")\n",
|
||||
" show_image(pred_img, 'reconstruction')\n",
|
||||
"\n",
|
||||
" plt.subplot(1, 4, 4)\n",
|
||||
" show_image(im_paste, \"reconstruction + visible\")\n",
|
||||
" show_image(img_paste, 'reconstruction + visible')\n",
|
||||
"\n",
|
||||
" plt.show()\n",
|
||||
" if out_file is None:\n",
|
||||
" plt.show()\n",
|
||||
" else:\n",
|
||||
" plt.savefig(out_file)\n",
|
||||
" print(f'Images are saved to {out_file}')\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def post_process(x, y, mask):\n",
|
||||
" x = torch.einsum('nchw->nhwc', x.cpu())\n",
|
||||
"def recover_norm(img: torch.Tensor,\n",
|
||||
" mean: np.ndarray = imagenet_mean,\n",
|
||||
" std: np.ndarray = imagenet_std):\n",
|
||||
" if mean is not None and std is not None:\n",
|
||||
" img = torch.clip((img * std + mean) * 255, 0, 255).int()\n",
|
||||
" return img\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def post_process(\n",
|
||||
" original_img: torch.Tensor,\n",
|
||||
" pred_img: torch.Tensor,\n",
|
||||
" mask: torch.Tensor,\n",
|
||||
" mean: np.ndarray = imagenet_mean,\n",
|
||||
" std: np.ndarray = imagenet_std\n",
|
||||
") -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n",
|
||||
" # channel conversion\n",
|
||||
" original_img = torch.einsum('nchw->nhwc', original_img.cpu())\n",
|
||||
" # masked image\n",
|
||||
" im_masked = x * (1 - mask)\n",
|
||||
" # MAE reconstruction pasted with visible patches\n",
|
||||
" im_paste = x * (1 - mask) + y * mask\n",
|
||||
" return x[0], im_masked[0], y[0], im_paste[0]"
|
||||
" img_masked = original_img * (1 - mask)\n",
|
||||
" # reconstructed image pasted with visible patches\n",
|
||||
" img_paste = original_img * (1 - mask) + pred_img * mask\n",
|
||||
"\n",
|
||||
" # muptiply std and add mean to each image\n",
|
||||
" original_img = recover_norm(original_img[0])\n",
|
||||
" img_masked = recover_norm(img_masked[0])\n",
|
||||
"\n",
|
||||
" pred_img = recover_norm(pred_img[0])\n",
|
||||
" img_paste = recover_norm(img_paste[0])\n",
|
||||
"\n",
|
||||
" return original_img, img_masked, pred_img, img_paste\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Prepare config file"
|
||||
"### Load a pre-trained MAE model\n",
|
||||
"\n",
|
||||
"This is an MAE model trained with config 'mae_vit-large-p16_8xb512-fp16-coslr-1600e_in1k.py'.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -138,55 +170,20 @@
|
|||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Overwriting ../configs/selfsup/mae/mae_visualization.py\n"
|
||||
"--2022-11-08 11:00:50-- https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-large-p16_8xb512-fp16-coslr-1600e_in1k/mae_vit-large-p16_8xb512-fp16-coslr-1600e_in1k_20220825-cc7e98c9.pth\n",
|
||||
"正在解析主机 download.openmmlab.com (download.openmmlab.com)... 47.102.71.233\n",
|
||||
"正在连接 download.openmmlab.com (download.openmmlab.com)|47.102.71.233|:443... 已连接。\n",
|
||||
"已发出 HTTP 请求,正在等待回应... 200 OK\n",
|
||||
"长度: 1355429265 (1.3G) [application/octet-stream]\n",
|
||||
"正在保存至: “mae_vit-large-p16_8xb512-fp16-coslr-1600e_in1k_20220825-cc7e98c9.pth”\n",
|
||||
"\n",
|
||||
"e_in1k_20220825-cc7 99%[==================> ] 1.26G 913KB/s 剩余 0s s "
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"%%writefile ../configs/selfsup/mae/mae_visualization.py\n",
|
||||
"model = dict(\n",
|
||||
" type='MAE',\n",
|
||||
" data_preprocessor=dict(\n",
|
||||
" mean=[123.675, 116.28, 103.53],\n",
|
||||
" std=[58.395, 57.12, 57.375],\n",
|
||||
" bgr_to_rgb=True),\n",
|
||||
" backbone=dict(type='MAEViT', arch='l', patch_size=16, mask_ratio=0.75),\n",
|
||||
" neck=dict(\n",
|
||||
" type='MAEPretrainDecoder',\n",
|
||||
" patch_size=16,\n",
|
||||
" in_chans=3,\n",
|
||||
" embed_dim=1024,\n",
|
||||
" decoder_embed_dim=512,\n",
|
||||
" decoder_depth=8,\n",
|
||||
" decoder_num_heads=16,\n",
|
||||
" mlp_ratio=4.,\n",
|
||||
" ),\n",
|
||||
" head=dict(\n",
|
||||
" type='MAEPretrainHead',\n",
|
||||
" norm_pix=True,\n",
|
||||
" patch_size=16,\n",
|
||||
" loss=dict(type='MAEReconstructionLoss')),\n",
|
||||
" init_cfg=[\n",
|
||||
" dict(type='Xavier', distribution='uniform', layer='Linear'),\n",
|
||||
" dict(type='Constant', layer='LayerNorm', val=1.0, bias=0.0)\n",
|
||||
" ])\n",
|
||||
"\n",
|
||||
"file_client_args = dict(backend='disk')\n",
|
||||
"\n",
|
||||
"# dataset summary\n",
|
||||
"test_dataloader = dict(\n",
|
||||
" dataset=dict(pipeline=[\n",
|
||||
" dict(type='LoadImageFromFile', file_client_args=file_client_args),\n",
|
||||
" dict(type='Resize', scale=(224, 224)),\n",
|
||||
" dict(type='PackSelfSupInputs', meta_keys=['img_path'])\n",
|
||||
" ]))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Load a pre-trained MAE model"
|
||||
"# download checkpoint if not exist\n",
|
||||
"!wget -nc https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-large-p16_8xb512-fp16-coslr-1600e_in1k/mae_vit-large-p16_8xb512-fp16-coslr-1600e_in1k_20220825-cc7e98c9.pth"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -198,46 +195,21 @@
|
|||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"--2022-09-03 00:34:55-- https://download.openmmlab.com/mmselfsup/mae/mae_visualize_vit_large.pth\n",
|
||||
"正在解析主机 download.openmmlab.com (download.openmmlab.com)... 47.107.10.247\n",
|
||||
"正在连接 download.openmmlab.com (download.openmmlab.com)|47.107.10.247|:443... 已连接。\n",
|
||||
"已发出 HTTP 请求,正在等待回应... 200 OK\n",
|
||||
"长度: 1318299501 (1.2G) [application/octet-stream]\n",
|
||||
"正在保存至: “mae_visualize_vit_large.pth”\n",
|
||||
"local loads checkpoint from path: mae_vit-large-p16_8xb512-fp16-coslr-1600e_in1k_20220825-cc7e98c9.pth\n",
|
||||
"The model and loaded state dict do not match exactly\n",
|
||||
"\n",
|
||||
"mae_visualize_vit_l 100%[===================>] 1.23G 3.22MB/s 用时 6m 4s \n",
|
||||
"unexpected key in source state_dict: data_preprocessor.mean, data_preprocessor.std\n",
|
||||
"\n",
|
||||
"2022-09-03 00:40:59 (3.46 MB/s) - 已保存 “mae_visualize_vit_large.pth” [1318299501/1318299501])\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# This is an MAE model trained with pixels as targets for visualization (ViT-large, training mask ratio=0.75)\n",
|
||||
"\n",
|
||||
"# download checkpoint if not exist\n",
|
||||
"# This ckpt is converted from https://dl.fbaipublicfiles.com/mae/visualize/mae_visualize_vit_large.pth\n",
|
||||
"!wget -nc https://download.openmmlab.com/mmselfsup/mae/mae_visualize_vit_large.pth"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"local loads checkpoint from path: mae_visualize_vit_large.pth\n",
|
||||
"Model loaded.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from mmselfsup.apis import init_model\n",
|
||||
"ckpt_path = \"mae_visualize_vit_large.pth\"\n",
|
||||
"model = init_model('../configs/selfsup/mae/mae_visualization.py', ckpt_path, device='cpu')\n",
|
||||
"ckpt_path = \"mae_vit-large-p16_8xb512-fp16-coslr-1600e_in1k_20220825-cc7e98c9.pth\"\n",
|
||||
"model = init_model(\n",
|
||||
" '../configs/selfsup/mae/mae_vit-large-p16_8xb512-amp-coslr-1600e_in1k.py',\n",
|
||||
" ckpt_path,\n",
|
||||
" device='cpu')\n",
|
||||
"print('Model loaded.')"
|
||||
]
|
||||
},
|
||||
|
@ -250,16 +222,16 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"<torch._C.Generator at 0x7f5029d19950>"
|
||||
"<torch._C.Generator at 0x7fb2ccfbac90>"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
|
@ -272,23 +244,23 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"--2022-09-03 00:41:01-- https://download.openmmlab.com/mmselfsup/mae/fox.jpg\n",
|
||||
"正在解析主机 download.openmmlab.com (download.openmmlab.com)... 101.133.111.186\n",
|
||||
"正在连接 download.openmmlab.com (download.openmmlab.com)|101.133.111.186|:443... 已连接。\n",
|
||||
"--2022-11-08 11:21:14-- https://download.openmmlab.com/mmselfsup/mae/fox.jpg\n",
|
||||
"正在解析主机 download.openmmlab.com (download.openmmlab.com)... 47.102.71.233\n",
|
||||
"正在连接 download.openmmlab.com (download.openmmlab.com)|47.102.71.233|:443... 已连接。\n",
|
||||
"已发出 HTTP 请求,正在等待回应... 200 OK\n",
|
||||
"长度: 60133 (59K) [image/jpeg]\n",
|
||||
"正在保存至: “fox.jpg”\n",
|
||||
"\n",
|
||||
"fox.jpg 100%[===================>] 58.72K --.-KB/s 用时 0.06s \n",
|
||||
"fox.jpg 100%[===================>] 58.72K --.-KB/s 用时 0.05s \n",
|
||||
"\n",
|
||||
"2022-09-03 00:41:01 (962 KB/s) - 已保存 “fox.jpg” [60133/60133])\n",
|
||||
"2022-11-08 11:21:15 (1.08 MB/s) - 已保存 “fox.jpg” [60133/60133])\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
|
@ -299,22 +271,34 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"img_path = 'fox.jpg'"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Build Pipeline"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"cfg = model.cfg\n",
|
||||
"test_pipeline = Compose(cfg.test_dataloader.dataset.pipeline)\n",
|
||||
"data_preprocessor = MODELS.build(cfg.model.data_preprocessor)"
|
||||
"model.cfg.test_dataloader = dict(\n",
|
||||
" dataset=dict(pipeline=[\n",
|
||||
" dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),\n",
|
||||
" dict(type='Resize', scale=(224, 224), backend='pillow'),\n",
|
||||
" dict(type='PackSelfSupInputs', meta_keys=['img_path'])\n",
|
||||
" ]))\n",
|
||||
"\n",
|
||||
"vis_pipeline = Compose(model.cfg.test_dataloader.dataset.pipeline)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -324,36 +308,62 @@
|
|||
"outputs": [],
|
||||
"source": [
|
||||
"data = dict(img_path=img_path)\n",
|
||||
"data = test_pipeline(data)\n",
|
||||
"data = vis_pipeline(data)\n",
|
||||
"data = default_collate([data])\n",
|
||||
"img, _ = data_preprocessor(data, False)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"plt.rcParams['figure.figsize'] = [5, 5]\n",
|
||||
"show_image(torch.einsum('nchw->nhwc', img[0].cpu())[0])"
|
||||
"img, _ = model.data_preprocessor(data, False)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Run MAE on the image"
|
||||
"### Reconstruction pipeline"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"results = inference_model(model, img_path)\n",
|
||||
"x, im_masked, y, im_paste = post_process(img[0], results.pred.value, results.mask.value)"
|
||||
"# for MAE reconstruction\n",
|
||||
"img_embedding = model.head.patchify(img[0])\n",
|
||||
"# normalize the target image\n",
|
||||
"mean = img_embedding.mean(dim=-1, keepdim=True)\n",
|
||||
"std = (img_embedding.var(dim=-1, keepdim=True) + 1.e-6)**.5"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# get reconstruction image\n",
|
||||
"features = inference_model(model, img_path)\n",
|
||||
"results = model.reconstruct(features, mean=mean, std=std)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"original_target = img[0]\n",
|
||||
"original_img, img_masked, pred_img, img_paste = post_process(\n",
|
||||
" original_target,\n",
|
||||
" results.pred.value,\n",
|
||||
" results.mask.value,\n",
|
||||
" mean=mean,\n",
|
||||
" std=std)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Show the image"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -362,21 +372,13 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print('MAE with pixel reconstruction:')\n",
|
||||
"show_images(x, im_masked, y, im_paste)"
|
||||
"save_images(original_img, img_masked, pred_img, img_paste)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"display_name": "Python 3.7.0 ('openmmlab')",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
|
@ -390,11 +392,11 @@
|
|||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.13"
|
||||
"version": "3.7.0"
|
||||
},
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
"hash": "1742319693997e01e5942276ccf039297cd0a474ab9a20f711b7fa536eca5436"
|
||||
"hash": "5909b3386efe3692f76356628babf720cfd47771f5d858315790cc041eb41361"
|
||||
}
|
||||
}
|
||||
},
|
||||
|
|
|
@ -7,8 +7,6 @@
|
|||
- [Publish a model](#publish-a-model)
|
||||
- [Reproducibility](#reproducibility)
|
||||
- [Log Analysis](#log-analysis)
|
||||
- [Visualize Datasets](#visualize-datasets)
|
||||
- [Use t-SNE](#use-t-sne)
|
||||
|
||||
## Count number of parameters
|
||||
|
||||
|
@ -92,53 +90,3 @@ Examples:
|
|||
time std over epochs is 0.0028
|
||||
average iter time: 1.1959 s/iter
|
||||
```
|
||||
|
||||
## Visualize Datasets
|
||||
|
||||
`tools/misc/browse_dataset.py` helps the user to browse a mmselfsup dataset (transformed images) visually, or save the image to a designated directory.
|
||||
|
||||
```shell
|
||||
python tools/misc/browse_dataset.py ${CONFIG} [-h] [--skip-type ${SKIP_TYPE[SKIP_TYPE...]}] [--output-dir ${OUTPUT_DIR}] [--not-show] [--show-interval ${SHOW_INTERVAL}]
|
||||
```
|
||||
|
||||
An example:
|
||||
|
||||
```shell
|
||||
python tools/misc/browse_dataset.py configs/selfsup/simsiam/simsiam_resnet50_8xb32-coslr-100e_in1k.py
|
||||
```
|
||||
|
||||
An example of visualization:
|
||||
|
||||
<div align="center">
|
||||
<img src="https://user-images.githubusercontent.com/36138628/199387454-219e6f6c-fbb7-43bb-b319-61d3e6266abc.png" width="600" />
|
||||
</div>
|
||||
|
||||
- The left two pictures are images from contrastive learning data pipeline.
|
||||
- The right one is a masked image.
|
||||
|
||||
## Use t-SNE
|
||||
|
||||
We provide an off-the-shelf tool to visualize the quality of image representations by t-SNE.
|
||||
|
||||
```shell
|
||||
python tools/analysis_tools/visualize_tsne.py ${CONFIG_FILE} --checkpoint ${CKPT_PATH} --work-dir ${WORK_DIR} [optional arguments]
|
||||
```
|
||||
|
||||
Arguments:
|
||||
|
||||
- `CONFIG_FILE`: config file for the pre-trained model.
|
||||
- `CKPT_PATH`: the path of model's checkpoint.
|
||||
- `WORK_DIR`: the directory to save the results of visualization.
|
||||
- `[optional arguments]`: for optional arguments, you can refer to [visualize_tsne.py](https://github.com/open-mmlab/mmselfsup/blob/master/tools/analysis_tools/visualize_tsne.py)
|
||||
|
||||
An example:
|
||||
|
||||
```shell
|
||||
python tools/analysis_tools/visualize_tsne.py configs/selfsup/simsiam/simsiam_resnet50_8xb32-coslr-100e_in1k.py --checkpoint epoch_100.pth --work-dir work_dirs/selfsup/simsiam_resnet50_8xb32-coslr-200e_in1k
|
||||
```
|
||||
|
||||
An example of visualization:
|
||||
|
||||
<div align="center">
|
||||
<img src="https://user-images.githubusercontent.com/36138628/199388251-476a5ad2-f9c1-4dfb-afe2-73cf41b5793b.jpg" width="800" />
|
||||
</div>
|
||||
|
|
|
@ -7,8 +7,11 @@ Visualization can give an intuitive interpretation of the performance of the mod
|
|||
- [Visualization](#visualization)
|
||||
- [How visualization is implemented](#how-visualization-is-implemented)
|
||||
- [What Visualization do in MMSelfsup](#what-visualization-do-in-mmselfsup)
|
||||
- [Use different storage backends](#use-different-storage-backends)
|
||||
- [Use Different Storage Backends](#use-different-storage-backends)
|
||||
- [Customize Visualization](#customize-visualization)
|
||||
- [Visualize Datasets](#visualize-datasets)
|
||||
- [Visualize t-SNE](#visualize-t-sne)
|
||||
- [Visualize Low-level Feature Reconstruction](#visualize-low-level-feature-reconstruction)
|
||||
|
||||
<!-- /TOC -->
|
||||
|
||||
|
@ -43,7 +46,7 @@ def after_train_iter(...):
|
|||
|
||||
The function [`add_datasample()`](https://github.com/open-mmlab/mmselfsup/blob/dev-1.x/mmselfsup/visualization/selfsup_visualizer.py#L151) is impleted in [`SelfSupVisualizer`](mmselfsup.visualization.SelfSupVisualizer), and it is mainly used in [browse_dataset.py](https://github.com/open-mmlab/mmselfsup/blob/dev-1.x/tools/analysis_tools/browse_dataset.py) for browsing dataset. More tutorial is in [analysis_tools.md](analysis_tools.md)
|
||||
|
||||
## Use different storage backends
|
||||
## Use Different Storage Backends
|
||||
|
||||
If you want to use a different backend (Wandb, Tensorboard, or a custom backend with a remote window), just change the `vis_backends` in the config, as follows:
|
||||
|
||||
|
@ -86,3 +89,114 @@ E.g.
|
|||
## Customize Visualization
|
||||
|
||||
The customization of the visualization is similar to other components. If you want to customize `Visualizer`, `VisBackend` or `VisualizationHook`, you can refer to [Visualization Doc](https://github.com/open-mmlab/mmengine/blob/main/docs/zh_cn/tutorials/visualization.md) in MMEngine.
|
||||
|
||||
## Visualize Datasets
|
||||
|
||||
`tools/misc/browse_dataset.py` helps the user to browse a mmselfsup dataset (transformed images) visually, or save the image to a designated directory.
|
||||
|
||||
```shell
|
||||
python tools/misc/browse_dataset.py ${CONFIG} [-h] [--skip-type ${SKIP_TYPE[SKIP_TYPE...]}] [--output-dir ${OUTPUT_DIR}] [--not-show] [--show-interval ${SHOW_INTERVAL}]
|
||||
```
|
||||
|
||||
An example:
|
||||
|
||||
```shell
|
||||
python tools/misc/browse_dataset.py configs/selfsup/simsiam/simsiam_resnet50_8xb32-coslr-100e_in1k.py
|
||||
```
|
||||
|
||||
An example of visualization:
|
||||
|
||||
<div align="center">
|
||||
<img src="https://user-images.githubusercontent.com/36138628/199387454-219e6f6c-fbb7-43bb-b319-61d3e6266abc.png" width="600" />
|
||||
</div>
|
||||
|
||||
- The left two pictures are images from contrastive learning data pipeline.
|
||||
- The right one is a masked image.
|
||||
|
||||
## Visualize t-SNE
|
||||
|
||||
We provide an off-the-shelf tool to visualize the quality of image representations by t-SNE.
|
||||
|
||||
```shell
|
||||
python tools/analysis_tools/visualize_tsne.py ${CONFIG_FILE} --checkpoint ${CKPT_PATH} --work-dir ${WORK_DIR} [optional arguments]
|
||||
```
|
||||
|
||||
Arguments:
|
||||
|
||||
- `CONFIG_FILE`: config file for the pre-trained model.
|
||||
- `CKPT_PATH`: the path of model's checkpoint.
|
||||
- `WORK_DIR`: the directory to save the results of visualization.
|
||||
- `[optional arguments]`: for optional arguments, you can refer to [visualize_tsne.py](https://github.com/open-mmlab/mmselfsup/blob/master/tools/analysis_tools/visualize_tsne.py)
|
||||
|
||||
An example:
|
||||
|
||||
```shell
|
||||
python tools/analysis_tools/visualize_tsne.py configs/selfsup/simsiam/simsiam_resnet50_8xb32-coslr-100e_in1k.py --checkpoint epoch_100.pth --work-dir work_dirs/selfsup/simsiam_resnet50_8xb32-coslr-200e_in1k
|
||||
```
|
||||
|
||||
An example of visualization:
|
||||
|
||||
<div align="center">
|
||||
<img src="https://user-images.githubusercontent.com/36138628/199388251-476a5ad2-f9c1-4dfb-afe2-73cf41b5793b.jpg" width="800" />
|
||||
</div>
|
||||
|
||||
## Visualize Low-level Feature Reconstruction
|
||||
|
||||
We provide several reconstruction visualization for listed algorithms:
|
||||
|
||||
- MAE
|
||||
- SimMIM
|
||||
- MaskFeat
|
||||
|
||||
Users can run command below to visualize the reconstruction.
|
||||
|
||||
```shell
|
||||
python tools/analysis_tools/visualize_reconstruction.py ${CONFIG_FILE} \
|
||||
--checkpoint ${CKPT_PATH} \
|
||||
--img-path ${IMAGE_PATH} \
|
||||
--out-file ${OUTPUT_PATH}
|
||||
```
|
||||
|
||||
Arguments:
|
||||
|
||||
- `CONFIG_FILE`: config file for the pre-trained model.
|
||||
- `CKPT_PATH`: the path of model's checkpoint.
|
||||
- `IMAGE_PATH`: the input image path.
|
||||
- `OUTPUT_PATH`: the output image path, including 4 sub-images.
|
||||
- `[optional arguments]`: for optional arguments, you can refer to [visualize_reconstruction.py](https://github.com/open-mmlab/mmselfsup/blob/dev-1.x/tools/analysis_tools/visualize_reconstruction.py)
|
||||
|
||||
An example:
|
||||
|
||||
```shell
|
||||
python tools/analysis_tools/visualize_reconstruction.py configs/selfsup/mae/mae_vit-huge-p16_8xb512-amp-coslr-1600e_in1k.py \
|
||||
--checkpoint https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-huge-p16_8xb512-fp16-coslr-1600e_in1k/mae_vit-huge-p16_8xb512-fp16-coslr-1600e_in1k_20220916-ff848775.pth \
|
||||
--img-path data/imagenet/val/ILSVRC2012_val_00000003.JPEG \
|
||||
--out-file test_mae.jpg \
|
||||
--norm-pix
|
||||
|
||||
|
||||
# As for SimMIM, it generates the mask in data pipeline, thus we use '--use-vis-pipeline' to apply 'vis_pipeline' defined in config instead of the pipeline defined in script.
|
||||
python tools/analysis_tools/visualize_reconstruction.py configs/selfsup/simmim/simmim_swin-large_16xb128-amp-coslr-800e_in1k-192.py \
|
||||
--checkpoint https://download.openmmlab.com/mmselfsup/1.x/simmim/simmim_swin-large_16xb128-amp-coslr-800e_in1k-192/simmim_swin-large_16xb128-amp-coslr-800e_in1k-192_20220916-4ad216d3.pth \
|
||||
--img-path data/imagenet/val/ILSVRC2012_val_00000003.JPEG \
|
||||
--out-file test_simmim.jpg \
|
||||
--use-vis-pipeline
|
||||
```
|
||||
|
||||
Results of MAE:
|
||||
|
||||
<div align="center">
|
||||
<img src="https://user-images.githubusercontent.com/36138628/200465826-83f316ed-5a46-46a9-b665-784b5332d348.jpg" width="800" />
|
||||
</div>
|
||||
|
||||
Results of SimMIM:
|
||||
|
||||
<div align="center">
|
||||
<img src="https://user-images.githubusercontent.com/36138628/200466133-b77bc9af-224b-4810-863c-eed81ddd1afa.jpg" width="800" />
|
||||
</div>
|
||||
|
||||
Results of MaskFeat:
|
||||
|
||||
<div align="center">
|
||||
<img src="https://user-images.githubusercontent.com/36138628/200465876-7e7dcb6f-5e8d-4d80-b300-9e1847cb975f.jpg" width="800" />
|
||||
</div>
|
||||
|
|
|
@ -81,5 +81,6 @@ def inference_model(model: nn.Module,
|
|||
|
||||
# forward the model
|
||||
with torch.no_grad():
|
||||
results = model.test_step(data)
|
||||
return results
|
||||
inputs, data_samples = model.data_preprocessor(data, False)
|
||||
features = model(inputs, data_samples, mode='tensor')
|
||||
return features
|
||||
|
|
|
@ -119,7 +119,7 @@ class BaseModel(_BaseModel):
|
|||
or ``dict of tensor for custom use.
|
||||
"""
|
||||
if mode == 'tensor':
|
||||
feats = self.extract_feat(inputs)
|
||||
feats = self.extract_feat(inputs, data_samples=data_samples)
|
||||
return feats
|
||||
elif mode == 'loss':
|
||||
return self.loss(inputs, data_samples)
|
||||
|
|
|
@ -17,7 +17,9 @@ class MAE(BaseModel):
|
|||
<https://arxiv.org/abs/2111.06377>`_.
|
||||
"""
|
||||
|
||||
def extract_feat(self, inputs: List[torch.Tensor],
|
||||
def extract_feat(self,
|
||||
inputs: List[torch.Tensor],
|
||||
data_samples: Optional[List[SelfSupDataSample]] = None,
|
||||
**kwarg) -> Tuple[torch.Tensor]:
|
||||
"""The forward function to extract features from neck.
|
||||
|
||||
|
@ -27,33 +29,33 @@ class MAE(BaseModel):
|
|||
Returns:
|
||||
Tuple[torch.Tensor]: Neck outputs.
|
||||
"""
|
||||
latent, _, ids_restore = self.backbone(inputs[0])
|
||||
latent, mask, ids_restore = self.backbone(inputs[0])
|
||||
pred = self.neck(latent, ids_restore)
|
||||
self.mask = mask
|
||||
return pred
|
||||
|
||||
def predict(self,
|
||||
inputs: List[torch.Tensor],
|
||||
data_samples: Optional[List[SelfSupDataSample]] = None,
|
||||
**kwargs) -> SelfSupDataSample:
|
||||
"""The forward function in testing. It is mainly for image
|
||||
reconstruction.
|
||||
def reconstruct(self,
|
||||
features: torch.Tensor,
|
||||
data_samples: Optional[List[SelfSupDataSample]] = None,
|
||||
**kwargs) -> SelfSupDataSample:
|
||||
"""The function is for image reconstruction.
|
||||
|
||||
Args:
|
||||
inputs (List[torch.Tensor]): The input images.
|
||||
features (torch.Tensor): The input images.
|
||||
data_samples (List[SelfSupDataSample]): All elements required
|
||||
during the forward function.
|
||||
|
||||
Returns:
|
||||
SelfSupDataSample: The prediction from model.
|
||||
"""
|
||||
mean = kwargs['mean']
|
||||
std = kwargs['std']
|
||||
features = features * std + mean
|
||||
|
||||
latent, mask, ids_restore = self.backbone(inputs[0])
|
||||
pred = self.neck(latent, ids_restore)
|
||||
|
||||
pred = self.head.unpatchify(pred)
|
||||
pred = self.head.unpatchify(features)
|
||||
pred = torch.einsum('nchw->nhwc', pred).detach().cpu()
|
||||
|
||||
mask = mask.detach()
|
||||
mask = self.mask.detach()
|
||||
mask = mask.unsqueeze(-1).repeat(1, 1, self.head.patch_size**2 *
|
||||
3) # (N, H*W, p*p*3)
|
||||
mask = self.head.unpatchify(mask) # 1 is removing, 0 is keeping
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from mmengine.structures import BaseDataElement
|
||||
|
||||
from mmselfsup.registry import MODELS
|
||||
from mmselfsup.structures import SelfSupDataSample
|
||||
|
@ -16,8 +17,10 @@ class MaskFeat(BaseModel):
|
|||
Pre-Training <https://arxiv.org/abs/2112.09133>`_.
|
||||
"""
|
||||
|
||||
def extract_feat(self, inputs: List[torch.Tensor],
|
||||
def extract_feat(self,
|
||||
inputs: List[torch.Tensor],
|
||||
data_samples: List[SelfSupDataSample],
|
||||
compute_hog: bool = True,
|
||||
**kwarg) -> Tuple[torch.Tensor]:
|
||||
"""The forward function to extract features from neck.
|
||||
|
||||
|
@ -25,15 +28,30 @@ class MaskFeat(BaseModel):
|
|||
inputs (List[torch.Tensor]): The input images and mask.
|
||||
data_samples (List[SelfSupDataSample]): All elements required
|
||||
during the forward function.
|
||||
compute_hog (bool): Whether to compute hog during extraction. If
|
||||
True, the batch size of inputs need to be 1. Defaults to True.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor]: Neck outputs.
|
||||
"""
|
||||
img = inputs[0]
|
||||
mask = torch.stack(
|
||||
self.mask = torch.stack(
|
||||
[data_sample.mask.value for data_sample in data_samples])
|
||||
latent = self.backbone(img, mask)
|
||||
return latent
|
||||
latent = self.backbone(img, self.mask)
|
||||
B, L, C = latent.shape
|
||||
pred = self.neck([latent.view(B * L, C)])
|
||||
pred = pred[0].view(B, L, -1)
|
||||
|
||||
# compute hog
|
||||
if compute_hog:
|
||||
assert img.size(0) == 1, 'Currently only support batch size 1.'
|
||||
_ = self.target_generator(img)
|
||||
hog_image = torch.from_numpy(
|
||||
self.target_generator.generate_hog_image(
|
||||
self.target_generator.out)).unsqueeze(0).unsqueeze(0)
|
||||
self.target = hog_image.expand(-1, 3, -1, -1)
|
||||
|
||||
return pred[:, 1:, :] # remove cls token
|
||||
|
||||
def loss(self, inputs: List[torch.Tensor],
|
||||
data_samples: List[SelfSupDataSample],
|
||||
|
@ -62,3 +80,60 @@ class MaskFeat(BaseModel):
|
|||
loss = self.head(pred, hog, mask)
|
||||
losses = dict(loss=loss)
|
||||
return losses
|
||||
|
||||
def reconstruct(self,
|
||||
features: List[torch.Tensor],
|
||||
data_samples: Optional[List[SelfSupDataSample]] = None,
|
||||
**kwargs) -> SelfSupDataSample:
|
||||
"""The function is for image reconstruction.
|
||||
|
||||
Args:
|
||||
features (List[torch.Tensor]): The input images.
|
||||
data_samples (List[SelfSupDataSample]): All elements required
|
||||
during the forward function.
|
||||
|
||||
Returns:
|
||||
SelfSupDataSample: The prediction from model.
|
||||
"""
|
||||
|
||||
# recover to HOG description from feature embeddings
|
||||
unfold_size = self.target_generator.unfold_size
|
||||
tmp4 = features.unflatten(2,
|
||||
(features.shape[2] // unfold_size**2,
|
||||
unfold_size, unfold_size)) # 1,196,27,2,2
|
||||
tmp3 = tmp4.unflatten(1, self.backbone.patch_resolution)
|
||||
|
||||
b, p1, p2, c_nbins, _, _ = tmp3.shape # 1,14,14,27,2,2
|
||||
tmp2 = tmp3.permute(0, 1, 2, 5, 3, 4).reshape(
|
||||
(b, p1, p2 * unfold_size, c_nbins, unfold_size))
|
||||
tmp1 = tmp2.permute(0, 1, 4, 2, 3).reshape(
|
||||
(b, p1 * unfold_size, p2 * unfold_size, c_nbins))
|
||||
tmp0 = tmp1.permute(0, 3, 1, 2) # 1,27,28,28
|
||||
hog_out = tmp0.unflatten(1,
|
||||
(int(c_nbins // self.target_generator.nbins),
|
||||
self.target_generator.nbins)) # 1,3,9,28,28
|
||||
|
||||
# generate predction of HOG
|
||||
hog_image = torch.from_numpy(
|
||||
self.target_generator.generate_hog_image(hog_out))
|
||||
hog_image = hog_image.unsqueeze(0).unsqueeze(0)
|
||||
pred = torch.einsum('nchw->nhwc', hog_image).expand(-1, -1, -1,
|
||||
3).detach().cpu()
|
||||
|
||||
# transform patch mask to pixel mask
|
||||
mask = self.mask
|
||||
patch_dim_1 = int(self.backbone.patch_embed.init_input_size[0] //
|
||||
self.backbone.patch_resolution[0])
|
||||
patch_dim_2 = int(self.backbone.patch_embed.init_input_size[1] //
|
||||
self.backbone.patch_resolution[1])
|
||||
mask = mask.repeat_interleave(
|
||||
patch_dim_1, dim=1).repeat_interleave(
|
||||
patch_dim_2, dim=2).unsqueeze(-1).repeat(1, 1, 1, 3)
|
||||
# 1 is removing, 0 is keeping
|
||||
mask = mask.detach().cpu()
|
||||
|
||||
results = SelfSupDataSample()
|
||||
results.mask = BaseDataElement(**dict(value=mask))
|
||||
results.pred = BaseDataElement(**dict(value=pred))
|
||||
|
||||
return results
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from mmengine.structures import BaseDataElement
|
||||
|
||||
from mmselfsup.registry import MODELS
|
||||
from mmselfsup.structures import SelfSupDataSample
|
||||
|
@ -33,6 +34,7 @@ class SimMIM(BaseModel):
|
|||
[data_sample.mask.value for data_sample in data_samples])
|
||||
img_latent = self.backbone(inputs[0], mask)
|
||||
feat = self.neck(img_latent[0])
|
||||
self.mask = mask
|
||||
return feat
|
||||
|
||||
def loss(self, inputs: List[torch.Tensor],
|
||||
|
@ -58,3 +60,37 @@ class SimMIM(BaseModel):
|
|||
losses = dict(loss=loss)
|
||||
|
||||
return losses
|
||||
|
||||
def reconstruct(self,
|
||||
features: torch.Tensor,
|
||||
data_samples: Optional[List[SelfSupDataSample]] = None,
|
||||
**kwargs) -> SelfSupDataSample:
|
||||
"""The function is for image reconstruction.
|
||||
|
||||
Args:
|
||||
features (torch.Tensor): The input images.
|
||||
data_samples (List[SelfSupDataSample]): All elements required
|
||||
during the forward function.
|
||||
|
||||
Returns:
|
||||
SelfSupDataSample: The prediction from model.
|
||||
"""
|
||||
pred = torch.einsum('nchw->nhwc', features).detach().cpu()
|
||||
|
||||
# transform patch mask to pixel mask
|
||||
mask = self.mask.detach()
|
||||
p1 = int(self.backbone.patch_embed.init_input_size[0] //
|
||||
self.backbone.patch_resolution[0])
|
||||
p2 = int(self.backbone.patch_embed.init_input_size[1] //
|
||||
self.backbone.patch_resolution[1])
|
||||
mask = mask.repeat_interleave(
|
||||
p1, dim=1).repeat_interleave(
|
||||
p2, dim=2).unsqueeze(-1).repeat(1, 1, 1, 3) # (N, H, W, 3)
|
||||
# 1 is removing, 0 is keeping
|
||||
mask = mask.detach().cpu()
|
||||
|
||||
results = SelfSupDataSample()
|
||||
results.mask = BaseDataElement(**dict(value=mask))
|
||||
results.pred = BaseDataElement(**dict(value=pred))
|
||||
|
||||
return results
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from mmengine.model import BaseModule
|
||||
|
@ -13,10 +15,10 @@ class HOGGenerator(BaseModule):
|
|||
"""Generate HOG feature for images.
|
||||
|
||||
This module is used in MaskFeat to generate HOG feature. The code is
|
||||
modified from this `file
|
||||
modified from file `slowfast/models/operators.py
|
||||
<https://github.com/facebookresearch/SlowFast/blob/main/slowfast/models/operators.py>`_.
|
||||
Here is the link `HOG wikipedia
|
||||
<https://en.m.wikipedia.org/wiki/Histogram_of_oriented_gradients>`_.
|
||||
Here is the link of `HOG wikipedia
|
||||
<https://en.wikipedia.org/wiki/Histogram_of_oriented_gradients>`_.
|
||||
|
||||
Args:
|
||||
nbins (int): Number of bin. Defaults to 9.
|
||||
|
@ -61,12 +63,12 @@ class HOGGenerator(BaseModule):
|
|||
def _reshape(self, hog_feat: torch.Tensor) -> torch.Tensor:
|
||||
"""Reshape HOG Features for output."""
|
||||
hog_feat = hog_feat.flatten(1, 2)
|
||||
unfold_size = hog_feat.shape[-1] // 14
|
||||
hog_feat = (
|
||||
hog_feat.permute(0, 2, 3,
|
||||
1).unfold(1, unfold_size, unfold_size).unfold(
|
||||
2, unfold_size,
|
||||
unfold_size).flatten(1, 2).flatten(2))
|
||||
self.unfold_size = hog_feat.shape[-1] // 14
|
||||
hog_feat = hog_feat.permute(0, 2, 3, 1)
|
||||
hog_feat = hog_feat.unfold(1, self.unfold_size,
|
||||
self.unfold_size).unfold(
|
||||
2, self.unfold_size, self.unfold_size)
|
||||
hog_feat = hog_feat.flatten(1, 2).flatten(2)
|
||||
return hog_feat
|
||||
|
||||
@torch.no_grad()
|
||||
|
@ -80,6 +82,7 @@ class HOGGenerator(BaseModule):
|
|||
torch.Tensor: Hog features.
|
||||
"""
|
||||
# input is RGB image with shape [B 3 H W]
|
||||
self.h, self.w = x.size(-2), x.size(-1)
|
||||
x = F.pad(x, pad=(1, 1, 1, 1), mode='reflect')
|
||||
gx_rgb = F.conv2d(
|
||||
x, self.weight_x, bias=None, stride=1, padding=0, groups=3)
|
||||
|
@ -112,6 +115,38 @@ class HOGGenerator(BaseModule):
|
|||
out = out.unfold(4, self.pool, self.pool)
|
||||
out = out.sum(dim=[-1, -2])
|
||||
|
||||
out = F.normalize(out, p=2, dim=2)
|
||||
self.out = F.normalize(out, p=2, dim=2)
|
||||
|
||||
return self._reshape(out)
|
||||
return self._reshape(self.out)
|
||||
|
||||
def generate_hog_image(self, hog_out: torch.Tensor) -> np.ndarray:
|
||||
"""Generate HOG image according to HOG features."""
|
||||
assert hog_out.size(0) == 1 and hog_out.size(1) == 3, \
|
||||
'Check the input batch size and the channcel number, only support'\
|
||||
'"batch_size = 1".'
|
||||
hog_image = np.zeros([self.h, self.w])
|
||||
cell_gradient = np.array(hog_out.mean(dim=1).squeeze().detach().cpu())
|
||||
cell_width = self.pool / 2
|
||||
max_mag = np.array(cell_gradient).max()
|
||||
angle_gap = 360 / self.nbins
|
||||
|
||||
for x in range(cell_gradient.shape[1]):
|
||||
for y in range(cell_gradient.shape[2]):
|
||||
cell_grad = cell_gradient[:, x, y]
|
||||
cell_grad /= max_mag
|
||||
angle = 0
|
||||
for magnitude in cell_grad:
|
||||
angle_radian = math.radians(angle)
|
||||
x1 = int(x * self.pool +
|
||||
magnitude * cell_width * math.cos(angle_radian))
|
||||
y1 = int(y * self.pool +
|
||||
magnitude * cell_width * math.sin(angle_radian))
|
||||
x2 = int(x * self.pool -
|
||||
magnitude * cell_width * math.cos(angle_radian))
|
||||
y2 = int(y * self.pool -
|
||||
magnitude * cell_width * math.sin(angle_radian))
|
||||
magnitude = 0 if magnitude < 0 else magnitude
|
||||
cv2.line(hog_image, (y1, x1), (y2, x2),
|
||||
int(255 * math.sqrt(magnitude)))
|
||||
angle += angle_gap
|
||||
return hog_image
|
||||
|
|
|
@ -27,10 +27,10 @@ class ExampleModel(BaseModel):
|
|||
super(ExampleModel, self).__init__(backbone=backbone)
|
||||
self.layer = nn.Linear(1, 1)
|
||||
|
||||
def predict(self,
|
||||
inputs: List[torch.Tensor],
|
||||
data_samples: Optional[List[SelfSupDataSample]] = None,
|
||||
**kwargs) -> SelfSupDataSample:
|
||||
def extract_feat(self,
|
||||
inputs: List[torch.Tensor],
|
||||
data_samples: Optional[List[SelfSupDataSample]] = None,
|
||||
**kwargs) -> SelfSupDataSample:
|
||||
out = self.layer(inputs[0])
|
||||
return out
|
||||
|
||||
|
|
|
@ -48,9 +48,14 @@ def test_mae():
|
|||
fake_outputs = alg(fake_batch_inputs, fake_data_samples, mode='loss')
|
||||
assert isinstance(fake_outputs['loss'].item(), float)
|
||||
|
||||
# test extraction
|
||||
fake_feats = alg(fake_batch_inputs, fake_data_samples, mode='tensor')
|
||||
assert list(fake_feats.shape) == [2, 196, 768]
|
||||
|
||||
results = alg(fake_batch_inputs, fake_data_samples, mode='predict')
|
||||
# test reconstruct
|
||||
mean = fake_feats.mean(dim=-1, keepdim=True)
|
||||
std = (fake_feats.var(dim=-1, keepdim=True) + 1.e-6)**.5
|
||||
results = alg.reconstruct(
|
||||
fake_feats, fake_data_samples, mean=mean, std=std)
|
||||
assert list(results.mask.value.shape) == [2, 224, 224, 3]
|
||||
assert list(results.pred.value.shape) == [2, 224, 224, 3]
|
||||
|
|
|
@ -5,6 +5,7 @@ import platform
|
|||
import pytest
|
||||
import torch
|
||||
from mmengine.structures import InstanceData
|
||||
from mmengine.utils import digit_version
|
||||
|
||||
from mmselfsup.models.algorithms.maskfeat import MaskFeat
|
||||
from mmselfsup.structures import SelfSupDataSample
|
||||
|
@ -22,6 +23,9 @@ target_generator = dict(
|
|||
type='HOGGenerator', nbins=9, pool=8, gaussian_window=16)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
digit_version(torch.__version__) < digit_version('1.7.0'),
|
||||
reason='torch version')
|
||||
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
|
||||
def test_maskfeat():
|
||||
data_preprocessor = {
|
||||
|
@ -42,13 +46,19 @@ def test_maskfeat():
|
|||
fake_mask = InstanceData(value=torch.rand((14, 14)).bool())
|
||||
fake_data_sample.mask = fake_mask
|
||||
fake_data = {
|
||||
'inputs': [torch.randn((2, 3, 224, 224))],
|
||||
'data_sample': [fake_data_sample for _ in range(2)]
|
||||
'inputs': [torch.randn((1, 3, 224, 224))],
|
||||
'data_sample': [fake_data_sample for _ in range(1)]
|
||||
}
|
||||
|
||||
fake_batch_inputs, fake_data_samples = alg.data_preprocessor(fake_data)
|
||||
fake_outputs = alg(fake_batch_inputs, fake_data_samples, mode='loss')
|
||||
assert isinstance(fake_outputs['loss'].item(), float)
|
||||
|
||||
# test extraction
|
||||
fake_feats = alg.extract_feat(fake_batch_inputs, fake_data_samples)
|
||||
assert list(fake_feats.shape) == [2, 197, 768]
|
||||
assert list(fake_feats.shape) == [1, 196, 108]
|
||||
|
||||
# test reconstruction
|
||||
results = alg.reconstruct(fake_feats, fake_data_samples)
|
||||
assert list(results.mask.value.shape) == [1, 224, 224, 3]
|
||||
assert list(results.pred.value.shape) == [1, 224, 224, 3]
|
||||
|
|
|
@ -50,5 +50,10 @@ def test_simmim():
|
|||
|
||||
# test extract_feat
|
||||
fake_inputs, fake_data_samples = model.data_preprocessor(fake_data)
|
||||
fake_feat = model.extract_feat(fake_inputs, fake_data_samples)
|
||||
assert list(fake_feat.shape) == [2, 3, 192, 192]
|
||||
fake_feats = model.extract_feat(fake_inputs, fake_data_samples)
|
||||
assert list(fake_feats.shape) == [2, 3, 192, 192]
|
||||
|
||||
# test reconstruct
|
||||
results = model.reconstruct(fake_feats, fake_data_samples)
|
||||
assert list(results.mask.value.shape) == [2, 192, 192, 3]
|
||||
assert list(results.pred.value.shape) == [2, 192, 192, 3]
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmselfsup.models.target_generators import HOGGenerator
|
||||
|
@ -10,3 +11,10 @@ def test_hog_generator():
|
|||
fake_input = torch.randn((2, 3, 224, 224))
|
||||
fake_output = hog_generator(fake_input)
|
||||
assert list(fake_output.shape) == [2, 196, 108]
|
||||
|
||||
fake_hog_out = hog_generator.out[0].unsqueeze(0)
|
||||
fake_hog_img = hog_generator.generate_hog_image(fake_hog_out)
|
||||
assert fake_hog_img.shape == (224, 224)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
fake_hog_img = hog_generator.generate_hog_image(hog_generator.out[0])
|
||||
|
|
|
@ -1,108 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# Modified from https://colab.research.google.com/github/facebookresearch/mae
|
||||
# /blob/main/demo/mae_visualize.ipynb
|
||||
from argparse import ArgumentParser
|
||||
from typing import Tuple
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine.dataset import Compose, default_collate
|
||||
|
||||
from mmselfsup.apis import inference_model, init_model
|
||||
from mmselfsup.registry import MODELS
|
||||
from mmselfsup.utils import register_all_modules
|
||||
|
||||
imagenet_mean = np.array([0.485, 0.456, 0.406])
|
||||
imagenet_std = np.array([0.229, 0.224, 0.225])
|
||||
|
||||
|
||||
def show_image(image: torch.Tensor, title: str = '') -> None:
|
||||
# image is [H, W, 3]
|
||||
assert image.shape[2] == 3
|
||||
image = torch.clip((image * imagenet_std + imagenet_mean) * 255, 0,
|
||||
255).int()
|
||||
plt.imshow(image)
|
||||
plt.title(title, fontsize=16)
|
||||
plt.axis('off')
|
||||
return
|
||||
|
||||
|
||||
def save_images(x: torch.Tensor, im_masked: torch.Tensor, y: torch.Tensor,
|
||||
im_paste: torch.Tensor, out_file: str) -> None:
|
||||
# make the plt figure larger
|
||||
plt.rcParams['figure.figsize'] = [24, 6]
|
||||
|
||||
plt.subplot(1, 4, 1)
|
||||
show_image(x, 'original')
|
||||
|
||||
plt.subplot(1, 4, 2)
|
||||
show_image(im_masked, 'masked')
|
||||
|
||||
plt.subplot(1, 4, 3)
|
||||
show_image(y, 'reconstruction')
|
||||
|
||||
plt.subplot(1, 4, 4)
|
||||
show_image(im_paste, 'reconstruction + visible')
|
||||
|
||||
plt.savefig(out_file)
|
||||
|
||||
|
||||
def post_process(
|
||||
x: torch.Tensor, y: torch.Tensor, mask: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
x = torch.einsum('nchw->nhwc', x.cpu())
|
||||
# masked image
|
||||
im_masked = x * (1 - mask)
|
||||
# MAE reconstruction pasted with visible patches
|
||||
im_paste = x * (1 - mask) + y * mask
|
||||
return x[0], im_masked[0], y[0], im_paste[0]
|
||||
|
||||
|
||||
def main():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument('img_path', help='Image file path')
|
||||
parser.add_argument('config', help='MAE Config file')
|
||||
parser.add_argument('checkpoint', help='Checkpoint file')
|
||||
parser.add_argument('out_file', help='The output image file path')
|
||||
parser.add_argument(
|
||||
'--device', default='cuda:0', help='Device used for inference')
|
||||
args = parser.parse_args()
|
||||
|
||||
register_all_modules()
|
||||
|
||||
# build the model from a config file and a checkpoint file
|
||||
model = init_model(args.config, args.checkpoint, device=args.device)
|
||||
print('Model loaded.')
|
||||
|
||||
# make random mask reproducible (comment out to make it change)
|
||||
torch.manual_seed(2)
|
||||
print('MAE with pixel reconstruction:')
|
||||
|
||||
model.cfg.test_dataloader = dict(
|
||||
dataset=dict(pipeline=[
|
||||
dict(
|
||||
type='LoadImageFromFile',
|
||||
file_client_args=dict(backend='disk')),
|
||||
dict(type='Resize', scale=(224, 224)),
|
||||
dict(type='PackSelfSupInputs', meta_keys=['img_path'])
|
||||
]))
|
||||
|
||||
results = inference_model(model, args.img_path)
|
||||
|
||||
cfg = model.cfg
|
||||
test_pipeline = Compose(cfg.test_dataloader.dataset.pipeline)
|
||||
data_preprocessor = MODELS.build(cfg.model.data_preprocessor)
|
||||
data = dict(img_path=args.img_path)
|
||||
data = test_pipeline(data)
|
||||
data = default_collate([data])
|
||||
img, _ = data_preprocessor(data, False)
|
||||
|
||||
x, im_masked, y, im_paste = post_process(img[0], results.pred.value,
|
||||
results.mask.value)
|
||||
save_images(x, im_masked, y, im_paste, args.out_file)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,178 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# Modified from https://colab.research.google.com/github/facebookresearch/mae
|
||||
# /blob/main/demo/mae_visualize.ipynb
|
||||
import random
|
||||
from argparse import ArgumentParser
|
||||
from typing import Tuple
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine.dataset import Compose, default_collate
|
||||
|
||||
from mmselfsup.apis import inference_model, init_model
|
||||
from mmselfsup.utils import register_all_modules
|
||||
|
||||
imagenet_mean = np.array([0.485, 0.456, 0.406])
|
||||
imagenet_std = np.array([0.229, 0.224, 0.225])
|
||||
|
||||
|
||||
def show_image(img: torch.Tensor, title: str = '') -> None:
|
||||
# image is [H, W, 3]
|
||||
assert img.shape[2] == 3
|
||||
|
||||
plt.imshow(img)
|
||||
plt.title(title, fontsize=16)
|
||||
plt.axis('off')
|
||||
return
|
||||
|
||||
|
||||
def save_images(original_img: torch.Tensor, img_masked: torch.Tensor,
|
||||
pred_img: torch.Tensor, img_paste: torch.Tensor,
|
||||
out_file: str) -> None:
|
||||
# make the plt figure larger
|
||||
plt.rcParams['figure.figsize'] = [24, 6]
|
||||
|
||||
plt.subplot(1, 4, 1)
|
||||
show_image(original_img, 'original')
|
||||
|
||||
plt.subplot(1, 4, 2)
|
||||
show_image(img_masked, 'masked')
|
||||
|
||||
plt.subplot(1, 4, 3)
|
||||
show_image(pred_img, 'reconstruction')
|
||||
|
||||
plt.subplot(1, 4, 4)
|
||||
show_image(img_paste, 'reconstruction + visible')
|
||||
|
||||
plt.savefig(out_file)
|
||||
print(f'Images are saved to {out_file}')
|
||||
|
||||
|
||||
def recover_norm(img: torch.Tensor,
|
||||
mean: np.ndarray = imagenet_mean,
|
||||
std: np.ndarray = imagenet_std):
|
||||
if mean is not None and std is not None:
|
||||
img = torch.clip((img * std + mean) * 255, 0, 255).int()
|
||||
return img
|
||||
|
||||
|
||||
def post_process(
|
||||
original_img: torch.Tensor,
|
||||
pred_img: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
mean: np.ndarray = imagenet_mean,
|
||||
std: np.ndarray = imagenet_std
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
# channel conversion
|
||||
original_img = torch.einsum('nchw->nhwc', original_img.cpu())
|
||||
# masked image
|
||||
img_masked = original_img * (1 - mask)
|
||||
# reconstructed image pasted with visible patches
|
||||
img_paste = original_img * (1 - mask) + pred_img * mask
|
||||
|
||||
# muptiply std and add mean to each image
|
||||
original_img = recover_norm(original_img[0])
|
||||
img_masked = recover_norm(img_masked[0])
|
||||
|
||||
pred_img = recover_norm(pred_img[0])
|
||||
img_paste = recover_norm(img_paste[0])
|
||||
|
||||
return original_img, img_masked, pred_img, img_paste
|
||||
|
||||
|
||||
def main():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument('config', help='Model config file')
|
||||
parser.add_argument('--checkpoint', help='Checkpoint file')
|
||||
parser.add_argument('--img-path', help='Image file path')
|
||||
parser.add_argument('--out-file', help='The output image file path')
|
||||
parser.add_argument(
|
||||
'--use-vis-pipeline',
|
||||
action='store_true',
|
||||
help='Use vis_pipeline defined in config. For some algorithms, such '
|
||||
'as SimMIM and MaskFeat, they generate mask in data pipeline, thus '
|
||||
'the visualization process applies vis_pipeline in config to obtain '
|
||||
'the mask.')
|
||||
parser.add_argument(
|
||||
'--norm-pix',
|
||||
action='store_true',
|
||||
help='MAE uses `norm_pix_loss` for optimization in pre-training, thus '
|
||||
'the visualization process also need to compute mean and std of each '
|
||||
'patch embedding while reconstructing the original images.')
|
||||
parser.add_argument(
|
||||
'--target-generator',
|
||||
action='store_true',
|
||||
help='Some algorithms use target_generator for optimization in '
|
||||
'pre-training, such as MaskFeat, thus the visualization process could '
|
||||
'turn this on to visualize the target instead of RGB image.')
|
||||
parser.add_argument(
|
||||
'--device', default='cuda:0', help='Device used for inference')
|
||||
parser.add_argument(
|
||||
'--seed',
|
||||
type=int,
|
||||
default=0,
|
||||
help='The random seed for visualization')
|
||||
args = parser.parse_args()
|
||||
|
||||
register_all_modules()
|
||||
|
||||
# build the model from a config file and a checkpoint file
|
||||
model = init_model(args.config, args.checkpoint, device=args.device)
|
||||
print('Model loaded.')
|
||||
|
||||
# make random mask reproducible (comment out to make it change)
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
print('Reconstruction visualization.')
|
||||
|
||||
if args.use_vis_pipeline:
|
||||
model.cfg.test_dataloader = dict(
|
||||
dataset=dict(pipeline=model.cfg.vis_pipeline))
|
||||
else:
|
||||
model.cfg.test_dataloader = dict(
|
||||
dataset=dict(pipeline=[
|
||||
dict(
|
||||
type='LoadImageFromFile',
|
||||
file_client_args=dict(backend='disk')),
|
||||
dict(type='Resize', scale=(224, 224), backend='pillow'),
|
||||
dict(type='PackSelfSupInputs', meta_keys=['img_path'])
|
||||
]))
|
||||
|
||||
# get original image
|
||||
vis_pipeline = Compose(model.cfg.test_dataloader.dataset.pipeline)
|
||||
data = dict(img_path=args.img_path)
|
||||
data = vis_pipeline(data)
|
||||
data = default_collate([data])
|
||||
img, _ = model.data_preprocessor(data, False)
|
||||
|
||||
if args.norm_pix:
|
||||
# for MAE reconstruction
|
||||
img_embedding = model.head.patchify(img[0])
|
||||
# normalize the target image
|
||||
mean = img_embedding.mean(dim=-1, keepdim=True)
|
||||
std = (img_embedding.var(dim=-1, keepdim=True) + 1.e-6)**.5
|
||||
else:
|
||||
mean = imagenet_mean
|
||||
std = imagenet_std
|
||||
|
||||
# get reconstruction image
|
||||
features = inference_model(model, args.img_path)
|
||||
results = model.reconstruct(features, mean=mean, std=std)
|
||||
|
||||
original_target = model.target if args.target_generator else img[0]
|
||||
|
||||
original_img, img_masked, pred_img, img_paste = post_process(
|
||||
original_target,
|
||||
results.pred.value,
|
||||
results.mask.value,
|
||||
mean=mean,
|
||||
std=std)
|
||||
|
||||
save_images(original_img, img_masked, pred_img, img_paste, args.out_file)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
Reference in New Issue