Add support for converting a inpainting model to ONNX and TensorRT (#1831)
* Add support for inpainting models * Add configs * Add comment * Refactor * Add test code for inpainting task * Fix * Fix * Update * Fix * Fix * Update docs * Update * Fix visualization * Handle case without Resizepull/1944/head
parent
aae9f32623
commit
f7c484a046
|
@ -0,0 +1,20 @@
|
|||
_base_ = ['./inpainting_static.py']
|
||||
|
||||
onnx_config = dict(
|
||||
dynamic_axes=dict(
|
||||
masked_img={
|
||||
0: 'batch',
|
||||
2: 'height',
|
||||
3: 'width'
|
||||
},
|
||||
mask={
|
||||
0: 'batch',
|
||||
2: 'height',
|
||||
3: 'width'
|
||||
},
|
||||
output={
|
||||
0: 'batch',
|
||||
2: 'height',
|
||||
3: 'width'
|
||||
}),
|
||||
input_shape=None)
|
|
@ -0,0 +1 @@
|
|||
_base_ = ['./inpainting_dynamic.py', '../../_base_/backends/onnxruntime.py']
|
|
@ -0,0 +1,3 @@
|
|||
_base_ = ['./inpainting_static.py', '../../_base_/backends/onnxruntime.py']
|
||||
|
||||
onnx_config = dict(input_shape=[256, 256])
|
|
@ -0,0 +1,5 @@
|
|||
_base_ = ['../../_base_/onnx_config.py']
|
||||
|
||||
codebase_config = dict(type='mmedit', task='Inpainting')
|
||||
onnx_config = dict(
|
||||
input_names=['masked_img', 'mask'], output_names=['fake_img'])
|
|
@ -0,0 +1,17 @@
|
|||
_base_ = ['./inpainting_static.py', '../../_base_/backends/tensorrt-fp16.py']
|
||||
|
||||
onnx_config = dict(input_shape=[256, 256])
|
||||
backend_config = dict(
|
||||
common_config=dict(max_workspace_size=1 << 30),
|
||||
model_inputs=[
|
||||
dict(
|
||||
input_shapes=dict(
|
||||
masked_img=dict(
|
||||
min_shape=[1, 3, 256, 256],
|
||||
opt_shape=[1, 3, 256, 256],
|
||||
max_shape=[1, 3, 256, 256]),
|
||||
mask=dict(
|
||||
min_shape=[1, 1, 256, 256],
|
||||
opt_shape=[1, 1, 256, 256],
|
||||
max_shape=[1, 1, 256, 256])))
|
||||
])
|
|
@ -0,0 +1,17 @@
|
|||
_base_ = ['./inpainting_static.py', '../../_base_/backends/tensorrt-int8.py']
|
||||
|
||||
onnx_config = dict(input_shape=[256, 256])
|
||||
backend_config = dict(
|
||||
common_config=dict(max_workspace_size=1 << 30),
|
||||
model_inputs=[
|
||||
dict(
|
||||
input_shapes=dict(
|
||||
masked_img=dict(
|
||||
min_shape=[1, 3, 256, 256],
|
||||
opt_shape=[1, 3, 256, 256],
|
||||
max_shape=[1, 3, 256, 256]),
|
||||
mask=dict(
|
||||
min_shape=[1, 1, 256, 256],
|
||||
opt_shape=[1, 1, 256, 256],
|
||||
max_shape=[1, 1, 256, 256])))
|
||||
])
|
|
@ -0,0 +1,17 @@
|
|||
_base_ = ['./inpainting_static.py', '../../_base_/backends/tensorrt.py']
|
||||
|
||||
onnx_config = dict(input_shape=[256, 256])
|
||||
backend_config = dict(
|
||||
common_config=dict(max_workspace_size=1 << 30),
|
||||
model_inputs=[
|
||||
dict(
|
||||
input_shapes=dict(
|
||||
masked_img=dict(
|
||||
min_shape=[1, 3, 256, 256],
|
||||
opt_shape=[1, 3, 256, 256],
|
||||
max_shape=[1, 3, 256, 256]),
|
||||
mask=dict(
|
||||
min_shape=[1, 1, 256, 256],
|
||||
opt_shape=[1, 1, 256, 256],
|
||||
max_shape=[1, 1, 256, 256])))
|
||||
])
|
|
@ -8,13 +8,20 @@ Please refer to [official installation guide](https://mmediting.readthedocs.io/e
|
|||
|
||||
## MMEditing models support
|
||||
|
||||
| Model | Task | ONNX Runtime | TensorRT | ncnn | PPLNN | OpenVINO | Model config |
|
||||
| :---------- | :--------------- | :----------: | :------: | :--: | :---: | :------: | :--------------------------------------------------------------------------------------------: |
|
||||
| SRCNN | super-resolution | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/srcnn) |
|
||||
| ESRGAN | super-resolution | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/esrgan) |
|
||||
| ESRGAN-PSNR | super-resolution | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/esrgan) |
|
||||
| SRGAN | super-resolution | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/srresnet_srgan) |
|
||||
| SRResNet | super-resolution | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/srresnet_srgan) |
|
||||
| Real-ESRGAN | super-resolution | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/real_esrgan) |
|
||||
| EDSR | super-resolution | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/edsr) |
|
||||
| RDN | super-resolution | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/rdn) |
|
||||
| Model | Task | ONNX Runtime | TensorRT | ncnn | PPLNN | OpenVINO | Model config |
|
||||
| :------------- | :--------------- | :----------: | :------: | :--: | :---: | :------: | :--------------------------------------------------------------------------------------------: |
|
||||
| SRCNN | super-resolution | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/srcnn) |
|
||||
| ESRGAN | super-resolution | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/esrgan) |
|
||||
| ESRGAN-PSNR | super-resolution | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/esrgan) |
|
||||
| SRGAN | super-resolution | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/srresnet_srgan) |
|
||||
| SRResNet | super-resolution | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/srresnet_srgan) |
|
||||
| Real-ESRGAN | super-resolution | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/real_esrgan) |
|
||||
| EDSR | super-resolution | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/edsr) |
|
||||
| RDN | super-resolution | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/rdn) |
|
||||
| Global&Local\* | inpainting | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/inpainting/global_local) |
|
||||
| DeepFillv1\* | inpainting | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/inpainting/deepfillv1) |
|
||||
| PConv\* | inpainting | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/inpainting/partial_conv) |
|
||||
| DeepFillv2\* | inpainting | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/inpainting/deepfillv2) |
|
||||
| AOT-GAN\* | inpainting | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/inpainting/AOT-GAN) |
|
||||
|
||||
1. We skipped quantitative evaluation for image inpainting due to the high computational cost required for testing.
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .deploy import MMEditing, SuperResolution
|
||||
from .deploy import Inpainting, MMEditing, SuperResolution
|
||||
from .models import * # noqa: F401,F403
|
||||
|
||||
__all__ = ['MMEditing', 'SuperResolution']
|
||||
__all__ = ['MMEditing', 'SuperResolution', 'Inpainting']
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmdeploy.codebase.mmedit.deploy.inpainting import Inpainting
|
||||
from mmdeploy.codebase.mmedit.deploy.mmediting import MMEditing
|
||||
from mmdeploy.codebase.mmedit.deploy.super_resolution import SuperResolution
|
||||
|
||||
__all__ = ['MMEditing', 'SuperResolution']
|
||||
__all__ = ['MMEditing', 'SuperResolution', 'Inpainting']
|
||||
|
|
|
@ -0,0 +1,293 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
from typing import Callable, Optional, Sequence, Tuple, Union
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.parallel import collate, scatter
|
||||
from mmcv.utils import Config
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from mmdeploy.codebase.base import BaseTask
|
||||
from mmdeploy.codebase.mmedit.deploy.mmediting import MMEDIT_TASK
|
||||
from mmdeploy.utils import Task, get_ir_config, load_config
|
||||
|
||||
|
||||
def process_model_config(model_cfg: Config,
|
||||
imgs: Union[Sequence[str], Sequence[np.ndarray]],
|
||||
input_shape: Optional[Sequence[int]] = None):
|
||||
"""Process the model config.
|
||||
|
||||
Args:
|
||||
model_cfg (Config): The model config.
|
||||
imgs (Sequence[str] | Sequence[np.ndarray]): Input image(s), accepted
|
||||
data type are List[str], List[np.ndarray].
|
||||
input_shape (Sequence[int], optional): A list of two integer
|
||||
in (width, height) format specifying input shape. Default: None.
|
||||
|
||||
Returns:
|
||||
Config: the model config after processing.
|
||||
"""
|
||||
config = load_config(model_cfg)[0].copy()
|
||||
load_from_file = isinstance(imgs[0], str)
|
||||
if not load_from_file:
|
||||
# Remove 'LoadImageFromFile'
|
||||
config.test_pipeline.pop(0)
|
||||
|
||||
if input_shape is not None:
|
||||
# Fix the input shape by 'Resize' or 'RandomResizedCrop'
|
||||
for pipeline in config.test_pipeline[::-1]:
|
||||
if pipeline.type == 'Resize':
|
||||
pipeline.scale = tuple(input_shape)
|
||||
break
|
||||
elif pipeline.type == 'RandomResizedCrop':
|
||||
pipeline.crop_size = tuple(input_shape)
|
||||
break
|
||||
|
||||
key = 'gt_img_path'
|
||||
for pipeline in config.test_pipeline:
|
||||
if 'meta_keys' in pipeline:
|
||||
while key in pipeline['meta_keys']:
|
||||
pipeline['meta_keys'].remove(key)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
@MMEDIT_TASK.register_module(Task.INPAINTING.value)
|
||||
class Inpainting(BaseTask):
|
||||
"""BaseTask class of inpainting task.
|
||||
|
||||
Args:
|
||||
model_cfg (Config): Model config file.
|
||||
deploy_cfg (Config): Deployment config file.
|
||||
device (str): A string specifying device type.
|
||||
"""
|
||||
|
||||
def __init__(self, model_cfg: Config, deploy_cfg: Config, device: str):
|
||||
super(Inpainting, self).__init__(model_cfg, deploy_cfg, device)
|
||||
|
||||
def init_backend_model(self,
|
||||
model_files: Sequence[str] = None,
|
||||
**kwargs) -> nn.Module:
|
||||
"""Initialize backend model.
|
||||
|
||||
Args:
|
||||
model_files (Sequence[str]): Input model files.
|
||||
|
||||
Returns:
|
||||
nn.Module: An initialized backend model.
|
||||
"""
|
||||
from .inpainting_model import build_inpainting_model
|
||||
return build_inpainting_model(
|
||||
model_files,
|
||||
self.model_cfg,
|
||||
self.deploy_cfg,
|
||||
device=self.device,
|
||||
**kwargs)
|
||||
|
||||
def init_pytorch_model(self,
|
||||
model_checkpoint: Optional[str] = None,
|
||||
cfg_options: Optional[dict] = None,
|
||||
**kwargs) -> nn.Module:
|
||||
"""Initialize torch model.
|
||||
|
||||
Args:
|
||||
model_checkpoint (str): The checkpoint file of torch model,
|
||||
defaults to `None`.
|
||||
cfg_options (dict): Optional config key-pair parameters.
|
||||
|
||||
Returns:
|
||||
nn.Module: An initialized torch model generated by other OpenMMLab
|
||||
codebases.
|
||||
"""
|
||||
from mmedit.apis import init_model
|
||||
model = init_model(self.model_cfg, model_checkpoint, self.device)
|
||||
|
||||
forward_test = model.forward_test
|
||||
model.forward_test = lambda *args, **kwargs: {
|
||||
k: v
|
||||
for k, v in forward_test(*args, **kwargs).items()
|
||||
if k in get_ir_config(self.deploy_cfg).output_names
|
||||
}
|
||||
|
||||
return model.eval()
|
||||
|
||||
def create_input(self,
|
||||
imgs: Union[str, np.ndarray],
|
||||
input_shape: Optional[Sequence[int]] = None,
|
||||
pipeline_updater: Optional[Callable] = None,
|
||||
**kwargs) -> Tuple[dict, Tuple[torch.Tensor]]:
|
||||
"""Create input for model.
|
||||
|
||||
Args:
|
||||
imgs (str | np.ndarray | Sequence): Input image(s),
|
||||
accepted data types are `str`, `np.ndarray`.
|
||||
input_shape (Sequence[int] | None): Input shape of image in
|
||||
(width, height) format, defaults to `None`.
|
||||
pipeline_updater (function | None): A function to get a new
|
||||
pipeline.
|
||||
|
||||
Returns:
|
||||
tuple: (data, tuple), meta information for the input image and
|
||||
input tensors.
|
||||
"""
|
||||
from mmedit.datasets.pipelines import Compose
|
||||
|
||||
if isinstance(imgs, (list, tuple)):
|
||||
if not isinstance(imgs[0], (np.ndarray, str)):
|
||||
raise AssertionError('imgs must be strings or numpy arrays')
|
||||
elif isinstance(imgs, (np.ndarray, str)):
|
||||
imgs = [imgs]
|
||||
else:
|
||||
raise AssertionError('imgs must be strings or numpy arrays')
|
||||
|
||||
cfg = process_model_config(self.model_cfg, imgs, input_shape)
|
||||
|
||||
test_pipeline = Compose(cfg.test_pipeline)
|
||||
|
||||
data_arr = []
|
||||
for img in imgs:
|
||||
if isinstance(img, np.ndarray):
|
||||
data = dict(gt_img=img)
|
||||
else:
|
||||
data = dict(gt_img_path=img)
|
||||
|
||||
data = test_pipeline(data)
|
||||
data_arr.append(data)
|
||||
|
||||
data = collate(data_arr, samples_per_gpu=len(imgs))
|
||||
|
||||
if self.device != 'cpu':
|
||||
data = scatter(data, [self.device])[0]
|
||||
|
||||
return data, (data['masked_img'], data['mask'])
|
||||
|
||||
def visualize(self,
|
||||
model: torch.nn.Module,
|
||||
image: Union[str, np.ndarray],
|
||||
result: list,
|
||||
output_file: str,
|
||||
window_name: str = '',
|
||||
show_result: bool = False,
|
||||
**kwargs) -> None:
|
||||
"""Visualize predictions of a model.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Input model.
|
||||
image (str | np.ndarray): Input image to draw predictions on.
|
||||
result (list): A list of predictions.
|
||||
output_file (str): Output file to save drawn image.
|
||||
window_name (str): The name of visualization window. Defaults to
|
||||
an empty string.
|
||||
show_result (bool): Whether to show result in windows, defaults
|
||||
to `False`.
|
||||
"""
|
||||
if len(result.shape) == 4:
|
||||
result = result[0]
|
||||
|
||||
result = result.transpose(1, 2, 0)
|
||||
result = (result + 1) * 127.5
|
||||
result = np.clip(result, 0, 255)
|
||||
|
||||
if show_result:
|
||||
int_result = result.astype(np.uint8)
|
||||
mmcv.imshow(int_result, window_name, 0)
|
||||
|
||||
output_file = None if show_result else output_file
|
||||
if output_file is not None:
|
||||
mmcv.imwrite(result, output_file)
|
||||
|
||||
if not (show_result or output_file):
|
||||
warnings.warn(
|
||||
'show_result==False and output_file is not specified, only '
|
||||
'result image will be returned')
|
||||
|
||||
@staticmethod
|
||||
def run_inference(model: nn.Module, model_inputs: dict) -> list:
|
||||
"""Run inference once for a model of a OpenMMLab Codebase.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Input model.
|
||||
model_inputs (dict): A dict containing model inputs tensor and
|
||||
meta info.
|
||||
|
||||
Returns:
|
||||
list: The predictions of model inference.
|
||||
"""
|
||||
results = model(model_inputs['masked_img'], model_inputs['mask'])
|
||||
if isinstance(results, dict):
|
||||
results = [results['fake_img']]
|
||||
|
||||
if not isinstance(results[0], np.ndarray):
|
||||
results = [results[0].detach().cpu().numpy()]
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def get_partition_cfg(partition_type: str, **kwargs) -> dict:
|
||||
"""Get a certain partition config."""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def get_tensor_from_input(input_data: dict, **kwargs) -> torch.Tensor:
|
||||
"""Get input tensor from input data."""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def evaluate_outputs(model_cfg,
|
||||
outputs: list,
|
||||
dataset: Dataset,
|
||||
metrics: Optional[str] = None,
|
||||
out: Optional[str] = None,
|
||||
metric_options: Optional[dict] = None,
|
||||
format_only: bool = False,
|
||||
log_file: Optional[str] = None,
|
||||
json_file: Optional[str] = None,
|
||||
**kwargs) -> None:
|
||||
"""Evaluation function implemented in mmedit.
|
||||
|
||||
Args:
|
||||
model_cfg (Config): The model config.
|
||||
outputs (list): A list of result of model inference.
|
||||
dataset (Dataset): Input dataset to run test.
|
||||
metrics (str): Evaluation metrics, which depends on
|
||||
the codebase and the dataset, e.g., "PSNR", "SSIM" in mmedit.
|
||||
out (str): Output result file in pickle format, defaults to `None`.
|
||||
metric_options (dict): Custom options for evaluation, will be
|
||||
kwargs for dataset.evaluate() function. Defaults to `None`.
|
||||
format_only (bool): Format the output results without perform
|
||||
evaluation. It is useful when you want to format the result
|
||||
to a specific format and submit it to the test server. Defaults
|
||||
to `False`.
|
||||
log_file (str | None): The file to write the evaluation results.
|
||||
Defaults to `None` and the results will only print on stdout.
|
||||
json_file (str | None): The file to write the evaluation metrics.
|
||||
Defaults to `None`.
|
||||
"""
|
||||
from mmcv.utils import get_logger
|
||||
logger = get_logger('test', log_file=log_file)
|
||||
|
||||
if out:
|
||||
logger.debug(f'writing results to {out}')
|
||||
mmcv.dump(outputs, out)
|
||||
|
||||
stats = dataset.evaluate(outputs)
|
||||
if json_file is not None:
|
||||
mmcv.dump(stats, json_file, indent=4)
|
||||
|
||||
print('')
|
||||
for stat in stats:
|
||||
logger.info('Eval-{}: {}'.format(stat, stats[stat]))
|
||||
|
||||
def get_preprocess(self) -> dict:
|
||||
"""Get the preprocess information for SDK."""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_postprocess(self) -> dict:
|
||||
"""Get the postprocess information for SDK."""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_model_name(self) -> str:
|
||||
"""Get the model name."""
|
||||
raise NotImplementedError
|
|
@ -0,0 +1,216 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
from typing import Sequence, Union
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmcv.utils import Config, Registry
|
||||
from mmedit.core import L1Evaluation, psnr, ssim, tensor2img
|
||||
|
||||
from mmdeploy.codebase.base import BaseBackendModel
|
||||
from mmdeploy.utils import (Backend, get_backend, get_codebase_config,
|
||||
get_ir_config, load_config)
|
||||
|
||||
|
||||
def __build_backend_model(cls_name: str, registry: Registry, *args, **kwargs):
|
||||
return registry.module_dict[cls_name](*args, **kwargs)
|
||||
|
||||
|
||||
__BACKEND_MODEL = mmcv.utils.Registry(
|
||||
'backend_models', build_func=__build_backend_model)
|
||||
|
||||
|
||||
@__BACKEND_MODEL.register_module('end2end')
|
||||
class End2EndModel(BaseBackendModel):
|
||||
"""End to end model for inference of inpainting.
|
||||
|
||||
Args:
|
||||
backend (Backend): The backend enum, specifying backend type.
|
||||
backend_files (Sequence[str]): Paths to all required backend files
|
||||
(e.g. '.onnx' for ONNX Runtime).
|
||||
device (str): A string represents device type.
|
||||
model_cfg(Config): Input model config object.
|
||||
deploy_cfg(str | Config): Deployment config file or loaded Config
|
||||
object.
|
||||
"""
|
||||
_eval_metrics = dict(l1=L1Evaluation, psnr=psnr, ssim=ssim)
|
||||
|
||||
def __init__(self,
|
||||
backend: Backend,
|
||||
backend_files: Sequence[str],
|
||||
device: str,
|
||||
model_cfg: Config,
|
||||
deploy_cfg: Union[str, Config] = None,
|
||||
**kwargs):
|
||||
super(End2EndModel, self).__init__(deploy_cfg=deploy_cfg)
|
||||
|
||||
self.input_names = get_ir_config(deploy_cfg).input_names
|
||||
|
||||
self.test_cfg = model_cfg.test_cfg
|
||||
|
||||
# init wrapper
|
||||
self.wrapper = self._build_wrapper(
|
||||
backend=backend,
|
||||
backend_files=backend_files,
|
||||
device=device,
|
||||
input_names=self.input_names,
|
||||
output_names=self.output_names,
|
||||
deploy_cfg=deploy_cfg,
|
||||
**kwargs)
|
||||
|
||||
def forward(self,
|
||||
masked_img: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
test_mode: bool = False,
|
||||
*args,
|
||||
**kwargs) -> Union[list, dict]:
|
||||
"""Run test inference for inpainting.
|
||||
|
||||
We want forward() to output an image or a evaluation result.
|
||||
When test_mode is set, the output is evaluation result. Otherwise
|
||||
it is an image.
|
||||
|
||||
Args:
|
||||
masked_img (torch.Tensor): Image with hole as input.
|
||||
mask (torch.Tensor): Mask as input.
|
||||
test_mode (bool, optional): Whether use testing mode.
|
||||
Defaults to True.
|
||||
|
||||
Returns:
|
||||
list | dict: Inpainted image or a evaluation results.
|
||||
"""
|
||||
if test_mode:
|
||||
return self.forward_test(masked_img, mask, *args, **kwargs)
|
||||
|
||||
return self.forward_dummy(masked_img, mask, *args, **kwargs)
|
||||
|
||||
def forward_test(self,
|
||||
masked_img: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
save_path=None,
|
||||
*args,
|
||||
**kwargs):
|
||||
"""Run inference for inpaintor to generate evaluation result.
|
||||
|
||||
Args:
|
||||
masked_img (torch.Tensor): Image with hole as input.
|
||||
mask (torch.Tensor): Mask as input.
|
||||
save_path (str, optional): If given a valid str, the results will
|
||||
be saved in this path. Defaults to None.
|
||||
|
||||
Returns:
|
||||
dict: Evaluation results.
|
||||
"""
|
||||
outputs = self.forward_dummy(masked_img, mask, *args, **kwargs)
|
||||
results = self.test_post_process(outputs, masked_img, mask, *args,
|
||||
**kwargs)
|
||||
|
||||
if save_path is not None:
|
||||
outputs = [torch.from_numpy(i).flip(1) for i in outputs]
|
||||
|
||||
filename, _ = osp.splitext(
|
||||
osp.basename(kwargs['meta'][0]['gt_img_path']))
|
||||
save_path = osp.join(save_path, f'{filename}.png')
|
||||
mmcv.imwrite(tensor2img(outputs, min_max=(-1, 1)), save_path)
|
||||
|
||||
return results
|
||||
|
||||
def forward_dummy(self, masked_img: torch.Tensor, mask: torch.Tensor,
|
||||
*args, **kwargs):
|
||||
"""Run test inference for inpaintor with backend wrapper.
|
||||
|
||||
Args:
|
||||
masked_img (torch.Tensor): Image with hole as input.
|
||||
mask (torch.Tensor): Mask as input.
|
||||
|
||||
Returns:
|
||||
list[np.ndarray] : Inpainted image.
|
||||
"""
|
||||
inputs = dict(masked_img=masked_img, mask=mask)
|
||||
outputs = self.wrapper(inputs)
|
||||
outputs = self.wrapper.output_to_list(outputs)
|
||||
outputs = [out.detach().cpu().numpy() for out in outputs]
|
||||
return outputs
|
||||
|
||||
def evaluate(self, output: Union[torch.Tensor, np.ndarray], masked_img,
|
||||
mask, **kwargs):
|
||||
"""Evaluation function implemented in mmedit.
|
||||
|
||||
Args:
|
||||
output (torch.Tensor | np.ndarray): Model output with
|
||||
shape (n, c, h, w).
|
||||
masked_img (torch.Tensor): Image with hole as input.
|
||||
mask (torch.Tensor): Mask as input.
|
||||
|
||||
Returns:
|
||||
dict: Evaluation results.
|
||||
"""
|
||||
|
||||
if isinstance(output, np.ndarray):
|
||||
output = torch.from_numpy(output)
|
||||
gt_img = kwargs['gt_img'].cpu()
|
||||
|
||||
eval_result = dict()
|
||||
data_dict = dict(gt_img=gt_img, fake_img=output, mask=mask.cpu())
|
||||
for metric in self.test_cfg.metrics:
|
||||
if metric in ['ssim', 'psnr']:
|
||||
eval_result[metric] = self._eval_metrics[metric](
|
||||
tensor2img(output, min_max=(-1, 1)),
|
||||
tensor2img(gt_img, min_max=(-1, 1)),
|
||||
)
|
||||
else:
|
||||
eval_result[metric] = self._eval_metrics[metric]()(
|
||||
data_dict).item()
|
||||
|
||||
return eval_result
|
||||
|
||||
def test_post_process(self, outputs: list, masked_img, mask, *args,
|
||||
**kwargs):
|
||||
"""Get evaluation results by post-processing model outputs.
|
||||
|
||||
Args:
|
||||
output (list[np.ndarray]) : The output inpainted image.
|
||||
masked_img (torch.Tensor): Image with hole as input.
|
||||
mask (torch.Tensor): Mask as input.
|
||||
|
||||
Returns:
|
||||
dict: Evaluation results.
|
||||
"""
|
||||
if self.test_cfg is not None and self.test_cfg.get('metrics', None):
|
||||
assert 'gt_img' in kwargs, (
|
||||
'evaluation with metrics must have gt images.')
|
||||
results = dict(
|
||||
eval_result=self.evaluate(outputs[0], masked_img, mask, *args,
|
||||
**kwargs))
|
||||
else:
|
||||
results = dict(masked_img=masked_img, fake_img=outputs)
|
||||
if 'gt_img' in kwargs:
|
||||
results['gt_img'] = kwargs['gt_img'].cpu()
|
||||
|
||||
return results
|
||||
|
||||
def show_result(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def build_inpainting_model(model_files: Sequence[str],
|
||||
model_cfg: Union[str, Config],
|
||||
deploy_cfg: Union[str,
|
||||
Config], device: str, **kwargs):
|
||||
model_cfg = load_config(model_cfg)[0]
|
||||
deploy_cfg = load_config(deploy_cfg)[0]
|
||||
|
||||
backend = get_backend(deploy_cfg)
|
||||
model_type = get_codebase_config(deploy_cfg).get('model_type', 'end2end')
|
||||
|
||||
backend_model = __BACKEND_MODEL.build(
|
||||
model_type,
|
||||
backend=backend,
|
||||
backend_files=model_files,
|
||||
device=device,
|
||||
model_cfg=model_cfg,
|
||||
deploy_cfg=deploy_cfg,
|
||||
**kwargs)
|
||||
|
||||
return backend_model
|
|
@ -0,0 +1,2 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from . import contextual_attention # noqa: F401,F403
|
|
@ -0,0 +1,30 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
|
||||
|
||||
def _shape(x):
|
||||
return torch.gather(torch.tensor(x.shape), 0, torch.tensor(
|
||||
(0, 1, 2, 3))).tolist()
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmedit.models.common.contextual_attention.ContextualAttentionModule.'
|
||||
'patch_correlation')
|
||||
def contextual_attention__patch_correlation(ctx, self, x, kernel):
|
||||
# Force tensor shape to avoid the following RuntimeError:
|
||||
# Unsupported: ONNX export of convolution for kernel of unknown shape.
|
||||
kernel = kernel.reshape(_shape(kernel))
|
||||
return ctx.origin_func(self, x, kernel)
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmedit.models.common.contextual_attention.ContextualAttentionModule.'
|
||||
'patch_copy_deconv')
|
||||
def contextual_attention__patch_copy_deconv(ctx, self, attention_score,
|
||||
context_filter):
|
||||
# Force tensor shape to avoid the following RuntimeError:
|
||||
# Unsupported: ONNX export of convolution for kernel of unknown shape.
|
||||
context_filter = context_filter.reshape(_shape(context_filter))
|
||||
return ctx.origin_func(self, attention_score, context_filter)
|
|
@ -29,6 +29,7 @@ class Task(AdvancedEnum):
|
|||
POSE_DETECTION = 'PoseDetection'
|
||||
ROTATED_DETECTION = 'RotatedDetection'
|
||||
VIDEO_RECOGNITION = 'VideoRecognition'
|
||||
INPAINTING = 'Inpainting'
|
||||
|
||||
|
||||
class Codebase(AdvancedEnum):
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
imgs/blank.jpg
|
|
@ -0,0 +1,109 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
exp_name = 'deepfillv1_256x256_8x2_places'
|
||||
|
||||
model = dict(
|
||||
type='DeepFillv1Inpaintor',
|
||||
encdec=dict(
|
||||
type='DeepFillEncoderDecoder',
|
||||
stage1=dict(
|
||||
type='GLEncoderDecoder',
|
||||
encoder=dict(type='DeepFillEncoder', padding_mode='reflect'),
|
||||
decoder=dict(
|
||||
type='DeepFillDecoder',
|
||||
in_channels=128,
|
||||
padding_mode='reflect'),
|
||||
dilation_neck=dict(
|
||||
type='GLDilationNeck',
|
||||
in_channels=128,
|
||||
act_cfg=dict(type='ELU'),
|
||||
padding_mode='reflect')),
|
||||
stage2=dict(
|
||||
type='DeepFillRefiner',
|
||||
encoder_attention=dict(
|
||||
type='DeepFillEncoder',
|
||||
encoder_type='stage2_attention',
|
||||
padding_mode='reflect'),
|
||||
encoder_conv=dict(
|
||||
type='DeepFillEncoder',
|
||||
encoder_type='stage2_conv',
|
||||
padding_mode='reflect'),
|
||||
dilation_neck=dict(
|
||||
type='GLDilationNeck',
|
||||
in_channels=128,
|
||||
act_cfg=dict(type='ELU'),
|
||||
padding_mode='reflect'),
|
||||
contextual_attention=dict(
|
||||
type='ContextualAttentionNeck',
|
||||
in_channels=128,
|
||||
padding_mode='reflect'),
|
||||
decoder=dict(
|
||||
type='DeepFillDecoder',
|
||||
in_channels=256,
|
||||
padding_mode='reflect'))),
|
||||
disc=dict(
|
||||
type='DeepFillv1Discriminators',
|
||||
global_disc_cfg=dict(
|
||||
type='MultiLayerDiscriminator',
|
||||
in_channels=3,
|
||||
max_channels=256,
|
||||
fc_in_channels=65536,
|
||||
fc_out_channels=1,
|
||||
num_convs=4,
|
||||
norm_cfg=None,
|
||||
act_cfg=dict(type='ELU'),
|
||||
out_act_cfg=dict(type='LeakyReLU', negative_slope=0.2)),
|
||||
local_disc_cfg=dict(
|
||||
type='MultiLayerDiscriminator',
|
||||
in_channels=3,
|
||||
max_channels=512,
|
||||
fc_in_channels=32768,
|
||||
fc_out_channels=1,
|
||||
num_convs=4,
|
||||
norm_cfg=None,
|
||||
act_cfg=dict(type='ELU'),
|
||||
out_act_cfg=dict(type='LeakyReLU', negative_slope=0.2))),
|
||||
stage1_loss_type=('loss_l1_hole', 'loss_l1_valid'),
|
||||
stage2_loss_type=('loss_l1_hole', 'loss_l1_valid', 'loss_gan'),
|
||||
loss_gan=dict(type='GANLoss', gan_type='wgan', loss_weight=0.0001),
|
||||
loss_l1_hole=dict(type='L1Loss', loss_weight=1.0),
|
||||
loss_l1_valid=dict(type='L1Loss', loss_weight=1.0),
|
||||
loss_gp=dict(type='GradientPenaltyLoss', loss_weight=10.0),
|
||||
loss_disc_shift=dict(type='DiscShiftLoss', loss_weight=0.001),
|
||||
pretrained=None)
|
||||
|
||||
test_cfg = dict(metrics=['l1', 'psnr', 'ssim'])
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile', key='gt_img'),
|
||||
dict(
|
||||
type='LoadMask',
|
||||
mask_mode='bbox',
|
||||
mask_config=dict(
|
||||
max_bbox_shape=(128, 128),
|
||||
max_bbox_delta=40,
|
||||
min_margin=20,
|
||||
img_shape=(256, 256))),
|
||||
dict(type='Crop', keys=['gt_img'], crop_size=(384, 384), random_crop=True),
|
||||
dict(type='Resize', keys=['gt_img'], scale=(256, 256), keep_ratio=False),
|
||||
dict(
|
||||
type='Normalize',
|
||||
keys=['gt_img'],
|
||||
mean=[127.5, 127.5, 127.5],
|
||||
std=[127.5, 127.5, 127.5],
|
||||
to_rgb=False),
|
||||
dict(type='GetMaskedImage'),
|
||||
dict(
|
||||
type='Collect',
|
||||
keys=['gt_img', 'masked_img', 'mask', 'mask_bbox'],
|
||||
meta_keys=['gt_img_path']),
|
||||
dict(type='ImageToTensor', keys=['gt_img', 'masked_img', 'mask']),
|
||||
dict(type='ToTensor', keys=['mask_bbox'])
|
||||
]
|
||||
data = dict(
|
||||
test_dataloader=dict(samples_per_gpu=1),
|
||||
test=dict(
|
||||
type='ImgInpaintingDataset',
|
||||
ann_file='tests/test_codebase/test_mmedit/data/ann_file.txt',
|
||||
data_prefix='tests/test_codebase/test_mmedit/data',
|
||||
pipeline=test_pipeline,
|
||||
test_mode=True))
|
|
@ -0,0 +1,128 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmdeploy.apis import build_task_processor
|
||||
from mmdeploy.utils import load_config
|
||||
from mmdeploy.utils.test import SwitchBackendWrapper
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
def model_cfg():
|
||||
cfg = 'tests/test_codebase/test_mmedit/data/inpainting_model.py'
|
||||
return load_config(cfg)[0]
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
def deploy_cfg():
|
||||
return mmcv.Config(
|
||||
dict(
|
||||
backend_config=dict(type='onnxruntime'),
|
||||
codebase_config=dict(type='mmedit', task='Inpainting'),
|
||||
onnx_config=dict(
|
||||
type='onnx',
|
||||
export_params=True,
|
||||
keep_initializers_as_inputs=False,
|
||||
opset_version=11,
|
||||
input_shape=None,
|
||||
input_names=['masked_img', 'mask'],
|
||||
output_names=['fake_img'])))
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
def input_img():
|
||||
return np.random.rand(32, 32, 3)
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
def model_input():
|
||||
return dict(
|
||||
masked_img=np.random.rand(32, 32, 3),
|
||||
mask=np.random.randint(0, 2, (32, 32)))
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
def task_processor(model_cfg, deploy_cfg):
|
||||
return build_task_processor(model_cfg, deploy_cfg, device='cpu')
|
||||
|
||||
|
||||
def test_init_pytorch_model(task_processor):
|
||||
torch_model = task_processor.init_pytorch_model(model_checkpoint=None)
|
||||
assert torch_model is not None
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
def backend_model(task_processor):
|
||||
from mmdeploy.backend.onnxruntime import ORTWrapper
|
||||
with SwitchBackendWrapper(ORTWrapper) as wrapper:
|
||||
wrapper.set(outputs=dict(fake_img=torch.rand(3, 32, 32)))
|
||||
yield task_processor.init_backend_model([''])
|
||||
|
||||
|
||||
def test_init_backend_model(backend_model):
|
||||
assert backend_model is not None
|
||||
|
||||
|
||||
def test_create_input(task_processor, input_img):
|
||||
inputs, _ = task_processor.create_input(
|
||||
input_img, img_shape=input_img.shape[:2])
|
||||
assert 'masked_img' in inputs
|
||||
assert 'mask' in inputs
|
||||
|
||||
|
||||
def test_visualize(backend_model, task_processor, model_input, input_img):
|
||||
result = task_processor.run_inference(backend_model, model_input)
|
||||
with tempfile.TemporaryDirectory() as dir:
|
||||
filename = dir + 'tmp.jpg'
|
||||
task_processor.visualize(backend_model, input_img, result[0], filename,
|
||||
'onnxruntime')
|
||||
assert os.path.exists(filename)
|
||||
|
||||
|
||||
def test_run_inference(backend_model, task_processor, model_input):
|
||||
results = task_processor.run_inference(backend_model, model_input)
|
||||
assert results is not None
|
||||
|
||||
|
||||
def test_get_tensor_from_input(task_processor, model_input):
|
||||
with pytest.raises(NotImplementedError):
|
||||
task_processor.get_tensor_from_input(model_input)
|
||||
|
||||
|
||||
def test_get_partition_cfg(task_processor):
|
||||
with pytest.raises(NotImplementedError):
|
||||
task_processor.get_partition_cfg(None)
|
||||
|
||||
|
||||
def test_build_dataset(task_processor):
|
||||
data = dict(
|
||||
test=dict(
|
||||
type='ImgInpaintingDataset',
|
||||
ann_file='tests/test_codebase/test_mmedit/data/ann_file.txt',
|
||||
data_prefix='tests/test_codebase/test_mmedit/data',
|
||||
pipeline=[
|
||||
dict(type='LoadImageFromFile', key='gt_img'),
|
||||
dict(type='LoadMask')
|
||||
]))
|
||||
dataset_cfg = mmcv.Config(dict(data=data))
|
||||
dataset = task_processor.build_dataset(
|
||||
dataset_cfg=dataset_cfg, dataset_type='test')
|
||||
assert dataset is not None, 'Failed to build dataset'
|
||||
dataloader = task_processor.build_dataloader(dataset, 1, 1)
|
||||
assert dataloader is not None, 'Failed to build dataloader'
|
||||
|
||||
|
||||
def test_single_gpu_test(backend_model, model_cfg, task_processor):
|
||||
from mmcv.parallel import MMDataParallel
|
||||
dataset = task_processor.build_dataset(model_cfg, dataset_type='test')
|
||||
assert dataset is not None, 'Failed to build dataset'
|
||||
dataloader = task_processor.build_dataloader(dataset, 1, 1)
|
||||
assert dataloader is not None, 'Failed to build dataloader'
|
||||
backend_model = MMDataParallel(backend_model, device_ids=[0])
|
||||
outputs = task_processor.single_gpu_test(backend_model, dataloader)
|
||||
assert outputs is not None, 'Failed to test model'
|
|
@ -0,0 +1,48 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmdeploy.utils import Backend, load_config
|
||||
from mmdeploy.utils.test import SwitchBackendWrapper, backend_checker
|
||||
|
||||
|
||||
@backend_checker(Backend.ONNXRUNTIME)
|
||||
class TestEnd2EndModel:
|
||||
|
||||
@pytest.fixture(scope='class')
|
||||
def end2end_model(self):
|
||||
# force add backend wrapper regardless of plugins
|
||||
# make sure ONNXRuntimeEditor can use ORTWrapper inside itself
|
||||
from mmdeploy.backend.onnxruntime import ORTWrapper
|
||||
from mmdeploy.codebase.mmedit.deploy.inpainting_model import \
|
||||
End2EndModel
|
||||
|
||||
# simplify backend inference
|
||||
with SwitchBackendWrapper(ORTWrapper) as wrapper:
|
||||
wrapper.set(outputs=dict(fake_img=torch.rand(3, 32, 32)))
|
||||
deploy_cfg = mmcv.Config(
|
||||
dict(
|
||||
onnx_config=dict(
|
||||
input_names=['masked_img', 'mask'],
|
||||
output_names=['fake_img'])))
|
||||
model_cfg = load_config(
|
||||
'tests/test_codebase/test_mmedit/data/inpainting_model.py')[0]
|
||||
model = End2EndModel(Backend.ONNXRUNTIME, [''], 'cpu', model_cfg,
|
||||
deploy_cfg)
|
||||
yield model
|
||||
|
||||
def test_forward(self, end2end_model):
|
||||
masked_img = np.random.rand(3, 32, 32)
|
||||
mask = np.random.randint(0, 2, (1, 32, 32))
|
||||
|
||||
results = end2end_model.forward(masked_img, mask, test_mode=False)
|
||||
assert results is not None
|
||||
|
||||
results = end2end_model.forward(
|
||||
masked_img,
|
||||
torch.tensor(mask),
|
||||
test_mode=True,
|
||||
gt_img=torch.tensor(results[0]))
|
||||
assert results is not None
|
Loading…
Reference in New Issue