From 41747f73c774757ac7585f16f9f9a90960f7e9e1 Mon Sep 17 00:00:00 2001 From: RenQin <45731309+soonera@users.noreply.github.com> Date: Fri, 28 Oct 2022 15:37:08 +0800 Subject: [PATCH] [Refactor]: refactor MAE visualization (#471) * [Refactor]: refactor MAE visualization * [Fix]: fix lint * [Refactor]: refactor MAE visualization * [Feature]: add mae_visualization.py * [UT]: add unit test * [Refactor]: move mae_visualization.py to tools/analysis_tools * [Docs]: Add the purpose of the function unpatchify() * [Fix]: fix lint --- demo/mae_visualization.ipynb | 403 ++++++++++++++++++ mmselfsup/apis/__init__.py | 4 + mmselfsup/apis/inference.py | 85 ++++ mmselfsup/models/algorithms/mae.py | 37 +- mmselfsup/models/heads/mae_head.py | 17 + tests/test_apis/test_inference.py | 62 +++ tests/test_models/test_algorithms/test_mae.py | 4 + tools/analysis_tools/mae_visualization.py | 108 +++++ 8 files changed, 719 insertions(+), 1 deletion(-) create mode 100644 demo/mae_visualization.ipynb create mode 100644 mmselfsup/apis/__init__.py create mode 100644 mmselfsup/apis/inference.py create mode 100644 tests/test_apis/test_inference.py create mode 100644 tools/analysis_tools/mae_visualization.py diff --git a/demo/mae_visualization.ipynb b/demo/mae_visualization.ipynb new file mode 100644 index 00000000..850c24c6 --- /dev/null +++ b/demo/mae_visualization.ipynb @@ -0,0 +1,403 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Copyright (c) OpenMMLab. All rights reserved.\n", + "\n", + "Copyright (c) Meta Platforms, Inc. and affiliates.\n", + "\n", + "Modified from https://colab.research.google.com/github/facebookresearch/mae/blob/main/demo/mae_visualize.ipynb\n", + "\n", + "## Masked Autoencoders: Visualization Demo\n", + "\n", + "This is a visualization demo using our pre-trained MAE models. No GPU is needed." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Prepare\n", + "Check environment. Install packages if in Colab." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "\n", + "# check whether run in Colab\n", + "if 'google.colab' in sys.modules:\n", + " print('Running in Colab.')\n", + " !pip3 install openmim\n", + " !pip install -U openmim\n", + " !mim install 'mmengine==0.1.0' 'mmcv>=2.0.0rc1'\n", + "\n", + " !git clone https://github.com/open-mmlab/mmselfsup.git\n", + " %cd mmselfsup/\n", + " !git checkout dev-1.x\n", + " !pip install -e .\n", + "\n", + " sys.path.append('./mmselfsup')\n", + " %cd demo\n", + "else:\n", + " sys.path.append('..')" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import torch\n", + "from mmengine.dataset import Compose, default_collate\n", + "\n", + "from mmselfsup.apis import inference_model\n", + "from mmselfsup.models.utils import SelfSupDataPreprocessor\n", + "from mmselfsup.registry import MODELS\n", + "from mmselfsup.utils import register_all_modules" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Define utils" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "# define the utils\n", + "\n", + "imagenet_mean = np.array([0.485, 0.456, 0.406])\n", + "imagenet_std = np.array([0.229, 0.224, 0.225])\n", + "\n", + "def show_image(image, title=''):\n", + " # image is [H, W, 3]\n", + " assert image.shape[2] == 3\n", + " image = torch.clip((image * imagenet_std + imagenet_mean) * 255, 0, 255).int()\n", + " plt.imshow(image)\n", + " plt.title(title, fontsize=16)\n", + " plt.axis('off')\n", + " return\n", + "\n", + "\n", + "def show_images(x, im_masked, y, im_paste):\n", + " # make the plt figure larger\n", + " plt.rcParams['figure.figsize'] = [24, 6]\n", + "\n", + " plt.subplot(1, 4, 1)\n", + " show_image(x, \"original\")\n", + "\n", + " plt.subplot(1, 4, 2)\n", + " show_image(im_masked, \"masked\")\n", + "\n", + " plt.subplot(1, 4, 3)\n", + " show_image(y, \"reconstruction\")\n", + "\n", + " plt.subplot(1, 4, 4)\n", + " show_image(im_paste, \"reconstruction + visible\")\n", + "\n", + " plt.show()\n", + "\n", + "\n", + "def post_process(x, y, mask):\n", + " x = torch.einsum('nchw->nhwc', x.cpu())\n", + " # masked image\n", + " im_masked = x * (1 - mask)\n", + " # MAE reconstruction pasted with visible patches\n", + " im_paste = x * (1 - mask) + y * mask\n", + " return x[0], im_masked[0], y[0], im_paste[0]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Prepare config file" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Overwriting ../configs/selfsup/mae/mae_visualization.py\n" + ] + } + ], + "source": [ + "%%writefile ../configs/selfsup/mae/mae_visualization.py\n", + "model = dict(\n", + " type='MAE',\n", + " data_preprocessor=dict(\n", + " mean=[123.675, 116.28, 103.53],\n", + " std=[58.395, 57.12, 57.375],\n", + " bgr_to_rgb=True),\n", + " backbone=dict(type='MAEViT', arch='l', patch_size=16, mask_ratio=0.75),\n", + " neck=dict(\n", + " type='MAEPretrainDecoder',\n", + " patch_size=16,\n", + " in_chans=3,\n", + " embed_dim=1024,\n", + " decoder_embed_dim=512,\n", + " decoder_depth=8,\n", + " decoder_num_heads=16,\n", + " mlp_ratio=4.,\n", + " ),\n", + " head=dict(\n", + " type='MAEPretrainHead',\n", + " norm_pix=True,\n", + " patch_size=16,\n", + " loss=dict(type='MAEReconstructionLoss')),\n", + " init_cfg=[\n", + " dict(type='Xavier', distribution='uniform', layer='Linear'),\n", + " dict(type='Constant', layer='LayerNorm', val=1.0, bias=0.0)\n", + " ])\n", + "\n", + "file_client_args = dict(backend='disk')\n", + "\n", + "# dataset summary\n", + "test_dataloader = dict(\n", + " dataset=dict(pipeline=[\n", + " dict(type='LoadImageFromFile', file_client_args=file_client_args),\n", + " dict(type='Resize', scale=(224, 224)),\n", + " dict(type='PackSelfSupInputs', meta_keys=['img_path'])\n", + " ]))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load a pre-trained MAE model" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--2022-09-03 00:34:55-- https://download.openmmlab.com/mmselfsup/mae/mae_visualize_vit_large.pth\n", + "正在解析主机 download.openmmlab.com (download.openmmlab.com)... 47.107.10.247\n", + "正在连接 download.openmmlab.com (download.openmmlab.com)|47.107.10.247|:443... 已连接。\n", + "已发出 HTTP 请求,正在等待回应... 200 OK\n", + "长度: 1318299501 (1.2G) [application/octet-stream]\n", + "正在保存至: “mae_visualize_vit_large.pth”\n", + "\n", + "mae_visualize_vit_l 100%[===================>] 1.23G 3.22MB/s 用时 6m 4s \n", + "\n", + "2022-09-03 00:40:59 (3.46 MB/s) - 已保存 “mae_visualize_vit_large.pth” [1318299501/1318299501])\n", + "\n" + ] + } + ], + "source": [ + "# This is an MAE model trained with pixels as targets for visualization (ViT-large, training mask ratio=0.75)\n", + "\n", + "# download checkpoint if not exist\n", + "# This ckpt is converted from https://dl.fbaipublicfiles.com/mae/visualize/mae_visualize_vit_large.pth\n", + "!wget -nc https://download.openmmlab.com/mmselfsup/mae/mae_visualize_vit_large.pth" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "local loads checkpoint from path: mae_visualize_vit_large.pth\n", + "Model loaded.\n" + ] + } + ], + "source": [ + "from mmselfsup.apis import init_model\n", + "ckpt_path = \"mae_visualize_vit_large.pth\"\n", + "model = init_model('../configs/selfsup/mae/mae_visualization.py', ckpt_path, device='cpu')\n", + "print('Model loaded.')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load an image" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# make random mask reproducible (comment out to make it change)\n", + "register_all_modules()\n", + "torch.manual_seed(2)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--2022-09-03 00:41:01-- https://download.openmmlab.com/mmselfsup/mae/fox.jpg\n", + "正在解析主机 download.openmmlab.com (download.openmmlab.com)... 101.133.111.186\n", + "正在连接 download.openmmlab.com (download.openmmlab.com)|101.133.111.186|:443... 已连接。\n", + "已发出 HTTP 请求,正在等待回应... 200 OK\n", + "长度: 60133 (59K) [image/jpeg]\n", + "正在保存至: “fox.jpg”\n", + "\n", + "fox.jpg 100%[===================>] 58.72K --.-KB/s 用时 0.06s \n", + "\n", + "2022-09-03 00:41:01 (962 KB/s) - 已保存 “fox.jpg” [60133/60133])\n", + "\n" + ] + } + ], + "source": [ + "!wget -nc 'https://download.openmmlab.com/mmselfsup/mae/fox.jpg'" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "img_path = 'fox.jpg'" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "cfg = model.cfg\n", + "test_pipeline = Compose(cfg.test_dataloader.dataset.pipeline)\n", + "data_preprocessor = MODELS.build(cfg.model.data_preprocessor)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "data = dict(img_path=img_path)\n", + "data = test_pipeline(data)\n", + "data = default_collate([data])\n", + "img, _ = data_preprocessor(data, False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plt.rcParams['figure.figsize'] = [5, 5]\n", + "show_image(torch.einsum('nchw->nhwc', img[0].cpu())[0])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Run MAE on the image" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "results = inference_model(model, img_path)\n", + "x, im_masked, y, im_paste = post_process(img[0], results.pred.value, results.mask.value)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print('MAE with pixel reconstruction:')\n", + "show_images(x, im_masked, y, im_paste)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.13" + }, + "vscode": { + "interpreter": { + "hash": "1742319693997e01e5942276ccf039297cd0a474ab9a20f711b7fa536eca5436" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/mmselfsup/apis/__init__.py b/mmselfsup/apis/__init__.py new file mode 100644 index 00000000..08ccc087 --- /dev/null +++ b/mmselfsup/apis/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .inference import inference_model, init_model + +__all__ = ['init_model', 'inference_model'] diff --git a/mmselfsup/apis/inference.py b/mmselfsup/apis/inference.py new file mode 100644 index 00000000..f448b8f8 --- /dev/null +++ b/mmselfsup/apis/inference.py @@ -0,0 +1,85 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from typing import Optional, Union + +import numpy as np +import torch +from mmengine.config import Config +from mmengine.dataset import Compose, default_collate +from mmengine.runner import load_checkpoint +from torch import nn + +from mmselfsup.models import build_algorithm +from mmselfsup.structures import SelfSupDataSample + + +def init_model(config: Union[str, Config], + checkpoint: Optional[str] = None, + device: str = 'cuda:0', + options: Optional[dict] = None) -> nn.Module: + """Initialize a model from config file. + + Args: + config (str or :obj:`mmengine.Config`): Config file path or the config + object. + checkpoint (str, optional): Checkpoint path. If left as None, the model + will not load any weights. + device (str): The device where the anchors will be put on. + Defaults to cuda:0. + options (dict): Options to override some settings in the used config. + + Returns: + nn.Module: The initialized model. + """ + if isinstance(config, str): + config = Config.fromfile(config) + elif not isinstance(config, Config): + raise TypeError('config must be a filename or Config object, ' + f'but got {type(config)}') + if options is not None: + config.merge_from_dict(options) + config.model.pretrained = None + config.model.setdefault('data_preprocessor', + config.get('data_preprocessor', None)) + model = build_algorithm(config.model) + if checkpoint is not None: + # Mapping the weights to GPU may cause unexpected video memory leak + # which refers to https://github.com/open-mmlab/mmdetection/pull/6405 + checkpoint = load_checkpoint(model, checkpoint, map_location='cpu') + model.cfg = config # save the config in the model for convenience + model.to(device) + model.eval() + return model + + +def inference_model(model: nn.Module, + img: Union[str, np.ndarray]) -> SelfSupDataSample: + """Inference an image with the mmselfsup model. + + Args: + model (nn.Module): The loaded model. + img (Union[str, ndarray]): + Either image path or loaded image. + + Returns: + SelfSupDataSample: Output of model inference. + """ + cfg = model.cfg + # build the data pipeline + test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline + if isinstance(img, str): + if test_pipeline_cfg[0]['type'] != 'LoadImageFromFile': + test_pipeline_cfg.insert(0, dict(type='LoadImageFromFile')) + data = dict(img_path=img) + else: + if test_pipeline_cfg[0]['type'] == 'LoadImageFromFile': + test_pipeline_cfg.pop(0) + data = dict(img=img) + test_pipeline = Compose(test_pipeline_cfg) + data = test_pipeline(data) + data = default_collate([data]) + + # forward the model + with torch.no_grad(): + results = model.test_step(data) + return results diff --git a/mmselfsup/models/algorithms/mae.py b/mmselfsup/models/algorithms/mae.py index 42a8837c..1e89ca4b 100644 --- a/mmselfsup/models/algorithms/mae.py +++ b/mmselfsup/models/algorithms/mae.py @@ -1,7 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Dict, List, Tuple +from typing import Dict, List, Optional, Tuple import torch +from mmengine.structures import BaseDataElement from mmselfsup.registry import MODELS from mmselfsup.structures import SelfSupDataSample @@ -30,6 +31,40 @@ class MAE(BaseModel): pred = self.neck(latent, ids_restore) return pred + def predict(self, + inputs: List[torch.Tensor], + data_samples: Optional[List[SelfSupDataSample]] = None, + **kwargs) -> SelfSupDataSample: + """The forward function in testing. It is mainly for image + reconstruction. + + Args: + inputs (List[torch.Tensor]): The input images. + data_samples (List[SelfSupDataSample]): All elements required + during the forward function. + + Returns: + SelfSupDataSample: The prediction from model. + """ + + latent, mask, ids_restore = self.backbone(inputs[0]) + pred = self.neck(latent, ids_restore) + + pred = self.head.unpatchify(pred) + pred = torch.einsum('nchw->nhwc', pred).detach().cpu() + + mask = mask.detach() + mask = mask.unsqueeze(-1).repeat(1, 1, self.head.patch_size**2 * + 3) # (N, H*W, p*p*3) + mask = self.head.unpatchify(mask) # 1 is removing, 0 is keeping + mask = torch.einsum('nchw->nhwc', mask).detach().cpu() + + results = SelfSupDataSample() + results.mask = BaseDataElement(**dict(value=mask)) + results.pred = BaseDataElement(**dict(value=pred)) + + return results + def loss(self, inputs: List[torch.Tensor], data_samples: List[SelfSupDataSample], **kwargs) -> Dict[str, torch.Tensor]: diff --git a/mmselfsup/models/heads/mae_head.py b/mmselfsup/models/heads/mae_head.py index d6e5f6ab..e1281878 100644 --- a/mmselfsup/models/heads/mae_head.py +++ b/mmselfsup/models/heads/mae_head.py @@ -43,6 +43,23 @@ class MAEPretrainHead(BaseModule): x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3)) return x + def unpatchify(self, x: torch.Tensor) -> torch.Tensor: + """Combine non-overlapped patches into images. + + Args: + x (torch.Tensor): The shape is (N, L, patch_size**2 *3) + Returns: + imgs (torch.Tensor): The shape is (N, 3, H, W) + """ + p = self.patch_size + h = w = int(x.shape[1]**.5) + assert h * w == x.shape[1] + + x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) + x = torch.einsum('nhwpqc->nchpwq', x) + imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) + return imgs + def construct_target(self, target: torch.Tensor) -> torch.Tensor: """Construct the reconstruction target. diff --git a/tests/test_apis/test_inference.py b/tests/test_apis/test_inference.py new file mode 100644 index 00000000..aa37cc0b --- /dev/null +++ b/tests/test_apis/test_inference.py @@ -0,0 +1,62 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +import platform +from typing import List, Optional + +import pytest +import torch +import torch.nn as nn +from mmengine.config import Config + +from mmselfsup.apis import inference_model +from mmselfsup.models import BaseModel +from mmselfsup.structures import SelfSupDataSample +from mmselfsup.utils import register_all_modules + +backbone = dict( + type='ResNet', + depth=18, + in_channels=2, + out_indices=[4], # 0: conv-1, x: stage-x + norm_cfg=dict(type='BN')) + + +class ExampleModel(BaseModel): + + def __init__(self, backbone=backbone): + super(ExampleModel, self).__init__(backbone=backbone) + self.layer = nn.Linear(1, 1) + + def predict(self, + inputs: List[torch.Tensor], + data_samples: Optional[List[SelfSupDataSample]] = None, + **kwargs) -> SelfSupDataSample: + out = self.layer(inputs[0]) + return out + + +@pytest.mark.skipif(platform.system() == 'Windows', reason='') +def test_inference_model(): + register_all_modules() + + # Specify the data settings + cfg = Config.fromfile( + 'configs/selfsup/relative_loc/relative-loc_resnet50_8xb64-steplr-70e_in1k.py' # noqa: E501 + ) + # Build the algorithm + model = ExampleModel() + model.cfg = cfg + model.cfg.test_dataloader = dict( + dataset=dict(pipeline=[ + dict( + type='LoadImageFromFile', + file_client_args=dict(backend='disk')), + dict(type='Resize', scale=(1, 1)), + dict(type='PackSelfSupInputs', meta_keys=['img_path']) + ])) + + img_path = osp.join(osp.dirname(__file__), '..', 'data', 'color.jpg') + + # inference model + out = inference_model(model, img_path) + assert out.size() == torch.Size([1, 3, 1, 1]) diff --git a/tests/test_models/test_algorithms/test_mae.py b/tests/test_models/test_algorithms/test_mae.py index 83a5bdf2..51546e00 100644 --- a/tests/test_models/test_algorithms/test_mae.py +++ b/tests/test_models/test_algorithms/test_mae.py @@ -50,3 +50,7 @@ def test_mae(): fake_feats = alg(fake_batch_inputs, fake_data_samples, mode='tensor') assert list(fake_feats.shape) == [2, 196, 768] + + results = alg(fake_batch_inputs, fake_data_samples, mode='predict') + assert list(results.mask.value.shape) == [2, 224, 224, 3] + assert list(results.pred.value.shape) == [2, 224, 224, 3] diff --git a/tools/analysis_tools/mae_visualization.py b/tools/analysis_tools/mae_visualization.py new file mode 100644 index 00000000..2f6902ce --- /dev/null +++ b/tools/analysis_tools/mae_visualization.py @@ -0,0 +1,108 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Copyright (c) Meta Platforms, Inc. and affiliates. +# Modified from https://colab.research.google.com/github/facebookresearch/mae +# /blob/main/demo/mae_visualize.ipynb +from argparse import ArgumentParser +from typing import Tuple + +import matplotlib.pyplot as plt +import numpy as np +import torch +from mmengine.dataset import Compose, default_collate + +from mmselfsup.apis import inference_model, init_model +from mmselfsup.registry import MODELS +from mmselfsup.utils import register_all_modules + +imagenet_mean = np.array([0.485, 0.456, 0.406]) +imagenet_std = np.array([0.229, 0.224, 0.225]) + + +def show_image(image: torch.Tensor, title: str = '') -> None: + # image is [H, W, 3] + assert image.shape[2] == 3 + image = torch.clip((image * imagenet_std + imagenet_mean) * 255, 0, + 255).int() + plt.imshow(image) + plt.title(title, fontsize=16) + plt.axis('off') + return + + +def save_images(x: torch.Tensor, im_masked: torch.Tensor, y: torch.Tensor, + im_paste: torch.Tensor, out_file: str) -> None: + # make the plt figure larger + plt.rcParams['figure.figsize'] = [24, 6] + + plt.subplot(1, 4, 1) + show_image(x, 'original') + + plt.subplot(1, 4, 2) + show_image(im_masked, 'masked') + + plt.subplot(1, 4, 3) + show_image(y, 'reconstruction') + + plt.subplot(1, 4, 4) + show_image(im_paste, 'reconstruction + visible') + + plt.savefig(out_file) + + +def post_process( + x: torch.Tensor, y: torch.Tensor, mask: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + x = torch.einsum('nchw->nhwc', x.cpu()) + # masked image + im_masked = x * (1 - mask) + # MAE reconstruction pasted with visible patches + im_paste = x * (1 - mask) + y * mask + return x[0], im_masked[0], y[0], im_paste[0] + + +def main(): + parser = ArgumentParser() + parser.add_argument('img_path', help='Image file path') + parser.add_argument('config', help='MAE Config file') + parser.add_argument('checkpoint', help='Checkpoint file') + parser.add_argument('out_file', help='The output image file path') + parser.add_argument( + '--device', default='cuda:0', help='Device used for inference') + args = parser.parse_args() + + register_all_modules() + + # build the model from a config file and a checkpoint file + model = init_model(args.config, args.checkpoint, device=args.device) + print('Model loaded.') + + # make random mask reproducible (comment out to make it change) + torch.manual_seed(2) + print('MAE with pixel reconstruction:') + + model.cfg.test_dataloader = dict( + dataset=dict(pipeline=[ + dict( + type='LoadImageFromFile', + file_client_args=dict(backend='disk')), + dict(type='Resize', scale=(224, 224)), + dict(type='PackSelfSupInputs', meta_keys=['img_path']) + ])) + + results = inference_model(model, args.img_path) + + cfg = model.cfg + test_pipeline = Compose(cfg.test_dataloader.dataset.pipeline) + data_preprocessor = MODELS.build(cfg.model.data_preprocessor) + data = dict(img_path=args.img_path) + data = test_pipeline(data) + data = default_collate([data]) + img, _ = data_preprocessor(data, False) + + x, im_masked, y, im_paste = post_process(img[0], results.pred.value, + results.mask.value) + save_images(x, im_masked, y, im_paste, args.out_file) + + +if __name__ == '__main__': + main()