mirror of
https://github.com/open-mmlab/mmselfsup.git
synced 2025-06-03 14:59:38 +08:00
[Tools]: MAE Reconstructed Image Visualization (#376)
* [Tools]: MAE Reconstructed Image Visualization] * [Fix]: fix docstring and type hint * [Fix]: fix docstring in MAE clsss * [Fix]: fix docstring in MAE clsss * [Fix]: fix type hint * [Fix]: fix type hint and docstring * [refactor]: refactor super init
This commit is contained in:
parent
074ae09568
commit
3f530f085e
268
demo/mae_visualization.ipynb
Normal file
268
demo/mae_visualization.ipynb
Normal file
@ -0,0 +1,268 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Copyright (c) OpenMMLab. All rights reserved.\n",
|
||||||
|
"\n",
|
||||||
|
"Copyright (c) Meta Platforms, Inc. and affiliates.\n",
|
||||||
|
"\n",
|
||||||
|
"Modified from https://colab.research.google.com/github/facebookresearch/mae/blob/main/demo/mae_visualize.ipynb\n",
|
||||||
|
"\n",
|
||||||
|
"## Masked Autoencoders: Visualization Demo\n",
|
||||||
|
"\n",
|
||||||
|
"This is a visualization demo using our pre-trained MAE models. No GPU is needed."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Prepare\n",
|
||||||
|
"Check environment. Install packages if in Colab."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import sys\n",
|
||||||
|
"import os\n",
|
||||||
|
"import requests\n",
|
||||||
|
"\n",
|
||||||
|
"import torch\n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
"\n",
|
||||||
|
"import matplotlib.pyplot as plt\n",
|
||||||
|
"from PIL import Image\n",
|
||||||
|
"\n",
|
||||||
|
"from mmselfsup.models import build_algorithm\n",
|
||||||
|
"\n",
|
||||||
|
"# check whether run in Colab\n",
|
||||||
|
"if 'google.colab' in sys.modules:\n",
|
||||||
|
" print('Running in Colab.')\n",
|
||||||
|
" !pip install openmim\n",
|
||||||
|
" !mim install mmcv-full\n",
|
||||||
|
" !git clone https://github.com/open-mmlab/mmselfsup.git\n",
|
||||||
|
" %cd mmselfsup/\n",
|
||||||
|
" !pip install -e .\n",
|
||||||
|
" sys.path.append('./mmselfsup')\n",
|
||||||
|
" %cd demo\n",
|
||||||
|
"else:\n",
|
||||||
|
" sys.path.append('..')"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Define utils"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# define the utils\n",
|
||||||
|
"\n",
|
||||||
|
"imagenet_mean = np.array([0.485, 0.456, 0.406])\n",
|
||||||
|
"imagenet_std = np.array([0.229, 0.224, 0.225])\n",
|
||||||
|
"\n",
|
||||||
|
"def show_image(image, title=''):\n",
|
||||||
|
" # image is [H, W, 3]\n",
|
||||||
|
" assert image.shape[2] == 3\n",
|
||||||
|
" image = torch.clip((image * imagenet_std + imagenet_mean) * 255, 0, 255).int()\n",
|
||||||
|
" plt.imshow(image)\n",
|
||||||
|
" plt.title(title, fontsize=16)\n",
|
||||||
|
" plt.axis('off')\n",
|
||||||
|
" return\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"def show_images(x, im_masked, y, im_paste):\n",
|
||||||
|
" # make the plt figure larger\n",
|
||||||
|
" plt.rcParams['figure.figsize'] = [24, 6]\n",
|
||||||
|
"\n",
|
||||||
|
" plt.subplot(1, 4, 1)\n",
|
||||||
|
" show_image(x, \"original\")\n",
|
||||||
|
"\n",
|
||||||
|
" plt.subplot(1, 4, 2)\n",
|
||||||
|
" show_image(im_masked, \"masked\")\n",
|
||||||
|
"\n",
|
||||||
|
" plt.subplot(1, 4, 3)\n",
|
||||||
|
" show_image(y, \"reconstruction\")\n",
|
||||||
|
"\n",
|
||||||
|
" plt.subplot(1, 4, 4)\n",
|
||||||
|
" show_image(im_paste, \"reconstruction + visible\")\n",
|
||||||
|
"\n",
|
||||||
|
" plt.show()\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"def post_process(x, y, mask):\n",
|
||||||
|
" x = torch.einsum('nchw->nhwc', x.cpu())\n",
|
||||||
|
" # masked image\n",
|
||||||
|
" im_masked = x * (1 - mask)\n",
|
||||||
|
" # MAE reconstruction pasted with visible patches\n",
|
||||||
|
" im_paste = x * (1 - mask) + y * mask\n",
|
||||||
|
" return x[0], im_masked[0], y[0], im_paste[0]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Load an image"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 16,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# load an image\n",
|
||||||
|
"img_url = 'https://download.openmmlab.com/mmselfsup/mae/fox.jpg'\n",
|
||||||
|
"img_pil = Image.open(requests.get(img_url, stream=True).raw)\n",
|
||||||
|
"img = img_pil.resize((224, 224))\n",
|
||||||
|
"img = np.array(img) / 255.\n",
|
||||||
|
"\n",
|
||||||
|
"assert img.shape == (224, 224, 3)\n",
|
||||||
|
"\n",
|
||||||
|
"# normalize by ImageNet mean and std\n",
|
||||||
|
"img = img - imagenet_mean\n",
|
||||||
|
"img = img / imagenet_std\n",
|
||||||
|
"\n",
|
||||||
|
"plt.rcParams['figure.figsize'] = [5, 5]\n",
|
||||||
|
"show_image(torch.tensor(img))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Load a pre-trained MAE model"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 15,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"%%writefile ../configs/selfsup/mae/mae_visualization.py\n",
|
||||||
|
"model = dict(\n",
|
||||||
|
" type='MAE',\n",
|
||||||
|
" backbone=dict(type='MAEViT', arch='l', patch_size=16, mask_ratio=0.75),\n",
|
||||||
|
" neck=dict(\n",
|
||||||
|
" type='MAEPretrainDecoder',\n",
|
||||||
|
" patch_size=16,\n",
|
||||||
|
" in_chans=3,\n",
|
||||||
|
" embed_dim=1024,\n",
|
||||||
|
" decoder_embed_dim=512,\n",
|
||||||
|
" decoder_depth=8,\n",
|
||||||
|
" decoder_num_heads=16,\n",
|
||||||
|
" mlp_ratio=4.,\n",
|
||||||
|
" ),\n",
|
||||||
|
" head=dict(type='MAEPretrainHead', norm_pix=True, patch_size=16))\n",
|
||||||
|
"\n",
|
||||||
|
"img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n",
|
||||||
|
"# dataset summary\n",
|
||||||
|
"data = dict(\n",
|
||||||
|
" test=dict(\n",
|
||||||
|
" pipeline = [\n",
|
||||||
|
" dict(type='Resize', size=(224, 224)),\n",
|
||||||
|
" dict(type='ToTensor'),\n",
|
||||||
|
" dict(type='Normalize', **img_norm_cfg),]\n",
|
||||||
|
" ))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 13,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# This is an MAE model trained with pixels as targets for visualization (ViT-large, training mask ratio=0.75)\n",
|
||||||
|
"\n",
|
||||||
|
"# download checkpoint if not exist\n",
|
||||||
|
"# This ckpt is converted from https://dl.fbaipublicfiles.com/mae/visualize/mae_visualize_vit_large.pth\n",
|
||||||
|
"!wget -nc https://download.openmmlab.com/mmselfsup/mae/mae_visualize_vit_large.pth"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 14,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from mmselfsup.apis import init_model\n",
|
||||||
|
"ckpt_path = \"mae_visualize_vit_large.pth\"\n",
|
||||||
|
"model = init_model('../configs/selfsup/mae/mae_visualization.py', ckpt_path, device='cpu')\n",
|
||||||
|
"print('Model loaded.')"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Run MAE on the image"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 12,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# make random mask reproducible (comment out to make it change)\n",
|
||||||
|
"torch.manual_seed(2)\n",
|
||||||
|
"print('MAE with pixel reconstruction:')\n",
|
||||||
|
"\n",
|
||||||
|
"from mmselfsup.apis import inference_model\n",
|
||||||
|
"\n",
|
||||||
|
"img_url = 'https://download.openmmlab.com/mmselfsup/mae/fox.jpg'\n",
|
||||||
|
"img = Image.open(requests.get(img_url, stream=True).raw)\n",
|
||||||
|
"img, (mask, pred) = inference_model(model, img)\n",
|
||||||
|
"x, im_masked, y, im_paste = post_process(img, pred, mask)\n",
|
||||||
|
"show_images(x, im_masked, y, im_paste)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.8.13"
|
||||||
|
},
|
||||||
|
"vscode": {
|
||||||
|
"interpreter": {
|
||||||
|
"hash": "3e4aeeccd14e965f43d0896afbaf8d71604e66b8605affbaa33ec76aa4083757"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
@ -166,6 +166,27 @@ Arguments:
|
|||||||
- `WORK_DIR`: the directory to save the results of visualization.
|
- `WORK_DIR`: the directory to save the results of visualization.
|
||||||
- `[optional arguments]`: for optional arguments, you can refer to [visualize_tsne.py](https://github.com/open-mmlab/mmselfsup/blob/master/tools/analysis_tools/visualize_tsne.py)
|
- `[optional arguments]`: for optional arguments, you can refer to [visualize_tsne.py](https://github.com/open-mmlab/mmselfsup/blob/master/tools/analysis_tools/visualize_tsne.py)
|
||||||
|
|
||||||
|
### MAE Visualization
|
||||||
|
|
||||||
|
We provide a tool to visualize the mask and reconstruction image of MAE model.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
python tools/misc/mae_visualization.py ${IMG} ${CONFIG_FILE} ${CKPT_PATH} --device ${DEVICE}
|
||||||
|
```
|
||||||
|
|
||||||
|
参数:
|
||||||
|
|
||||||
|
- `IMG`: an image path used for visualization.
|
||||||
|
- `CONFIG_FILE`: config file for the pre-trained model.
|
||||||
|
- `CKPT_PATH`: the path of model's checkpoint.
|
||||||
|
- `DEVICE`: device used for inference.
|
||||||
|
|
||||||
|
An example:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
python tools/misc/mae_visualization.py tests/data/color.jpg configs/selfsup/mae/mae_vit-base-p16_8xb512-coslr-400e_in1k.py mae_epoch_400.pth --device 'cuda:0'
|
||||||
|
```
|
||||||
|
|
||||||
### Reproducibility
|
### Reproducibility
|
||||||
|
|
||||||
If you want to make your performance exactly reproducible, please switch on `--deterministic` to train the final model to be published. Note that this flag will switch off `torch.backends.cudnn.benchmark` and slow down the training speed.
|
If you want to make your performance exactly reproducible, please switch on `--deterministic` to train the final model to be published. Note that this flag will switch off `torch.backends.cudnn.benchmark` and slow down the training speed.
|
||||||
|
@ -164,6 +164,27 @@ python tools/analysis_tools/visualize_tsne.py ${CONFIG_FILE} --checkpoint ${CKPT
|
|||||||
- `WORK_DIR`: 保存可视化结果的路径.
|
- `WORK_DIR`: 保存可视化结果的路径.
|
||||||
- `[optional arguments]`: 可选参数,具体可以参考 [visualize_tsne.py](../../tools/analysis_tools/visualize_tsne.py)
|
- `[optional arguments]`: 可选参数,具体可以参考 [visualize_tsne.py](../../tools/analysis_tools/visualize_tsne.py)
|
||||||
|
|
||||||
|
### MAE 可视化
|
||||||
|
|
||||||
|
我们提供了一个对 MAE 掩码效果和重建效果可视化可视化的方法:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
python tools/misc/mae_visualization.py ${IMG} ${CONFIG_FILE} ${CKPT_PATH} --device ${DEVICE}
|
||||||
|
```
|
||||||
|
|
||||||
|
参数:
|
||||||
|
|
||||||
|
- `IMG`: 用于可视化的图片
|
||||||
|
- `CONFIG_FILE`: 训练预训练模型的参数配置文件.
|
||||||
|
- `CKPT_PATH`: 预训练模型的路径.
|
||||||
|
- `DEVICE`: 用于推理的设备.
|
||||||
|
|
||||||
|
示例:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
python tools/misc/mae_visualization.py tests/data/color.jpg configs/selfsup/mae/mae_vit-base-p16_8xb512-coslr-400e_in1k.py mae_epoch_400.pth --device 'cuda:0'
|
||||||
|
```
|
||||||
|
|
||||||
### 可复现性
|
### 可复现性
|
||||||
|
|
||||||
如果您想确保模型精度的可复现性,您可以设置 `--deterministic` 参数。但是,开启 `--deterministic` 意味着关闭 `torch.backends.cudnn.benchmark`, 所以会使模型的训练速度变慢。
|
如果您想确保模型精度的可复现性,您可以设置 `--deterministic` 参数。但是,开启 `--deterministic` 意味着关闭 `torch.backends.cudnn.benchmark`, 所以会使模型的训练速度变慢。
|
||||||
|
@ -1,4 +1,8 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from .inference import inference_model, init_model
|
||||||
from .train import init_random_seed, set_random_seed, train_model
|
from .train import init_random_seed, set_random_seed, train_model
|
||||||
|
|
||||||
__all__ = ['init_random_seed', 'set_random_seed', 'train_model']
|
__all__ = [
|
||||||
|
'init_random_seed', 'inference_model', 'set_random_seed', 'train_model',
|
||||||
|
'init_model'
|
||||||
|
]
|
||||||
|
85
mmselfsup/apis/inference.py
Normal file
85
mmselfsup/apis/inference.py
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
import mmcv
|
||||||
|
import torch
|
||||||
|
from mmcv.parallel import collate, scatter
|
||||||
|
from mmcv.runner import load_checkpoint
|
||||||
|
from mmcv.utils import build_from_cfg
|
||||||
|
from PIL import Image
|
||||||
|
from torch import nn
|
||||||
|
from torchvision.transforms import Compose
|
||||||
|
|
||||||
|
from mmselfsup.datasets import PIPELINES
|
||||||
|
from mmselfsup.models import build_algorithm
|
||||||
|
|
||||||
|
|
||||||
|
def init_model(config: Union[str, mmcv.Config],
|
||||||
|
checkpoint: Optional[str] = None,
|
||||||
|
device: str = 'cuda:0',
|
||||||
|
options: Optional[dict] = None) -> nn.Module:
|
||||||
|
"""Initialize an model from config file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (str or :obj:``mmcv.Config``): Config file path or the config
|
||||||
|
object.
|
||||||
|
checkpoint (str, optional): Checkpoint path. If left as None, the model
|
||||||
|
will not load any weights. Defaults to None.
|
||||||
|
device (str): The device where the model will be put on.
|
||||||
|
Default to 'cuda:0'.
|
||||||
|
options (dict, optional): Options to override some settings in the used
|
||||||
|
config. Defaults to None.
|
||||||
|
Returns:
|
||||||
|
nn.Module: The initialized model.
|
||||||
|
"""
|
||||||
|
if isinstance(config, str):
|
||||||
|
config = mmcv.Config.fromfile(config)
|
||||||
|
elif not isinstance(config, mmcv.Config):
|
||||||
|
raise TypeError('config must be a filename or Config object, '
|
||||||
|
f'but got {type(config)}')
|
||||||
|
if options is not None:
|
||||||
|
config.merge_from_dict(options)
|
||||||
|
model = build_algorithm(config.model)
|
||||||
|
if checkpoint is not None:
|
||||||
|
# Mapping the weights to GPU may cause unexpected video memory leak
|
||||||
|
# which refers to https://github.com/open-mmlab/mmdetection/pull/6405
|
||||||
|
checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
|
||||||
|
model.cfg = config # save the config in the model for convenience
|
||||||
|
model.to(device)
|
||||||
|
model.eval()
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def inference_model(
|
||||||
|
model: nn.Module,
|
||||||
|
data: Image) -> Tuple[torch.Tensor, Union[torch.Tensor, dict]]:
|
||||||
|
"""Inference an image with the model.
|
||||||
|
Args:
|
||||||
|
model (nn.Module): The loaded model.
|
||||||
|
data (PIL.Image): The loaded image.
|
||||||
|
Returns:
|
||||||
|
Tuple[torch.Tensor, Union(torch.Tensor, dict)]: Output of model
|
||||||
|
inference.
|
||||||
|
- data (torch.Tensor): The loaded image to input model.
|
||||||
|
- output (torch.Tensor, dict[str, torch.Tensor]): the output
|
||||||
|
of test model.
|
||||||
|
"""
|
||||||
|
cfg = model.cfg
|
||||||
|
device = next(model.parameters()).device # model device
|
||||||
|
# build the data pipeline
|
||||||
|
test_pipeline = [
|
||||||
|
build_from_cfg(p, PIPELINES) for p in cfg.data.test.pipeline
|
||||||
|
]
|
||||||
|
test_pipeline = Compose(test_pipeline)
|
||||||
|
|
||||||
|
data = test_pipeline(data)
|
||||||
|
data = collate([data], samples_per_gpu=1)
|
||||||
|
|
||||||
|
if next(model.parameters()).is_cuda:
|
||||||
|
# scatter to specified GPU
|
||||||
|
data = scatter(data, [device])[0]
|
||||||
|
|
||||||
|
# forward the model
|
||||||
|
with torch.no_grad():
|
||||||
|
output = model(data, mode='test')
|
||||||
|
return data, output
|
@ -1,4 +1,8 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from typing import Optional, Dict, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
from ..builder import ALGORITHMS, build_backbone, build_head, build_neck
|
from ..builder import ALGORITHMS, build_backbone, build_head, build_neck
|
||||||
from .base import BaseModel
|
from .base import BaseModel
|
||||||
|
|
||||||
@ -8,17 +12,22 @@ class MAE(BaseModel):
|
|||||||
"""MAE.
|
"""MAE.
|
||||||
|
|
||||||
Implementation of `Masked Autoencoders Are Scalable Vision Learners
|
Implementation of `Masked Autoencoders Are Scalable Vision Learners
|
||||||
<https://arxiv.org/abs/2111.06377>`_.
|
<https://arxiv.org/abs/2111.06377>`_.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
backbone (dict): Config dict for encoder. Defaults to None.
|
backbone (dict, optional): Config dict for encoder. Defaults to None.
|
||||||
neck (dict): Config dict for encoder. Defaults to None.
|
neck (dict, optional): Config dict for encoder. Defaults to None.
|
||||||
head (dict): Config dict for loss functions. Defaults to None.
|
head (dict, optional): Config dict for loss functions. Defaults to None.
|
||||||
init_cfg (dict): Config dict for weight initialization.
|
init_cfg (dict, optional): Config dict for weight initialization.
|
||||||
Defaults to None.
|
Defaults to None.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, backbone=None, neck=None, head=None, init_cfg=None):
|
def __init__(self,
|
||||||
super(MAE, self).__init__(init_cfg)
|
backbone: Optional[dict] = None,
|
||||||
|
neck: Optional[dict] = None,
|
||||||
|
head: Optional[dict] = None,
|
||||||
|
init_cfg: Optional[dict] = None) -> None:
|
||||||
|
super().__init__(init_cfg)
|
||||||
assert backbone is not None
|
assert backbone is not None
|
||||||
self.backbone = build_backbone(backbone)
|
self.backbone = build_backbone(backbone)
|
||||||
assert neck is not None
|
assert neck is not None
|
||||||
@ -28,31 +37,56 @@ class MAE(BaseModel):
|
|||||||
self.head = build_head(head)
|
self.head = build_head(head)
|
||||||
|
|
||||||
def init_weights(self):
|
def init_weights(self):
|
||||||
super(MAE, self).init_weights()
|
super().init_weights()
|
||||||
|
|
||||||
def extract_feat(self, img):
|
def extract_feat(self, img: torch.Tensor) -> Tuple[torch.Tensor]:
|
||||||
"""Function to extract features from backbone.
|
"""Function to extract features from backbone.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
img (Tensor): Input images of shape (N, C, H, W).
|
img (torch.Tensor): Input images of shape (N, C, H, W).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple[Tensor]: backbone outputs.
|
Tuple[torch.Tensor]: backbone outputs.
|
||||||
"""
|
"""
|
||||||
return self.backbone(img)
|
return self.backbone(img)
|
||||||
|
|
||||||
def forward_train(self, img, **kwargs):
|
def forward_train(self, img: torch.Tensor,
|
||||||
|
**kwargs) -> Dict[str, torch.Tensor]:
|
||||||
"""Forward computation during training.
|
"""Forward computation during training.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
img (Tensor): Input images of shape (N, C, H, W).
|
img (torch.Tensor): Input images of shape (N, C, H, W).
|
||||||
kwargs: Any keyword arguments to be used to forward.
|
kwargs: Any keyword arguments to be used to forward.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict[str, Tensor]: A dictionary of loss components.
|
Dict[str, torch.Tensor]: A dictionary of loss components.
|
||||||
"""
|
"""
|
||||||
latent, mask, ids_restore = self.backbone(img)
|
latent, mask, ids_restore = self.backbone(img)
|
||||||
pred = self.neck(latent, ids_restore)
|
pred = self.neck(latent, ids_restore)
|
||||||
losses = self.head(img, pred, mask)
|
losses = self.head(img, pred, mask)
|
||||||
|
|
||||||
return losses
|
return losses
|
||||||
|
|
||||||
|
def forward_test(self, img: torch.Tensor,
|
||||||
|
**kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Forward computation during testing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img (torch.Tensor): Input images of shape (N, C, H, W).
|
||||||
|
kwargs: Any keyword arguments to be used to forward.
|
||||||
|
Returns:
|
||||||
|
Tuple[torch.Tensor, torch.Tensor]: Output of model test.
|
||||||
|
- mask: Mask used to mask image.
|
||||||
|
- pred: The output of neck.
|
||||||
|
"""
|
||||||
|
latent, mask, ids_restore = self.backbone(img)
|
||||||
|
pred = self.neck(latent, ids_restore)
|
||||||
|
|
||||||
|
pred = self.head.unpatchify(pred)
|
||||||
|
pred = torch.einsum('nchw->nhwc', pred).detach().cpu()
|
||||||
|
|
||||||
|
mask = mask.detach()
|
||||||
|
mask = mask.unsqueeze(-1).repeat(1, 1, self.head.patch_size**2 *
|
||||||
|
3) # (N, H*W, p*p*3)
|
||||||
|
mask = self.head.unpatchify(mask) # 1 is removing, 0 is keeping
|
||||||
|
mask = torch.einsum('nchw->nhwc', mask).detach().cpu()
|
||||||
|
|
||||||
|
return mask, pred
|
||||||
|
@ -18,13 +18,18 @@ class MAEPretrainHead(BaseModule):
|
|||||||
patch_size (int): Patch size. Defaults to 16.
|
patch_size (int): Patch size. Defaults to 16.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, norm_pix=False, patch_size=16):
|
def __init__(self, norm_pix: bool = False, patch_size: int = 16) -> None:
|
||||||
super(MAEPretrainHead, self).__init__()
|
super().__init__()
|
||||||
self.norm_pix = norm_pix
|
self.norm_pix = norm_pix
|
||||||
self.patch_size = patch_size
|
self.patch_size = patch_size
|
||||||
|
|
||||||
def patchify(self, imgs):
|
def patchify(self, imgs: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
imgs (torch.Tensor): The shape is (N, 3, H, W)
|
||||||
|
Returns:
|
||||||
|
x (torch.Tensor): The shape is (N, L, patch_size**2 *3)
|
||||||
|
"""
|
||||||
p = self.patch_size
|
p = self.patch_size
|
||||||
assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
|
assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
|
||||||
|
|
||||||
@ -34,7 +39,24 @@ class MAEPretrainHead(BaseModule):
|
|||||||
x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
|
x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def forward(self, x, pred, mask):
|
def unpatchify(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): The shape is (N, L, patch_size**2 *3)
|
||||||
|
Returns:
|
||||||
|
imgs (torch.Tensor): The shape is (N, 3, H, W)
|
||||||
|
"""
|
||||||
|
p = self.patch_size
|
||||||
|
h = w = int(x.shape[1]**.5)
|
||||||
|
assert h * w == x.shape[1]
|
||||||
|
|
||||||
|
x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
|
||||||
|
x = torch.einsum('nhwpqc->nchpwq', x)
|
||||||
|
imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
|
||||||
|
return imgs
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, pred: torch.Tensor,
|
||||||
|
mask: torch.Tensor) -> dict:
|
||||||
losses = dict()
|
losses = dict()
|
||||||
target = self.patchify(x)
|
target = self.patchify(x)
|
||||||
if self.norm_pix:
|
if self.norm_pix:
|
||||||
@ -60,7 +82,7 @@ class MAEFinetuneHead(BaseModule):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, embed_dim, num_classes=1000, label_smooth_val=0.1):
|
def __init__(self, embed_dim, num_classes=1000, label_smooth_val=0.1):
|
||||||
super(MAEFinetuneHead, self).__init__()
|
super().__init__()
|
||||||
self.head = nn.Linear(embed_dim, num_classes)
|
self.head = nn.Linear(embed_dim, num_classes)
|
||||||
self.criterion = LabelSmoothLoss(label_smooth_val, num_classes)
|
self.criterion = LabelSmoothLoss(label_smooth_val, num_classes)
|
||||||
|
|
||||||
@ -92,7 +114,7 @@ class MAELinprobeHead(BaseModule):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, embed_dim, num_classes=1000):
|
def __init__(self, embed_dim, num_classes=1000):
|
||||||
super(MAELinprobeHead, self).__init__()
|
super().__init__()
|
||||||
self.head = nn.Linear(embed_dim, num_classes)
|
self.head = nn.Linear(embed_dim, num_classes)
|
||||||
self.bn = nn.BatchNorm1d(embed_dim, affine=False, eps=1e-6)
|
self.bn = nn.BatchNorm1d(embed_dim, affine=False, eps=1e-6)
|
||||||
self.criterion = nn.CrossEntropyLoss()
|
self.criterion = nn.CrossEntropyLoss()
|
||||||
|
59
tests/test_apis/test_inference.py
Normal file
59
tests/test_apis/test_inference.py
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import os.path as osp
|
||||||
|
import platform
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from mmcv import Config
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from mmselfsup.apis import inference_model
|
||||||
|
from mmselfsup.models import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class ExampleModel(BaseModel):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(ExampleModel, self).__init__()
|
||||||
|
self.test_cfg = None
|
||||||
|
self.layer = nn.Linear(1, 1)
|
||||||
|
self.neck = nn.Identity()
|
||||||
|
|
||||||
|
def extract_feat(self, imgs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def forward_train(self, imgs, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def forward_test(self, img, **kwargs):
|
||||||
|
out = self.layer(img)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(platform.system() == 'Windows', reason='')
|
||||||
|
def test_inference_model():
|
||||||
|
# Specify the data settings
|
||||||
|
cfg = Config.fromfile(
|
||||||
|
'configs/selfsup/relative_loc/relative-loc_resnet50_8xb64-steplr-70e_in1k.py' # noqa: E501
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build the algorithm
|
||||||
|
model = ExampleModel()
|
||||||
|
model.cfg = cfg
|
||||||
|
|
||||||
|
img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||||
|
model.cfg.data = dict(
|
||||||
|
test=dict(pipeline=[
|
||||||
|
dict(type='Resize', size=(1, 1)),
|
||||||
|
dict(type='ToTensor'),
|
||||||
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
|
]))
|
||||||
|
|
||||||
|
data = Image.open(
|
||||||
|
osp.join(osp.dirname(__file__), '..', 'data', 'color.jpg'))
|
||||||
|
|
||||||
|
# inference model
|
||||||
|
data, output = inference_model(model, data)
|
||||||
|
assert data.size() == torch.Size([1, 3, 1, 1])
|
||||||
|
assert output.size() == torch.Size([1, 3, 1, 1])
|
@ -33,5 +33,8 @@ def test_mae():
|
|||||||
fake_input = torch.randn((2, 3, 224, 224))
|
fake_input = torch.randn((2, 3, 224, 224))
|
||||||
fake_loss = alg.forward_train(fake_input)
|
fake_loss = alg.forward_train(fake_input)
|
||||||
fake_feature = alg.extract_feat(fake_input)
|
fake_feature = alg.extract_feat(fake_input)
|
||||||
|
mask, pred = alg.forward_test(fake_input)
|
||||||
assert isinstance(fake_loss['loss'].item(), float)
|
assert isinstance(fake_loss['loss'].item(), float)
|
||||||
assert list(fake_feature[0].shape) == [2, 50, 768]
|
assert list(fake_feature[0].shape) == [2, 50, 768]
|
||||||
|
assert list(mask.shape) == [2, 224, 224, 3]
|
||||||
|
assert list(pred.shape) == [2, 224, 224, 3]
|
||||||
|
@ -103,6 +103,10 @@ def test_mae_pretrain_head():
|
|||||||
|
|
||||||
assert loss_norm_pixel['loss'].item() > 0
|
assert loss_norm_pixel['loss'].item() > 0
|
||||||
|
|
||||||
|
x = torch.rand((1, 4, 16**2 * 3))
|
||||||
|
imgs = head_norm_pixel.unpatchify(x)
|
||||||
|
assert imgs.size() == torch.Size((1, 3, 32, 32))
|
||||||
|
|
||||||
|
|
||||||
def test_mae_finetune_head():
|
def test_mae_finetune_head():
|
||||||
|
|
||||||
|
93
tools/misc/mae_visualization.py
Normal file
93
tools/misc/mae_visualization.py
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# Modified from https://colab.research.google.com/github/facebookresearch/mae
|
||||||
|
# /blob/main/demo/mae_visualize.ipynb
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from mmselfsup.apis import inference_model, init_model
|
||||||
|
|
||||||
|
imagenet_mean = np.array([0.485, 0.456, 0.406])
|
||||||
|
imagenet_std = np.array([0.229, 0.224, 0.225])
|
||||||
|
|
||||||
|
|
||||||
|
def show_image(image: torch.Tensor, title: str = '') -> None:
|
||||||
|
# image is [H, W, 3]
|
||||||
|
assert image.shape[2] == 3
|
||||||
|
image = torch.clip((image * imagenet_std + imagenet_mean) * 255, 0,
|
||||||
|
255).int()
|
||||||
|
plt.imshow(image)
|
||||||
|
plt.title(title, fontsize=16)
|
||||||
|
plt.axis('off')
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def show_images(x: torch.Tensor, im_masked: torch.Tensor, y: torch.Tensor,
|
||||||
|
im_paste: torch.Tensor) -> None:
|
||||||
|
# make the plt figure larger
|
||||||
|
plt.rcParams['figure.figsize'] = [24, 6]
|
||||||
|
|
||||||
|
plt.subplot(1, 4, 1)
|
||||||
|
show_image(x, 'original')
|
||||||
|
|
||||||
|
plt.subplot(1, 4, 2)
|
||||||
|
show_image(im_masked, 'masked')
|
||||||
|
|
||||||
|
plt.subplot(1, 4, 3)
|
||||||
|
show_image(y, 'reconstruction')
|
||||||
|
|
||||||
|
plt.subplot(1, 4, 4)
|
||||||
|
show_image(im_paste, 'reconstruction + visible')
|
||||||
|
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
def post_process(
|
||||||
|
x: torch.Tensor, y: torch.Tensor, mask: torch.Tensor
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
x = torch.einsum('nchw->nhwc', x.cpu())
|
||||||
|
# masked image
|
||||||
|
im_masked = x * (1 - mask)
|
||||||
|
# MAE reconstruction pasted with visible patches
|
||||||
|
im_paste = x * (1 - mask) + y * mask
|
||||||
|
return x[0], im_masked[0], y[0], im_paste[0]
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = ArgumentParser()
|
||||||
|
parser.add_argument('img', help='Image file')
|
||||||
|
parser.add_argument('config', help='MAE Config file')
|
||||||
|
parser.add_argument('checkpoint', help='Checkpoint file')
|
||||||
|
parser.add_argument(
|
||||||
|
'--device', default='cuda:0', help='Device used for inference')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# build the model from a config file and a checkpoint file
|
||||||
|
model = init_model(args.config, args.checkpoint, device=args.device)
|
||||||
|
print('Model loaded.')
|
||||||
|
|
||||||
|
# make random mask reproducible (comment out to make it change)
|
||||||
|
torch.manual_seed(2)
|
||||||
|
print('MAE with pixel reconstruction:')
|
||||||
|
|
||||||
|
img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||||
|
model.cfg.data = dict(
|
||||||
|
test=dict(pipeline=[
|
||||||
|
dict(type='Resize', size=(224, 224)),
|
||||||
|
dict(type='ToTensor'),
|
||||||
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
|
]))
|
||||||
|
|
||||||
|
img = Image.open(args.img)
|
||||||
|
img, (mask, pred) = inference_model(model, img)
|
||||||
|
x, im_masked, y, im_paste = post_process(img, pred, mask)
|
||||||
|
show_images(x, im_masked, y, im_paste)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
Loading…
x
Reference in New Issue
Block a user