[Refactor]: refactor MAE visualization (#471)
* [Refactor]: refactor MAE visualization * [Fix]: fix lint * [Refactor]: refactor MAE visualization * [Feature]: add mae_visualization.py * [UT]: add unit test * [Refactor]: move mae_visualization.py to tools/analysis_tools * [Docs]: Add the purpose of the function unpatchify() * [Fix]: fix lintpull/582/head
parent
d142699265
commit
41747f73c7
|
@ -0,0 +1,403 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Copyright (c) OpenMMLab. All rights reserved.\n",
|
||||
"\n",
|
||||
"Copyright (c) Meta Platforms, Inc. and affiliates.\n",
|
||||
"\n",
|
||||
"Modified from https://colab.research.google.com/github/facebookresearch/mae/blob/main/demo/mae_visualize.ipynb\n",
|
||||
"\n",
|
||||
"## Masked Autoencoders: Visualization Demo\n",
|
||||
"\n",
|
||||
"This is a visualization demo using our pre-trained MAE models. No GPU is needed."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Prepare\n",
|
||||
"Check environment. Install packages if in Colab."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import sys\n",
|
||||
"\n",
|
||||
"# check whether run in Colab\n",
|
||||
"if 'google.colab' in sys.modules:\n",
|
||||
" print('Running in Colab.')\n",
|
||||
" !pip3 install openmim\n",
|
||||
" !pip install -U openmim\n",
|
||||
" !mim install 'mmengine==0.1.0' 'mmcv>=2.0.0rc1'\n",
|
||||
"\n",
|
||||
" !git clone https://github.com/open-mmlab/mmselfsup.git\n",
|
||||
" %cd mmselfsup/\n",
|
||||
" !git checkout dev-1.x\n",
|
||||
" !pip install -e .\n",
|
||||
"\n",
|
||||
" sys.path.append('./mmselfsup')\n",
|
||||
" %cd demo\n",
|
||||
"else:\n",
|
||||
" sys.path.append('..')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import numpy as np\n",
|
||||
"import torch\n",
|
||||
"from mmengine.dataset import Compose, default_collate\n",
|
||||
"\n",
|
||||
"from mmselfsup.apis import inference_model\n",
|
||||
"from mmselfsup.models.utils import SelfSupDataPreprocessor\n",
|
||||
"from mmselfsup.registry import MODELS\n",
|
||||
"from mmselfsup.utils import register_all_modules"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Define utils"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# define the utils\n",
|
||||
"\n",
|
||||
"imagenet_mean = np.array([0.485, 0.456, 0.406])\n",
|
||||
"imagenet_std = np.array([0.229, 0.224, 0.225])\n",
|
||||
"\n",
|
||||
"def show_image(image, title=''):\n",
|
||||
" # image is [H, W, 3]\n",
|
||||
" assert image.shape[2] == 3\n",
|
||||
" image = torch.clip((image * imagenet_std + imagenet_mean) * 255, 0, 255).int()\n",
|
||||
" plt.imshow(image)\n",
|
||||
" plt.title(title, fontsize=16)\n",
|
||||
" plt.axis('off')\n",
|
||||
" return\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def show_images(x, im_masked, y, im_paste):\n",
|
||||
" # make the plt figure larger\n",
|
||||
" plt.rcParams['figure.figsize'] = [24, 6]\n",
|
||||
"\n",
|
||||
" plt.subplot(1, 4, 1)\n",
|
||||
" show_image(x, \"original\")\n",
|
||||
"\n",
|
||||
" plt.subplot(1, 4, 2)\n",
|
||||
" show_image(im_masked, \"masked\")\n",
|
||||
"\n",
|
||||
" plt.subplot(1, 4, 3)\n",
|
||||
" show_image(y, \"reconstruction\")\n",
|
||||
"\n",
|
||||
" plt.subplot(1, 4, 4)\n",
|
||||
" show_image(im_paste, \"reconstruction + visible\")\n",
|
||||
"\n",
|
||||
" plt.show()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def post_process(x, y, mask):\n",
|
||||
" x = torch.einsum('nchw->nhwc', x.cpu())\n",
|
||||
" # masked image\n",
|
||||
" im_masked = x * (1 - mask)\n",
|
||||
" # MAE reconstruction pasted with visible patches\n",
|
||||
" im_paste = x * (1 - mask) + y * mask\n",
|
||||
" return x[0], im_masked[0], y[0], im_paste[0]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Prepare config file"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Overwriting ../configs/selfsup/mae/mae_visualization.py\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"%%writefile ../configs/selfsup/mae/mae_visualization.py\n",
|
||||
"model = dict(\n",
|
||||
" type='MAE',\n",
|
||||
" data_preprocessor=dict(\n",
|
||||
" mean=[123.675, 116.28, 103.53],\n",
|
||||
" std=[58.395, 57.12, 57.375],\n",
|
||||
" bgr_to_rgb=True),\n",
|
||||
" backbone=dict(type='MAEViT', arch='l', patch_size=16, mask_ratio=0.75),\n",
|
||||
" neck=dict(\n",
|
||||
" type='MAEPretrainDecoder',\n",
|
||||
" patch_size=16,\n",
|
||||
" in_chans=3,\n",
|
||||
" embed_dim=1024,\n",
|
||||
" decoder_embed_dim=512,\n",
|
||||
" decoder_depth=8,\n",
|
||||
" decoder_num_heads=16,\n",
|
||||
" mlp_ratio=4.,\n",
|
||||
" ),\n",
|
||||
" head=dict(\n",
|
||||
" type='MAEPretrainHead',\n",
|
||||
" norm_pix=True,\n",
|
||||
" patch_size=16,\n",
|
||||
" loss=dict(type='MAEReconstructionLoss')),\n",
|
||||
" init_cfg=[\n",
|
||||
" dict(type='Xavier', distribution='uniform', layer='Linear'),\n",
|
||||
" dict(type='Constant', layer='LayerNorm', val=1.0, bias=0.0)\n",
|
||||
" ])\n",
|
||||
"\n",
|
||||
"file_client_args = dict(backend='disk')\n",
|
||||
"\n",
|
||||
"# dataset summary\n",
|
||||
"test_dataloader = dict(\n",
|
||||
" dataset=dict(pipeline=[\n",
|
||||
" dict(type='LoadImageFromFile', file_client_args=file_client_args),\n",
|
||||
" dict(type='Resize', scale=(224, 224)),\n",
|
||||
" dict(type='PackSelfSupInputs', meta_keys=['img_path'])\n",
|
||||
" ]))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Load a pre-trained MAE model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"--2022-09-03 00:34:55-- https://download.openmmlab.com/mmselfsup/mae/mae_visualize_vit_large.pth\n",
|
||||
"正在解析主机 download.openmmlab.com (download.openmmlab.com)... 47.107.10.247\n",
|
||||
"正在连接 download.openmmlab.com (download.openmmlab.com)|47.107.10.247|:443... 已连接。\n",
|
||||
"已发出 HTTP 请求,正在等待回应... 200 OK\n",
|
||||
"长度: 1318299501 (1.2G) [application/octet-stream]\n",
|
||||
"正在保存至: “mae_visualize_vit_large.pth”\n",
|
||||
"\n",
|
||||
"mae_visualize_vit_l 100%[===================>] 1.23G 3.22MB/s 用时 6m 4s \n",
|
||||
"\n",
|
||||
"2022-09-03 00:40:59 (3.46 MB/s) - 已保存 “mae_visualize_vit_large.pth” [1318299501/1318299501])\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# This is an MAE model trained with pixels as targets for visualization (ViT-large, training mask ratio=0.75)\n",
|
||||
"\n",
|
||||
"# download checkpoint if not exist\n",
|
||||
"# This ckpt is converted from https://dl.fbaipublicfiles.com/mae/visualize/mae_visualize_vit_large.pth\n",
|
||||
"!wget -nc https://download.openmmlab.com/mmselfsup/mae/mae_visualize_vit_large.pth"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"local loads checkpoint from path: mae_visualize_vit_large.pth\n",
|
||||
"Model loaded.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from mmselfsup.apis import init_model\n",
|
||||
"ckpt_path = \"mae_visualize_vit_large.pth\"\n",
|
||||
"model = init_model('../configs/selfsup/mae/mae_visualization.py', ckpt_path, device='cpu')\n",
|
||||
"print('Model loaded.')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Load an image"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"<torch._C.Generator at 0x7f5029d19950>"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# make random mask reproducible (comment out to make it change)\n",
|
||||
"register_all_modules()\n",
|
||||
"torch.manual_seed(2)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"--2022-09-03 00:41:01-- https://download.openmmlab.com/mmselfsup/mae/fox.jpg\n",
|
||||
"正在解析主机 download.openmmlab.com (download.openmmlab.com)... 101.133.111.186\n",
|
||||
"正在连接 download.openmmlab.com (download.openmmlab.com)|101.133.111.186|:443... 已连接。\n",
|
||||
"已发出 HTTP 请求,正在等待回应... 200 OK\n",
|
||||
"长度: 60133 (59K) [image/jpeg]\n",
|
||||
"正在保存至: “fox.jpg”\n",
|
||||
"\n",
|
||||
"fox.jpg 100%[===================>] 58.72K --.-KB/s 用时 0.06s \n",
|
||||
"\n",
|
||||
"2022-09-03 00:41:01 (962 KB/s) - 已保存 “fox.jpg” [60133/60133])\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"!wget -nc 'https://download.openmmlab.com/mmselfsup/mae/fox.jpg'"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"img_path = 'fox.jpg'"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"cfg = model.cfg\n",
|
||||
"test_pipeline = Compose(cfg.test_dataloader.dataset.pipeline)\n",
|
||||
"data_preprocessor = MODELS.build(cfg.model.data_preprocessor)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"data = dict(img_path=img_path)\n",
|
||||
"data = test_pipeline(data)\n",
|
||||
"data = default_collate([data])\n",
|
||||
"img, _ = data_preprocessor(data, False)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"plt.rcParams['figure.figsize'] = [5, 5]\n",
|
||||
"show_image(torch.einsum('nchw->nhwc', img[0].cpu())[0])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Run MAE on the image"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"results = inference_model(model, img_path)\n",
|
||||
"x, im_masked, y, im_paste = post_process(img[0], results.pred.value, results.mask.value)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print('MAE with pixel reconstruction:')\n",
|
||||
"show_images(x, im_masked, y, im_paste)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.13"
|
||||
},
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
"hash": "1742319693997e01e5942276ccf039297cd0a474ab9a20f711b7fa536eca5436"
|
||||
}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
|
@ -0,0 +1,4 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .inference import inference_model, init_model
|
||||
|
||||
__all__ = ['init_model', 'inference_model']
|
|
@ -0,0 +1,85 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine.config import Config
|
||||
from mmengine.dataset import Compose, default_collate
|
||||
from mmengine.runner import load_checkpoint
|
||||
from torch import nn
|
||||
|
||||
from mmselfsup.models import build_algorithm
|
||||
from mmselfsup.structures import SelfSupDataSample
|
||||
|
||||
|
||||
def init_model(config: Union[str, Config],
|
||||
checkpoint: Optional[str] = None,
|
||||
device: str = 'cuda:0',
|
||||
options: Optional[dict] = None) -> nn.Module:
|
||||
"""Initialize a model from config file.
|
||||
|
||||
Args:
|
||||
config (str or :obj:`mmengine.Config`): Config file path or the config
|
||||
object.
|
||||
checkpoint (str, optional): Checkpoint path. If left as None, the model
|
||||
will not load any weights.
|
||||
device (str): The device where the anchors will be put on.
|
||||
Defaults to cuda:0.
|
||||
options (dict): Options to override some settings in the used config.
|
||||
|
||||
Returns:
|
||||
nn.Module: The initialized model.
|
||||
"""
|
||||
if isinstance(config, str):
|
||||
config = Config.fromfile(config)
|
||||
elif not isinstance(config, Config):
|
||||
raise TypeError('config must be a filename or Config object, '
|
||||
f'but got {type(config)}')
|
||||
if options is not None:
|
||||
config.merge_from_dict(options)
|
||||
config.model.pretrained = None
|
||||
config.model.setdefault('data_preprocessor',
|
||||
config.get('data_preprocessor', None))
|
||||
model = build_algorithm(config.model)
|
||||
if checkpoint is not None:
|
||||
# Mapping the weights to GPU may cause unexpected video memory leak
|
||||
# which refers to https://github.com/open-mmlab/mmdetection/pull/6405
|
||||
checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
|
||||
model.cfg = config # save the config in the model for convenience
|
||||
model.to(device)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
def inference_model(model: nn.Module,
|
||||
img: Union[str, np.ndarray]) -> SelfSupDataSample:
|
||||
"""Inference an image with the mmselfsup model.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The loaded model.
|
||||
img (Union[str, ndarray]):
|
||||
Either image path or loaded image.
|
||||
|
||||
Returns:
|
||||
SelfSupDataSample: Output of model inference.
|
||||
"""
|
||||
cfg = model.cfg
|
||||
# build the data pipeline
|
||||
test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
|
||||
if isinstance(img, str):
|
||||
if test_pipeline_cfg[0]['type'] != 'LoadImageFromFile':
|
||||
test_pipeline_cfg.insert(0, dict(type='LoadImageFromFile'))
|
||||
data = dict(img_path=img)
|
||||
else:
|
||||
if test_pipeline_cfg[0]['type'] == 'LoadImageFromFile':
|
||||
test_pipeline_cfg.pop(0)
|
||||
data = dict(img=img)
|
||||
test_pipeline = Compose(test_pipeline_cfg)
|
||||
data = test_pipeline(data)
|
||||
data = default_collate([data])
|
||||
|
||||
# forward the model
|
||||
with torch.no_grad():
|
||||
results = model.test_step(data)
|
||||
return results
|
|
@ -1,7 +1,8 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from mmengine.structures import BaseDataElement
|
||||
|
||||
from mmselfsup.registry import MODELS
|
||||
from mmselfsup.structures import SelfSupDataSample
|
||||
|
@ -30,6 +31,40 @@ class MAE(BaseModel):
|
|||
pred = self.neck(latent, ids_restore)
|
||||
return pred
|
||||
|
||||
def predict(self,
|
||||
inputs: List[torch.Tensor],
|
||||
data_samples: Optional[List[SelfSupDataSample]] = None,
|
||||
**kwargs) -> SelfSupDataSample:
|
||||
"""The forward function in testing. It is mainly for image
|
||||
reconstruction.
|
||||
|
||||
Args:
|
||||
inputs (List[torch.Tensor]): The input images.
|
||||
data_samples (List[SelfSupDataSample]): All elements required
|
||||
during the forward function.
|
||||
|
||||
Returns:
|
||||
SelfSupDataSample: The prediction from model.
|
||||
"""
|
||||
|
||||
latent, mask, ids_restore = self.backbone(inputs[0])
|
||||
pred = self.neck(latent, ids_restore)
|
||||
|
||||
pred = self.head.unpatchify(pred)
|
||||
pred = torch.einsum('nchw->nhwc', pred).detach().cpu()
|
||||
|
||||
mask = mask.detach()
|
||||
mask = mask.unsqueeze(-1).repeat(1, 1, self.head.patch_size**2 *
|
||||
3) # (N, H*W, p*p*3)
|
||||
mask = self.head.unpatchify(mask) # 1 is removing, 0 is keeping
|
||||
mask = torch.einsum('nchw->nhwc', mask).detach().cpu()
|
||||
|
||||
results = SelfSupDataSample()
|
||||
results.mask = BaseDataElement(**dict(value=mask))
|
||||
results.pred = BaseDataElement(**dict(value=pred))
|
||||
|
||||
return results
|
||||
|
||||
def loss(self, inputs: List[torch.Tensor],
|
||||
data_samples: List[SelfSupDataSample],
|
||||
**kwargs) -> Dict[str, torch.Tensor]:
|
||||
|
|
|
@ -43,6 +43,23 @@ class MAEPretrainHead(BaseModule):
|
|||
x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
|
||||
return x
|
||||
|
||||
def unpatchify(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Combine non-overlapped patches into images.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The shape is (N, L, patch_size**2 *3)
|
||||
Returns:
|
||||
imgs (torch.Tensor): The shape is (N, 3, H, W)
|
||||
"""
|
||||
p = self.patch_size
|
||||
h = w = int(x.shape[1]**.5)
|
||||
assert h * w == x.shape[1]
|
||||
|
||||
x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
|
||||
x = torch.einsum('nhwpqc->nchpwq', x)
|
||||
imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
|
||||
return imgs
|
||||
|
||||
def construct_target(self, target: torch.Tensor) -> torch.Tensor:
|
||||
"""Construct the reconstruction target.
|
||||
|
||||
|
|
|
@ -0,0 +1,62 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
import platform
|
||||
from typing import List, Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmengine.config import Config
|
||||
|
||||
from mmselfsup.apis import inference_model
|
||||
from mmselfsup.models import BaseModel
|
||||
from mmselfsup.structures import SelfSupDataSample
|
||||
from mmselfsup.utils import register_all_modules
|
||||
|
||||
backbone = dict(
|
||||
type='ResNet',
|
||||
depth=18,
|
||||
in_channels=2,
|
||||
out_indices=[4], # 0: conv-1, x: stage-x
|
||||
norm_cfg=dict(type='BN'))
|
||||
|
||||
|
||||
class ExampleModel(BaseModel):
|
||||
|
||||
def __init__(self, backbone=backbone):
|
||||
super(ExampleModel, self).__init__(backbone=backbone)
|
||||
self.layer = nn.Linear(1, 1)
|
||||
|
||||
def predict(self,
|
||||
inputs: List[torch.Tensor],
|
||||
data_samples: Optional[List[SelfSupDataSample]] = None,
|
||||
**kwargs) -> SelfSupDataSample:
|
||||
out = self.layer(inputs[0])
|
||||
return out
|
||||
|
||||
|
||||
@pytest.mark.skipif(platform.system() == 'Windows', reason='')
|
||||
def test_inference_model():
|
||||
register_all_modules()
|
||||
|
||||
# Specify the data settings
|
||||
cfg = Config.fromfile(
|
||||
'configs/selfsup/relative_loc/relative-loc_resnet50_8xb64-steplr-70e_in1k.py' # noqa: E501
|
||||
)
|
||||
# Build the algorithm
|
||||
model = ExampleModel()
|
||||
model.cfg = cfg
|
||||
model.cfg.test_dataloader = dict(
|
||||
dataset=dict(pipeline=[
|
||||
dict(
|
||||
type='LoadImageFromFile',
|
||||
file_client_args=dict(backend='disk')),
|
||||
dict(type='Resize', scale=(1, 1)),
|
||||
dict(type='PackSelfSupInputs', meta_keys=['img_path'])
|
||||
]))
|
||||
|
||||
img_path = osp.join(osp.dirname(__file__), '..', 'data', 'color.jpg')
|
||||
|
||||
# inference model
|
||||
out = inference_model(model, img_path)
|
||||
assert out.size() == torch.Size([1, 3, 1, 1])
|
|
@ -50,3 +50,7 @@ def test_mae():
|
|||
|
||||
fake_feats = alg(fake_batch_inputs, fake_data_samples, mode='tensor')
|
||||
assert list(fake_feats.shape) == [2, 196, 768]
|
||||
|
||||
results = alg(fake_batch_inputs, fake_data_samples, mode='predict')
|
||||
assert list(results.mask.value.shape) == [2, 224, 224, 3]
|
||||
assert list(results.pred.value.shape) == [2, 224, 224, 3]
|
||||
|
|
|
@ -0,0 +1,108 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# Modified from https://colab.research.google.com/github/facebookresearch/mae
|
||||
# /blob/main/demo/mae_visualize.ipynb
|
||||
from argparse import ArgumentParser
|
||||
from typing import Tuple
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine.dataset import Compose, default_collate
|
||||
|
||||
from mmselfsup.apis import inference_model, init_model
|
||||
from mmselfsup.registry import MODELS
|
||||
from mmselfsup.utils import register_all_modules
|
||||
|
||||
imagenet_mean = np.array([0.485, 0.456, 0.406])
|
||||
imagenet_std = np.array([0.229, 0.224, 0.225])
|
||||
|
||||
|
||||
def show_image(image: torch.Tensor, title: str = '') -> None:
|
||||
# image is [H, W, 3]
|
||||
assert image.shape[2] == 3
|
||||
image = torch.clip((image * imagenet_std + imagenet_mean) * 255, 0,
|
||||
255).int()
|
||||
plt.imshow(image)
|
||||
plt.title(title, fontsize=16)
|
||||
plt.axis('off')
|
||||
return
|
||||
|
||||
|
||||
def save_images(x: torch.Tensor, im_masked: torch.Tensor, y: torch.Tensor,
|
||||
im_paste: torch.Tensor, out_file: str) -> None:
|
||||
# make the plt figure larger
|
||||
plt.rcParams['figure.figsize'] = [24, 6]
|
||||
|
||||
plt.subplot(1, 4, 1)
|
||||
show_image(x, 'original')
|
||||
|
||||
plt.subplot(1, 4, 2)
|
||||
show_image(im_masked, 'masked')
|
||||
|
||||
plt.subplot(1, 4, 3)
|
||||
show_image(y, 'reconstruction')
|
||||
|
||||
plt.subplot(1, 4, 4)
|
||||
show_image(im_paste, 'reconstruction + visible')
|
||||
|
||||
plt.savefig(out_file)
|
||||
|
||||
|
||||
def post_process(
|
||||
x: torch.Tensor, y: torch.Tensor, mask: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
x = torch.einsum('nchw->nhwc', x.cpu())
|
||||
# masked image
|
||||
im_masked = x * (1 - mask)
|
||||
# MAE reconstruction pasted with visible patches
|
||||
im_paste = x * (1 - mask) + y * mask
|
||||
return x[0], im_masked[0], y[0], im_paste[0]
|
||||
|
||||
|
||||
def main():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument('img_path', help='Image file path')
|
||||
parser.add_argument('config', help='MAE Config file')
|
||||
parser.add_argument('checkpoint', help='Checkpoint file')
|
||||
parser.add_argument('out_file', help='The output image file path')
|
||||
parser.add_argument(
|
||||
'--device', default='cuda:0', help='Device used for inference')
|
||||
args = parser.parse_args()
|
||||
|
||||
register_all_modules()
|
||||
|
||||
# build the model from a config file and a checkpoint file
|
||||
model = init_model(args.config, args.checkpoint, device=args.device)
|
||||
print('Model loaded.')
|
||||
|
||||
# make random mask reproducible (comment out to make it change)
|
||||
torch.manual_seed(2)
|
||||
print('MAE with pixel reconstruction:')
|
||||
|
||||
model.cfg.test_dataloader = dict(
|
||||
dataset=dict(pipeline=[
|
||||
dict(
|
||||
type='LoadImageFromFile',
|
||||
file_client_args=dict(backend='disk')),
|
||||
dict(type='Resize', scale=(224, 224)),
|
||||
dict(type='PackSelfSupInputs', meta_keys=['img_path'])
|
||||
]))
|
||||
|
||||
results = inference_model(model, args.img_path)
|
||||
|
||||
cfg = model.cfg
|
||||
test_pipeline = Compose(cfg.test_dataloader.dataset.pipeline)
|
||||
data_preprocessor = MODELS.build(cfg.model.data_preprocessor)
|
||||
data = dict(img_path=args.img_path)
|
||||
data = test_pipeline(data)
|
||||
data = default_collate([data])
|
||||
img, _ = data_preprocessor(data, False)
|
||||
|
||||
x, im_masked, y, im_paste = post_process(img[0], results.pred.value,
|
||||
results.mask.value)
|
||||
save_images(x, im_masked, y, im_paste, args.out_file)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
Reference in New Issue