[Tools]: MAE Reconstructed Image Visualization (#376)

* [Tools]: MAE Reconstructed Image Visualization]

* [Fix]: fix docstring and type hint

* [Fix]: fix docstring in MAE clsss

* [Fix]: fix docstring in MAE clsss

* [Fix]: fix type hint

* [Fix]: fix type hint and docstring

* [refactor]: refactor super init
This commit is contained in:
RenQin 2022-07-27 16:03:57 +08:00 committed by Yixiao Fang
parent 074ae09568
commit 3f530f085e
11 changed files with 638 additions and 24 deletions

View File

@ -0,0 +1,268 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Copyright (c) OpenMMLab. All rights reserved.\n",
"\n",
"Copyright (c) Meta Platforms, Inc. and affiliates.\n",
"\n",
"Modified from https://colab.research.google.com/github/facebookresearch/mae/blob/main/demo/mae_visualize.ipynb\n",
"\n",
"## Masked Autoencoders: Visualization Demo\n",
"\n",
"This is a visualization demo using our pre-trained MAE models. No GPU is needed."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Prepare\n",
"Check environment. Install packages if in Colab."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"import os\n",
"import requests\n",
"\n",
"import torch\n",
"import numpy as np\n",
"\n",
"import matplotlib.pyplot as plt\n",
"from PIL import Image\n",
"\n",
"from mmselfsup.models import build_algorithm\n",
"\n",
"# check whether run in Colab\n",
"if 'google.colab' in sys.modules:\n",
" print('Running in Colab.')\n",
" !pip install openmim\n",
" !mim install mmcv-full\n",
" !git clone https://github.com/open-mmlab/mmselfsup.git\n",
" %cd mmselfsup/\n",
" !pip install -e .\n",
" sys.path.append('./mmselfsup')\n",
" %cd demo\n",
"else:\n",
" sys.path.append('..')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Define utils"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# define the utils\n",
"\n",
"imagenet_mean = np.array([0.485, 0.456, 0.406])\n",
"imagenet_std = np.array([0.229, 0.224, 0.225])\n",
"\n",
"def show_image(image, title=''):\n",
" # image is [H, W, 3]\n",
" assert image.shape[2] == 3\n",
" image = torch.clip((image * imagenet_std + imagenet_mean) * 255, 0, 255).int()\n",
" plt.imshow(image)\n",
" plt.title(title, fontsize=16)\n",
" plt.axis('off')\n",
" return\n",
"\n",
"\n",
"def show_images(x, im_masked, y, im_paste):\n",
" # make the plt figure larger\n",
" plt.rcParams['figure.figsize'] = [24, 6]\n",
"\n",
" plt.subplot(1, 4, 1)\n",
" show_image(x, \"original\")\n",
"\n",
" plt.subplot(1, 4, 2)\n",
" show_image(im_masked, \"masked\")\n",
"\n",
" plt.subplot(1, 4, 3)\n",
" show_image(y, \"reconstruction\")\n",
"\n",
" plt.subplot(1, 4, 4)\n",
" show_image(im_paste, \"reconstruction + visible\")\n",
"\n",
" plt.show()\n",
"\n",
"\n",
"def post_process(x, y, mask):\n",
" x = torch.einsum('nchw->nhwc', x.cpu())\n",
" # masked image\n",
" im_masked = x * (1 - mask)\n",
" # MAE reconstruction pasted with visible patches\n",
" im_paste = x * (1 - mask) + y * mask\n",
" return x[0], im_masked[0], y[0], im_paste[0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Load an image"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"# load an image\n",
"img_url = 'https://download.openmmlab.com/mmselfsup/mae/fox.jpg'\n",
"img_pil = Image.open(requests.get(img_url, stream=True).raw)\n",
"img = img_pil.resize((224, 224))\n",
"img = np.array(img) / 255.\n",
"\n",
"assert img.shape == (224, 224, 3)\n",
"\n",
"# normalize by ImageNet mean and std\n",
"img = img - imagenet_mean\n",
"img = img / imagenet_std\n",
"\n",
"plt.rcParams['figure.figsize'] = [5, 5]\n",
"show_image(torch.tensor(img))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Load a pre-trained MAE model"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"%%writefile ../configs/selfsup/mae/mae_visualization.py\n",
"model = dict(\n",
" type='MAE',\n",
" backbone=dict(type='MAEViT', arch='l', patch_size=16, mask_ratio=0.75),\n",
" neck=dict(\n",
" type='MAEPretrainDecoder',\n",
" patch_size=16,\n",
" in_chans=3,\n",
" embed_dim=1024,\n",
" decoder_embed_dim=512,\n",
" decoder_depth=8,\n",
" decoder_num_heads=16,\n",
" mlp_ratio=4.,\n",
" ),\n",
" head=dict(type='MAEPretrainHead', norm_pix=True, patch_size=16))\n",
"\n",
"img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n",
"# dataset summary\n",
"data = dict(\n",
" test=dict(\n",
" pipeline = [\n",
" dict(type='Resize', size=(224, 224)),\n",
" dict(type='ToTensor'),\n",
" dict(type='Normalize', **img_norm_cfg),]\n",
" ))"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"# This is an MAE model trained with pixels as targets for visualization (ViT-large, training mask ratio=0.75)\n",
"\n",
"# download checkpoint if not exist\n",
"# This ckpt is converted from https://dl.fbaipublicfiles.com/mae/visualize/mae_visualize_vit_large.pth\n",
"!wget -nc https://download.openmmlab.com/mmselfsup/mae/mae_visualize_vit_large.pth"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"from mmselfsup.apis import init_model\n",
"ckpt_path = \"mae_visualize_vit_large.pth\"\n",
"model = init_model('../configs/selfsup/mae/mae_visualization.py', ckpt_path, device='cpu')\n",
"print('Model loaded.')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Run MAE on the image"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"# make random mask reproducible (comment out to make it change)\n",
"torch.manual_seed(2)\n",
"print('MAE with pixel reconstruction:')\n",
"\n",
"from mmselfsup.apis import inference_model\n",
"\n",
"img_url = 'https://download.openmmlab.com/mmselfsup/mae/fox.jpg'\n",
"img = Image.open(requests.get(img_url, stream=True).raw)\n",
"img, (mask, pred) = inference_model(model, img)\n",
"x, im_masked, y, im_paste = post_process(img, pred, mask)\n",
"show_images(x, im_masked, y, im_paste)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
},
"vscode": {
"interpreter": {
"hash": "3e4aeeccd14e965f43d0896afbaf8d71604e66b8605affbaa33ec76aa4083757"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -166,6 +166,27 @@ Arguments:
- `WORK_DIR`: the directory to save the results of visualization.
- `[optional arguments]`: for optional arguments, you can refer to [visualize_tsne.py](https://github.com/open-mmlab/mmselfsup/blob/master/tools/analysis_tools/visualize_tsne.py)
### MAE Visualization
We provide a tool to visualize the mask and reconstruction image of MAE model.
```shell
python tools/misc/mae_visualization.py ${IMG} ${CONFIG_FILE} ${CKPT_PATH} --device ${DEVICE}
```
参数:
- `IMG`: an image path used for visualization.
- `CONFIG_FILE`: config file for the pre-trained model.
- `CKPT_PATH`: the path of model's checkpoint.
- `DEVICE`: device used for inference.
An example:
```shell
python tools/misc/mae_visualization.py tests/data/color.jpg configs/selfsup/mae/mae_vit-base-p16_8xb512-coslr-400e_in1k.py mae_epoch_400.pth --device 'cuda:0'
```
### Reproducibility
If you want to make your performance exactly reproducible, please switch on `--deterministic` to train the final model to be published. Note that this flag will switch off `torch.backends.cudnn.benchmark` and slow down the training speed.

View File

@ -164,6 +164,27 @@ python tools/analysis_tools/visualize_tsne.py ${CONFIG_FILE} --checkpoint ${CKPT
- `WORK_DIR`: 保存可视化结果的路径.
- `[optional arguments]`: 可选参数,具体可以参考 [visualize_tsne.py](../../tools/analysis_tools/visualize_tsne.py)
### MAE 可视化
我们提供了一个对 MAE 掩码效果和重建效果可视化可视化的方法:
```shell
python tools/misc/mae_visualization.py ${IMG} ${CONFIG_FILE} ${CKPT_PATH} --device ${DEVICE}
```
参数:
- `IMG`: 用于可视化的图片
- `CONFIG_FILE`: 训练预训练模型的参数配置文件.
- `CKPT_PATH`: 预训练模型的路径.
- `DEVICE`: 用于推理的设备.
示例:
```shell
python tools/misc/mae_visualization.py tests/data/color.jpg configs/selfsup/mae/mae_vit-base-p16_8xb512-coslr-400e_in1k.py mae_epoch_400.pth --device 'cuda:0'
```
### 可复现性
如果您想确保模型精度的可复现性,您可以设置 `--deterministic` 参数。但是,开启 `--deterministic` 意味着关闭 `torch.backends.cudnn.benchmark`, 所以会使模型的训练速度变慢。

View File

@ -1,4 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .inference import inference_model, init_model
from .train import init_random_seed, set_random_seed, train_model
__all__ = ['init_random_seed', 'set_random_seed', 'train_model']
__all__ = [
'init_random_seed', 'inference_model', 'set_random_seed', 'train_model',
'init_model'
]

View File

@ -0,0 +1,85 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Tuple, Union
import mmcv
import torch
from mmcv.parallel import collate, scatter
from mmcv.runner import load_checkpoint
from mmcv.utils import build_from_cfg
from PIL import Image
from torch import nn
from torchvision.transforms import Compose
from mmselfsup.datasets import PIPELINES
from mmselfsup.models import build_algorithm
def init_model(config: Union[str, mmcv.Config],
checkpoint: Optional[str] = None,
device: str = 'cuda:0',
options: Optional[dict] = None) -> nn.Module:
"""Initialize an model from config file.
Args:
config (str or :obj:``mmcv.Config``): Config file path or the config
object.
checkpoint (str, optional): Checkpoint path. If left as None, the model
will not load any weights. Defaults to None.
device (str): The device where the model will be put on.
Default to 'cuda:0'.
options (dict, optional): Options to override some settings in the used
config. Defaults to None.
Returns:
nn.Module: The initialized model.
"""
if isinstance(config, str):
config = mmcv.Config.fromfile(config)
elif not isinstance(config, mmcv.Config):
raise TypeError('config must be a filename or Config object, '
f'but got {type(config)}')
if options is not None:
config.merge_from_dict(options)
model = build_algorithm(config.model)
if checkpoint is not None:
# Mapping the weights to GPU may cause unexpected video memory leak
# which refers to https://github.com/open-mmlab/mmdetection/pull/6405
checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
model.cfg = config # save the config in the model for convenience
model.to(device)
model.eval()
return model
def inference_model(
model: nn.Module,
data: Image) -> Tuple[torch.Tensor, Union[torch.Tensor, dict]]:
"""Inference an image with the model.
Args:
model (nn.Module): The loaded model.
data (PIL.Image): The loaded image.
Returns:
Tuple[torch.Tensor, Union(torch.Tensor, dict)]: Output of model
inference.
- data (torch.Tensor): The loaded image to input model.
- output (torch.Tensor, dict[str, torch.Tensor]): the output
of test model.
"""
cfg = model.cfg
device = next(model.parameters()).device # model device
# build the data pipeline
test_pipeline = [
build_from_cfg(p, PIPELINES) for p in cfg.data.test.pipeline
]
test_pipeline = Compose(test_pipeline)
data = test_pipeline(data)
data = collate([data], samples_per_gpu=1)
if next(model.parameters()).is_cuda:
# scatter to specified GPU
data = scatter(data, [device])[0]
# forward the model
with torch.no_grad():
output = model(data, mode='test')
return data, output

View File

@ -1,4 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Dict, Tuple
import torch
from ..builder import ALGORITHMS, build_backbone, build_head, build_neck
from .base import BaseModel
@ -8,17 +12,22 @@ class MAE(BaseModel):
"""MAE.
Implementation of `Masked Autoencoders Are Scalable Vision Learners
<https://arxiv.org/abs/2111.06377>`_.
<https://arxiv.org/abs/2111.06377>`_.
Args:
backbone (dict): Config dict for encoder. Defaults to None.
neck (dict): Config dict for encoder. Defaults to None.
head (dict): Config dict for loss functions. Defaults to None.
init_cfg (dict): Config dict for weight initialization.
backbone (dict, optional): Config dict for encoder. Defaults to None.
neck (dict, optional): Config dict for encoder. Defaults to None.
head (dict, optional): Config dict for loss functions. Defaults to None.
init_cfg (dict, optional): Config dict for weight initialization.
Defaults to None.
"""
def __init__(self, backbone=None, neck=None, head=None, init_cfg=None):
super(MAE, self).__init__(init_cfg)
def __init__(self,
backbone: Optional[dict] = None,
neck: Optional[dict] = None,
head: Optional[dict] = None,
init_cfg: Optional[dict] = None) -> None:
super().__init__(init_cfg)
assert backbone is not None
self.backbone = build_backbone(backbone)
assert neck is not None
@ -28,31 +37,56 @@ class MAE(BaseModel):
self.head = build_head(head)
def init_weights(self):
super(MAE, self).init_weights()
super().init_weights()
def extract_feat(self, img):
def extract_feat(self, img: torch.Tensor) -> Tuple[torch.Tensor]:
"""Function to extract features from backbone.
Args:
img (Tensor): Input images of shape (N, C, H, W).
img (torch.Tensor): Input images of shape (N, C, H, W).
Returns:
tuple[Tensor]: backbone outputs.
Tuple[torch.Tensor]: backbone outputs.
"""
return self.backbone(img)
def forward_train(self, img, **kwargs):
def forward_train(self, img: torch.Tensor,
**kwargs) -> Dict[str, torch.Tensor]:
"""Forward computation during training.
Args:
img (Tensor): Input images of shape (N, C, H, W).
img (torch.Tensor): Input images of shape (N, C, H, W).
kwargs: Any keyword arguments to be used to forward.
Returns:
dict[str, Tensor]: A dictionary of loss components.
Dict[str, torch.Tensor]: A dictionary of loss components.
"""
latent, mask, ids_restore = self.backbone(img)
pred = self.neck(latent, ids_restore)
losses = self.head(img, pred, mask)
return losses
def forward_test(self, img: torch.Tensor,
**kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward computation during testing.
Args:
img (torch.Tensor): Input images of shape (N, C, H, W).
kwargs: Any keyword arguments to be used to forward.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Output of model test.
- mask: Mask used to mask image.
- pred: The output of neck.
"""
latent, mask, ids_restore = self.backbone(img)
pred = self.neck(latent, ids_restore)
pred = self.head.unpatchify(pred)
pred = torch.einsum('nchw->nhwc', pred).detach().cpu()
mask = mask.detach()
mask = mask.unsqueeze(-1).repeat(1, 1, self.head.patch_size**2 *
3) # (N, H*W, p*p*3)
mask = self.head.unpatchify(mask) # 1 is removing, 0 is keeping
mask = torch.einsum('nchw->nhwc', mask).detach().cpu()
return mask, pred

View File

@ -18,13 +18,18 @@ class MAEPretrainHead(BaseModule):
patch_size (int): Patch size. Defaults to 16.
"""
def __init__(self, norm_pix=False, patch_size=16):
super(MAEPretrainHead, self).__init__()
def __init__(self, norm_pix: bool = False, patch_size: int = 16) -> None:
super().__init__()
self.norm_pix = norm_pix
self.patch_size = patch_size
def patchify(self, imgs):
def patchify(self, imgs: torch.Tensor) -> torch.Tensor:
"""
Args:
imgs (torch.Tensor): The shape is (N, 3, H, W)
Returns:
x (torch.Tensor): The shape is (N, L, patch_size**2 *3)
"""
p = self.patch_size
assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
@ -34,7 +39,24 @@ class MAEPretrainHead(BaseModule):
x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
return x
def forward(self, x, pred, mask):
def unpatchify(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x (torch.Tensor): The shape is (N, L, patch_size**2 *3)
Returns:
imgs (torch.Tensor): The shape is (N, 3, H, W)
"""
p = self.patch_size
h = w = int(x.shape[1]**.5)
assert h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
x = torch.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
return imgs
def forward(self, x: torch.Tensor, pred: torch.Tensor,
mask: torch.Tensor) -> dict:
losses = dict()
target = self.patchify(x)
if self.norm_pix:
@ -60,7 +82,7 @@ class MAEFinetuneHead(BaseModule):
"""
def __init__(self, embed_dim, num_classes=1000, label_smooth_val=0.1):
super(MAEFinetuneHead, self).__init__()
super().__init__()
self.head = nn.Linear(embed_dim, num_classes)
self.criterion = LabelSmoothLoss(label_smooth_val, num_classes)
@ -92,7 +114,7 @@ class MAELinprobeHead(BaseModule):
"""
def __init__(self, embed_dim, num_classes=1000):
super(MAELinprobeHead, self).__init__()
super().__init__()
self.head = nn.Linear(embed_dim, num_classes)
self.bn = nn.BatchNorm1d(embed_dim, affine=False, eps=1e-6)
self.criterion = nn.CrossEntropyLoss()

View File

@ -0,0 +1,59 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import platform
import pytest
import torch
import torch.nn as nn
from mmcv import Config
from PIL import Image
from mmselfsup.apis import inference_model
from mmselfsup.models import BaseModel
class ExampleModel(BaseModel):
def __init__(self):
super(ExampleModel, self).__init__()
self.test_cfg = None
self.layer = nn.Linear(1, 1)
self.neck = nn.Identity()
def extract_feat(self, imgs):
pass
def forward_train(self, imgs, **kwargs):
pass
def forward_test(self, img, **kwargs):
out = self.layer(img)
return out
@pytest.mark.skipif(platform.system() == 'Windows', reason='')
def test_inference_model():
# Specify the data settings
cfg = Config.fromfile(
'configs/selfsup/relative_loc/relative-loc_resnet50_8xb64-steplr-70e_in1k.py' # noqa: E501
)
# Build the algorithm
model = ExampleModel()
model.cfg = cfg
img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
model.cfg.data = dict(
test=dict(pipeline=[
dict(type='Resize', size=(1, 1)),
dict(type='ToTensor'),
dict(type='Normalize', **img_norm_cfg),
]))
data = Image.open(
osp.join(osp.dirname(__file__), '..', 'data', 'color.jpg'))
# inference model
data, output = inference_model(model, data)
assert data.size() == torch.Size([1, 3, 1, 1])
assert output.size() == torch.Size([1, 3, 1, 1])

View File

@ -33,5 +33,8 @@ def test_mae():
fake_input = torch.randn((2, 3, 224, 224))
fake_loss = alg.forward_train(fake_input)
fake_feature = alg.extract_feat(fake_input)
mask, pred = alg.forward_test(fake_input)
assert isinstance(fake_loss['loss'].item(), float)
assert list(fake_feature[0].shape) == [2, 50, 768]
assert list(mask.shape) == [2, 224, 224, 3]
assert list(pred.shape) == [2, 224, 224, 3]

View File

@ -103,6 +103,10 @@ def test_mae_pretrain_head():
assert loss_norm_pixel['loss'].item() > 0
x = torch.rand((1, 4, 16**2 * 3))
imgs = head_norm_pixel.unpatchify(x)
assert imgs.size() == torch.Size((1, 3, 32, 32))
def test_mae_finetune_head():

View File

@ -0,0 +1,93 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Modified from https://colab.research.google.com/github/facebookresearch/mae
# /blob/main/demo/mae_visualize.ipynb
from argparse import ArgumentParser
from typing import Tuple
import matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image
from mmselfsup.apis import inference_model, init_model
imagenet_mean = np.array([0.485, 0.456, 0.406])
imagenet_std = np.array([0.229, 0.224, 0.225])
def show_image(image: torch.Tensor, title: str = '') -> None:
# image is [H, W, 3]
assert image.shape[2] == 3
image = torch.clip((image * imagenet_std + imagenet_mean) * 255, 0,
255).int()
plt.imshow(image)
plt.title(title, fontsize=16)
plt.axis('off')
return
def show_images(x: torch.Tensor, im_masked: torch.Tensor, y: torch.Tensor,
im_paste: torch.Tensor) -> None:
# make the plt figure larger
plt.rcParams['figure.figsize'] = [24, 6]
plt.subplot(1, 4, 1)
show_image(x, 'original')
plt.subplot(1, 4, 2)
show_image(im_masked, 'masked')
plt.subplot(1, 4, 3)
show_image(y, 'reconstruction')
plt.subplot(1, 4, 4)
show_image(im_paste, 'reconstruction + visible')
plt.show()
def post_process(
x: torch.Tensor, y: torch.Tensor, mask: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
x = torch.einsum('nchw->nhwc', x.cpu())
# masked image
im_masked = x * (1 - mask)
# MAE reconstruction pasted with visible patches
im_paste = x * (1 - mask) + y * mask
return x[0], im_masked[0], y[0], im_paste[0]
def main():
parser = ArgumentParser()
parser.add_argument('img', help='Image file')
parser.add_argument('config', help='MAE Config file')
parser.add_argument('checkpoint', help='Checkpoint file')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
args = parser.parse_args()
# build the model from a config file and a checkpoint file
model = init_model(args.config, args.checkpoint, device=args.device)
print('Model loaded.')
# make random mask reproducible (comment out to make it change)
torch.manual_seed(2)
print('MAE with pixel reconstruction:')
img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
model.cfg.data = dict(
test=dict(pipeline=[
dict(type='Resize', size=(224, 224)),
dict(type='ToTensor'),
dict(type='Normalize', **img_norm_cfg),
]))
img = Image.open(args.img)
img, (mask, pred) = inference_model(model, img)
x, im_masked, y, im_paste = post_process(img, pred, mask)
show_images(x, im_masked, y, im_paste)
if __name__ == '__main__':
main()