mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
Add support for converting a inpainting model to ONNX and TensorRT (#1831)
* Add support for inpainting models * Add configs * Add comment * Refactor * Add test code for inpainting task * Fix * Fix * Update * Fix * Fix * Update docs * Update * Fix visualization * Handle case without Resize
This commit is contained in:
parent
aae9f32623
commit
f7c484a046
20
configs/mmedit/inpainting/inpainting_dynamic.py
Normal file
20
configs/mmedit/inpainting/inpainting_dynamic.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
_base_ = ['./inpainting_static.py']
|
||||||
|
|
||||||
|
onnx_config = dict(
|
||||||
|
dynamic_axes=dict(
|
||||||
|
masked_img={
|
||||||
|
0: 'batch',
|
||||||
|
2: 'height',
|
||||||
|
3: 'width'
|
||||||
|
},
|
||||||
|
mask={
|
||||||
|
0: 'batch',
|
||||||
|
2: 'height',
|
||||||
|
3: 'width'
|
||||||
|
},
|
||||||
|
output={
|
||||||
|
0: 'batch',
|
||||||
|
2: 'height',
|
||||||
|
3: 'width'
|
||||||
|
}),
|
||||||
|
input_shape=None)
|
@ -0,0 +1 @@
|
|||||||
|
_base_ = ['./inpainting_dynamic.py', '../../_base_/backends/onnxruntime.py']
|
@ -0,0 +1,3 @@
|
|||||||
|
_base_ = ['./inpainting_static.py', '../../_base_/backends/onnxruntime.py']
|
||||||
|
|
||||||
|
onnx_config = dict(input_shape=[256, 256])
|
5
configs/mmedit/inpainting/inpainting_static.py
Normal file
5
configs/mmedit/inpainting/inpainting_static.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
_base_ = ['../../_base_/onnx_config.py']
|
||||||
|
|
||||||
|
codebase_config = dict(type='mmedit', task='Inpainting')
|
||||||
|
onnx_config = dict(
|
||||||
|
input_names=['masked_img', 'mask'], output_names=['fake_img'])
|
@ -0,0 +1,17 @@
|
|||||||
|
_base_ = ['./inpainting_static.py', '../../_base_/backends/tensorrt-fp16.py']
|
||||||
|
|
||||||
|
onnx_config = dict(input_shape=[256, 256])
|
||||||
|
backend_config = dict(
|
||||||
|
common_config=dict(max_workspace_size=1 << 30),
|
||||||
|
model_inputs=[
|
||||||
|
dict(
|
||||||
|
input_shapes=dict(
|
||||||
|
masked_img=dict(
|
||||||
|
min_shape=[1, 3, 256, 256],
|
||||||
|
opt_shape=[1, 3, 256, 256],
|
||||||
|
max_shape=[1, 3, 256, 256]),
|
||||||
|
mask=dict(
|
||||||
|
min_shape=[1, 1, 256, 256],
|
||||||
|
opt_shape=[1, 1, 256, 256],
|
||||||
|
max_shape=[1, 1, 256, 256])))
|
||||||
|
])
|
@ -0,0 +1,17 @@
|
|||||||
|
_base_ = ['./inpainting_static.py', '../../_base_/backends/tensorrt-int8.py']
|
||||||
|
|
||||||
|
onnx_config = dict(input_shape=[256, 256])
|
||||||
|
backend_config = dict(
|
||||||
|
common_config=dict(max_workspace_size=1 << 30),
|
||||||
|
model_inputs=[
|
||||||
|
dict(
|
||||||
|
input_shapes=dict(
|
||||||
|
masked_img=dict(
|
||||||
|
min_shape=[1, 3, 256, 256],
|
||||||
|
opt_shape=[1, 3, 256, 256],
|
||||||
|
max_shape=[1, 3, 256, 256]),
|
||||||
|
mask=dict(
|
||||||
|
min_shape=[1, 1, 256, 256],
|
||||||
|
opt_shape=[1, 1, 256, 256],
|
||||||
|
max_shape=[1, 1, 256, 256])))
|
||||||
|
])
|
@ -0,0 +1,17 @@
|
|||||||
|
_base_ = ['./inpainting_static.py', '../../_base_/backends/tensorrt.py']
|
||||||
|
|
||||||
|
onnx_config = dict(input_shape=[256, 256])
|
||||||
|
backend_config = dict(
|
||||||
|
common_config=dict(max_workspace_size=1 << 30),
|
||||||
|
model_inputs=[
|
||||||
|
dict(
|
||||||
|
input_shapes=dict(
|
||||||
|
masked_img=dict(
|
||||||
|
min_shape=[1, 3, 256, 256],
|
||||||
|
opt_shape=[1, 3, 256, 256],
|
||||||
|
max_shape=[1, 3, 256, 256]),
|
||||||
|
mask=dict(
|
||||||
|
min_shape=[1, 1, 256, 256],
|
||||||
|
opt_shape=[1, 1, 256, 256],
|
||||||
|
max_shape=[1, 1, 256, 256])))
|
||||||
|
])
|
@ -8,13 +8,20 @@ Please refer to [official installation guide](https://mmediting.readthedocs.io/e
|
|||||||
|
|
||||||
## MMEditing models support
|
## MMEditing models support
|
||||||
|
|
||||||
| Model | Task | ONNX Runtime | TensorRT | ncnn | PPLNN | OpenVINO | Model config |
|
| Model | Task | ONNX Runtime | TensorRT | ncnn | PPLNN | OpenVINO | Model config |
|
||||||
| :---------- | :--------------- | :----------: | :------: | :--: | :---: | :------: | :--------------------------------------------------------------------------------------------: |
|
| :------------- | :--------------- | :----------: | :------: | :--: | :---: | :------: | :--------------------------------------------------------------------------------------------: |
|
||||||
| SRCNN | super-resolution | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/srcnn) |
|
| SRCNN | super-resolution | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/srcnn) |
|
||||||
| ESRGAN | super-resolution | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/esrgan) |
|
| ESRGAN | super-resolution | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/esrgan) |
|
||||||
| ESRGAN-PSNR | super-resolution | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/esrgan) |
|
| ESRGAN-PSNR | super-resolution | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/esrgan) |
|
||||||
| SRGAN | super-resolution | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/srresnet_srgan) |
|
| SRGAN | super-resolution | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/srresnet_srgan) |
|
||||||
| SRResNet | super-resolution | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/srresnet_srgan) |
|
| SRResNet | super-resolution | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/srresnet_srgan) |
|
||||||
| Real-ESRGAN | super-resolution | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/real_esrgan) |
|
| Real-ESRGAN | super-resolution | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/real_esrgan) |
|
||||||
| EDSR | super-resolution | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/edsr) |
|
| EDSR | super-resolution | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/edsr) |
|
||||||
| RDN | super-resolution | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/rdn) |
|
| RDN | super-resolution | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/restorers/rdn) |
|
||||||
|
| Global&Local\* | inpainting | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/inpainting/global_local) |
|
||||||
|
| DeepFillv1\* | inpainting | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/inpainting/deepfillv1) |
|
||||||
|
| PConv\* | inpainting | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/inpainting/partial_conv) |
|
||||||
|
| DeepFillv2\* | inpainting | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/inpainting/deepfillv2) |
|
||||||
|
| AOT-GAN\* | inpainting | Y | Y | N | N | N | [config](https://github.com/open-mmlab/mmediting/tree/master/configs/inpainting/AOT-GAN) |
|
||||||
|
|
||||||
|
1. We skipped quantitative evaluation for image inpainting due to the high computational cost required for testing.
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from .deploy import MMEditing, SuperResolution
|
from .deploy import Inpainting, MMEditing, SuperResolution
|
||||||
|
from .models import * # noqa: F401,F403
|
||||||
|
|
||||||
__all__ = ['MMEditing', 'SuperResolution']
|
__all__ = ['MMEditing', 'SuperResolution', 'Inpainting']
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from mmdeploy.codebase.mmedit.deploy.inpainting import Inpainting
|
||||||
from mmdeploy.codebase.mmedit.deploy.mmediting import MMEditing
|
from mmdeploy.codebase.mmedit.deploy.mmediting import MMEditing
|
||||||
from mmdeploy.codebase.mmedit.deploy.super_resolution import SuperResolution
|
from mmdeploy.codebase.mmedit.deploy.super_resolution import SuperResolution
|
||||||
|
|
||||||
__all__ = ['MMEditing', 'SuperResolution']
|
__all__ = ['MMEditing', 'SuperResolution', 'Inpainting']
|
||||||
|
293
mmdeploy/codebase/mmedit/deploy/inpainting.py
Normal file
293
mmdeploy/codebase/mmedit/deploy/inpainting.py
Normal file
@ -0,0 +1,293 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import warnings
|
||||||
|
from typing import Callable, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
|
import mmcv
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from mmcv.parallel import collate, scatter
|
||||||
|
from mmcv.utils import Config
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
from mmdeploy.codebase.base import BaseTask
|
||||||
|
from mmdeploy.codebase.mmedit.deploy.mmediting import MMEDIT_TASK
|
||||||
|
from mmdeploy.utils import Task, get_ir_config, load_config
|
||||||
|
|
||||||
|
|
||||||
|
def process_model_config(model_cfg: Config,
|
||||||
|
imgs: Union[Sequence[str], Sequence[np.ndarray]],
|
||||||
|
input_shape: Optional[Sequence[int]] = None):
|
||||||
|
"""Process the model config.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_cfg (Config): The model config.
|
||||||
|
imgs (Sequence[str] | Sequence[np.ndarray]): Input image(s), accepted
|
||||||
|
data type are List[str], List[np.ndarray].
|
||||||
|
input_shape (Sequence[int], optional): A list of two integer
|
||||||
|
in (width, height) format specifying input shape. Default: None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Config: the model config after processing.
|
||||||
|
"""
|
||||||
|
config = load_config(model_cfg)[0].copy()
|
||||||
|
load_from_file = isinstance(imgs[0], str)
|
||||||
|
if not load_from_file:
|
||||||
|
# Remove 'LoadImageFromFile'
|
||||||
|
config.test_pipeline.pop(0)
|
||||||
|
|
||||||
|
if input_shape is not None:
|
||||||
|
# Fix the input shape by 'Resize' or 'RandomResizedCrop'
|
||||||
|
for pipeline in config.test_pipeline[::-1]:
|
||||||
|
if pipeline.type == 'Resize':
|
||||||
|
pipeline.scale = tuple(input_shape)
|
||||||
|
break
|
||||||
|
elif pipeline.type == 'RandomResizedCrop':
|
||||||
|
pipeline.crop_size = tuple(input_shape)
|
||||||
|
break
|
||||||
|
|
||||||
|
key = 'gt_img_path'
|
||||||
|
for pipeline in config.test_pipeline:
|
||||||
|
if 'meta_keys' in pipeline:
|
||||||
|
while key in pipeline['meta_keys']:
|
||||||
|
pipeline['meta_keys'].remove(key)
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
@MMEDIT_TASK.register_module(Task.INPAINTING.value)
|
||||||
|
class Inpainting(BaseTask):
|
||||||
|
"""BaseTask class of inpainting task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_cfg (Config): Model config file.
|
||||||
|
deploy_cfg (Config): Deployment config file.
|
||||||
|
device (str): A string specifying device type.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model_cfg: Config, deploy_cfg: Config, device: str):
|
||||||
|
super(Inpainting, self).__init__(model_cfg, deploy_cfg, device)
|
||||||
|
|
||||||
|
def init_backend_model(self,
|
||||||
|
model_files: Sequence[str] = None,
|
||||||
|
**kwargs) -> nn.Module:
|
||||||
|
"""Initialize backend model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_files (Sequence[str]): Input model files.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
nn.Module: An initialized backend model.
|
||||||
|
"""
|
||||||
|
from .inpainting_model import build_inpainting_model
|
||||||
|
return build_inpainting_model(
|
||||||
|
model_files,
|
||||||
|
self.model_cfg,
|
||||||
|
self.deploy_cfg,
|
||||||
|
device=self.device,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
def init_pytorch_model(self,
|
||||||
|
model_checkpoint: Optional[str] = None,
|
||||||
|
cfg_options: Optional[dict] = None,
|
||||||
|
**kwargs) -> nn.Module:
|
||||||
|
"""Initialize torch model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_checkpoint (str): The checkpoint file of torch model,
|
||||||
|
defaults to `None`.
|
||||||
|
cfg_options (dict): Optional config key-pair parameters.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
nn.Module: An initialized torch model generated by other OpenMMLab
|
||||||
|
codebases.
|
||||||
|
"""
|
||||||
|
from mmedit.apis import init_model
|
||||||
|
model = init_model(self.model_cfg, model_checkpoint, self.device)
|
||||||
|
|
||||||
|
forward_test = model.forward_test
|
||||||
|
model.forward_test = lambda *args, **kwargs: {
|
||||||
|
k: v
|
||||||
|
for k, v in forward_test(*args, **kwargs).items()
|
||||||
|
if k in get_ir_config(self.deploy_cfg).output_names
|
||||||
|
}
|
||||||
|
|
||||||
|
return model.eval()
|
||||||
|
|
||||||
|
def create_input(self,
|
||||||
|
imgs: Union[str, np.ndarray],
|
||||||
|
input_shape: Optional[Sequence[int]] = None,
|
||||||
|
pipeline_updater: Optional[Callable] = None,
|
||||||
|
**kwargs) -> Tuple[dict, Tuple[torch.Tensor]]:
|
||||||
|
"""Create input for model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
imgs (str | np.ndarray | Sequence): Input image(s),
|
||||||
|
accepted data types are `str`, `np.ndarray`.
|
||||||
|
input_shape (Sequence[int] | None): Input shape of image in
|
||||||
|
(width, height) format, defaults to `None`.
|
||||||
|
pipeline_updater (function | None): A function to get a new
|
||||||
|
pipeline.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (data, tuple), meta information for the input image and
|
||||||
|
input tensors.
|
||||||
|
"""
|
||||||
|
from mmedit.datasets.pipelines import Compose
|
||||||
|
|
||||||
|
if isinstance(imgs, (list, tuple)):
|
||||||
|
if not isinstance(imgs[0], (np.ndarray, str)):
|
||||||
|
raise AssertionError('imgs must be strings or numpy arrays')
|
||||||
|
elif isinstance(imgs, (np.ndarray, str)):
|
||||||
|
imgs = [imgs]
|
||||||
|
else:
|
||||||
|
raise AssertionError('imgs must be strings or numpy arrays')
|
||||||
|
|
||||||
|
cfg = process_model_config(self.model_cfg, imgs, input_shape)
|
||||||
|
|
||||||
|
test_pipeline = Compose(cfg.test_pipeline)
|
||||||
|
|
||||||
|
data_arr = []
|
||||||
|
for img in imgs:
|
||||||
|
if isinstance(img, np.ndarray):
|
||||||
|
data = dict(gt_img=img)
|
||||||
|
else:
|
||||||
|
data = dict(gt_img_path=img)
|
||||||
|
|
||||||
|
data = test_pipeline(data)
|
||||||
|
data_arr.append(data)
|
||||||
|
|
||||||
|
data = collate(data_arr, samples_per_gpu=len(imgs))
|
||||||
|
|
||||||
|
if self.device != 'cpu':
|
||||||
|
data = scatter(data, [self.device])[0]
|
||||||
|
|
||||||
|
return data, (data['masked_img'], data['mask'])
|
||||||
|
|
||||||
|
def visualize(self,
|
||||||
|
model: torch.nn.Module,
|
||||||
|
image: Union[str, np.ndarray],
|
||||||
|
result: list,
|
||||||
|
output_file: str,
|
||||||
|
window_name: str = '',
|
||||||
|
show_result: bool = False,
|
||||||
|
**kwargs) -> None:
|
||||||
|
"""Visualize predictions of a model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (nn.Module): Input model.
|
||||||
|
image (str | np.ndarray): Input image to draw predictions on.
|
||||||
|
result (list): A list of predictions.
|
||||||
|
output_file (str): Output file to save drawn image.
|
||||||
|
window_name (str): The name of visualization window. Defaults to
|
||||||
|
an empty string.
|
||||||
|
show_result (bool): Whether to show result in windows, defaults
|
||||||
|
to `False`.
|
||||||
|
"""
|
||||||
|
if len(result.shape) == 4:
|
||||||
|
result = result[0]
|
||||||
|
|
||||||
|
result = result.transpose(1, 2, 0)
|
||||||
|
result = (result + 1) * 127.5
|
||||||
|
result = np.clip(result, 0, 255)
|
||||||
|
|
||||||
|
if show_result:
|
||||||
|
int_result = result.astype(np.uint8)
|
||||||
|
mmcv.imshow(int_result, window_name, 0)
|
||||||
|
|
||||||
|
output_file = None if show_result else output_file
|
||||||
|
if output_file is not None:
|
||||||
|
mmcv.imwrite(result, output_file)
|
||||||
|
|
||||||
|
if not (show_result or output_file):
|
||||||
|
warnings.warn(
|
||||||
|
'show_result==False and output_file is not specified, only '
|
||||||
|
'result image will be returned')
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def run_inference(model: nn.Module, model_inputs: dict) -> list:
|
||||||
|
"""Run inference once for a model of a OpenMMLab Codebase.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (nn.Module): Input model.
|
||||||
|
model_inputs (dict): A dict containing model inputs tensor and
|
||||||
|
meta info.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: The predictions of model inference.
|
||||||
|
"""
|
||||||
|
results = model(model_inputs['masked_img'], model_inputs['mask'])
|
||||||
|
if isinstance(results, dict):
|
||||||
|
results = [results['fake_img']]
|
||||||
|
|
||||||
|
if not isinstance(results[0], np.ndarray):
|
||||||
|
results = [results[0].detach().cpu().numpy()]
|
||||||
|
return results
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_partition_cfg(partition_type: str, **kwargs) -> dict:
|
||||||
|
"""Get a certain partition config."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_tensor_from_input(input_data: dict, **kwargs) -> torch.Tensor:
|
||||||
|
"""Get input tensor from input data."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def evaluate_outputs(model_cfg,
|
||||||
|
outputs: list,
|
||||||
|
dataset: Dataset,
|
||||||
|
metrics: Optional[str] = None,
|
||||||
|
out: Optional[str] = None,
|
||||||
|
metric_options: Optional[dict] = None,
|
||||||
|
format_only: bool = False,
|
||||||
|
log_file: Optional[str] = None,
|
||||||
|
json_file: Optional[str] = None,
|
||||||
|
**kwargs) -> None:
|
||||||
|
"""Evaluation function implemented in mmedit.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_cfg (Config): The model config.
|
||||||
|
outputs (list): A list of result of model inference.
|
||||||
|
dataset (Dataset): Input dataset to run test.
|
||||||
|
metrics (str): Evaluation metrics, which depends on
|
||||||
|
the codebase and the dataset, e.g., "PSNR", "SSIM" in mmedit.
|
||||||
|
out (str): Output result file in pickle format, defaults to `None`.
|
||||||
|
metric_options (dict): Custom options for evaluation, will be
|
||||||
|
kwargs for dataset.evaluate() function. Defaults to `None`.
|
||||||
|
format_only (bool): Format the output results without perform
|
||||||
|
evaluation. It is useful when you want to format the result
|
||||||
|
to a specific format and submit it to the test server. Defaults
|
||||||
|
to `False`.
|
||||||
|
log_file (str | None): The file to write the evaluation results.
|
||||||
|
Defaults to `None` and the results will only print on stdout.
|
||||||
|
json_file (str | None): The file to write the evaluation metrics.
|
||||||
|
Defaults to `None`.
|
||||||
|
"""
|
||||||
|
from mmcv.utils import get_logger
|
||||||
|
logger = get_logger('test', log_file=log_file)
|
||||||
|
|
||||||
|
if out:
|
||||||
|
logger.debug(f'writing results to {out}')
|
||||||
|
mmcv.dump(outputs, out)
|
||||||
|
|
||||||
|
stats = dataset.evaluate(outputs)
|
||||||
|
if json_file is not None:
|
||||||
|
mmcv.dump(stats, json_file, indent=4)
|
||||||
|
|
||||||
|
print('')
|
||||||
|
for stat in stats:
|
||||||
|
logger.info('Eval-{}: {}'.format(stat, stats[stat]))
|
||||||
|
|
||||||
|
def get_preprocess(self) -> dict:
|
||||||
|
"""Get the preprocess information for SDK."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_postprocess(self) -> dict:
|
||||||
|
"""Get the postprocess information for SDK."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_model_name(self) -> str:
|
||||||
|
"""Get the model name."""
|
||||||
|
raise NotImplementedError
|
216
mmdeploy/codebase/mmedit/deploy/inpainting_model.py
Normal file
216
mmdeploy/codebase/mmedit/deploy/inpainting_model.py
Normal file
@ -0,0 +1,216 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import os.path as osp
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import mmcv
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from mmcv.utils import Config, Registry
|
||||||
|
from mmedit.core import L1Evaluation, psnr, ssim, tensor2img
|
||||||
|
|
||||||
|
from mmdeploy.codebase.base import BaseBackendModel
|
||||||
|
from mmdeploy.utils import (Backend, get_backend, get_codebase_config,
|
||||||
|
get_ir_config, load_config)
|
||||||
|
|
||||||
|
|
||||||
|
def __build_backend_model(cls_name: str, registry: Registry, *args, **kwargs):
|
||||||
|
return registry.module_dict[cls_name](*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
__BACKEND_MODEL = mmcv.utils.Registry(
|
||||||
|
'backend_models', build_func=__build_backend_model)
|
||||||
|
|
||||||
|
|
||||||
|
@__BACKEND_MODEL.register_module('end2end')
|
||||||
|
class End2EndModel(BaseBackendModel):
|
||||||
|
"""End to end model for inference of inpainting.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
backend (Backend): The backend enum, specifying backend type.
|
||||||
|
backend_files (Sequence[str]): Paths to all required backend files
|
||||||
|
(e.g. '.onnx' for ONNX Runtime).
|
||||||
|
device (str): A string represents device type.
|
||||||
|
model_cfg(Config): Input model config object.
|
||||||
|
deploy_cfg(str | Config): Deployment config file or loaded Config
|
||||||
|
object.
|
||||||
|
"""
|
||||||
|
_eval_metrics = dict(l1=L1Evaluation, psnr=psnr, ssim=ssim)
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
backend: Backend,
|
||||||
|
backend_files: Sequence[str],
|
||||||
|
device: str,
|
||||||
|
model_cfg: Config,
|
||||||
|
deploy_cfg: Union[str, Config] = None,
|
||||||
|
**kwargs):
|
||||||
|
super(End2EndModel, self).__init__(deploy_cfg=deploy_cfg)
|
||||||
|
|
||||||
|
self.input_names = get_ir_config(deploy_cfg).input_names
|
||||||
|
|
||||||
|
self.test_cfg = model_cfg.test_cfg
|
||||||
|
|
||||||
|
# init wrapper
|
||||||
|
self.wrapper = self._build_wrapper(
|
||||||
|
backend=backend,
|
||||||
|
backend_files=backend_files,
|
||||||
|
device=device,
|
||||||
|
input_names=self.input_names,
|
||||||
|
output_names=self.output_names,
|
||||||
|
deploy_cfg=deploy_cfg,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
masked_img: torch.Tensor,
|
||||||
|
mask: torch.Tensor,
|
||||||
|
test_mode: bool = False,
|
||||||
|
*args,
|
||||||
|
**kwargs) -> Union[list, dict]:
|
||||||
|
"""Run test inference for inpainting.
|
||||||
|
|
||||||
|
We want forward() to output an image or a evaluation result.
|
||||||
|
When test_mode is set, the output is evaluation result. Otherwise
|
||||||
|
it is an image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
masked_img (torch.Tensor): Image with hole as input.
|
||||||
|
mask (torch.Tensor): Mask as input.
|
||||||
|
test_mode (bool, optional): Whether use testing mode.
|
||||||
|
Defaults to True.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list | dict: Inpainted image or a evaluation results.
|
||||||
|
"""
|
||||||
|
if test_mode:
|
||||||
|
return self.forward_test(masked_img, mask, *args, **kwargs)
|
||||||
|
|
||||||
|
return self.forward_dummy(masked_img, mask, *args, **kwargs)
|
||||||
|
|
||||||
|
def forward_test(self,
|
||||||
|
masked_img: torch.Tensor,
|
||||||
|
mask: torch.Tensor,
|
||||||
|
save_path=None,
|
||||||
|
*args,
|
||||||
|
**kwargs):
|
||||||
|
"""Run inference for inpaintor to generate evaluation result.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
masked_img (torch.Tensor): Image with hole as input.
|
||||||
|
mask (torch.Tensor): Mask as input.
|
||||||
|
save_path (str, optional): If given a valid str, the results will
|
||||||
|
be saved in this path. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Evaluation results.
|
||||||
|
"""
|
||||||
|
outputs = self.forward_dummy(masked_img, mask, *args, **kwargs)
|
||||||
|
results = self.test_post_process(outputs, masked_img, mask, *args,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
if save_path is not None:
|
||||||
|
outputs = [torch.from_numpy(i).flip(1) for i in outputs]
|
||||||
|
|
||||||
|
filename, _ = osp.splitext(
|
||||||
|
osp.basename(kwargs['meta'][0]['gt_img_path']))
|
||||||
|
save_path = osp.join(save_path, f'{filename}.png')
|
||||||
|
mmcv.imwrite(tensor2img(outputs, min_max=(-1, 1)), save_path)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def forward_dummy(self, masked_img: torch.Tensor, mask: torch.Tensor,
|
||||||
|
*args, **kwargs):
|
||||||
|
"""Run test inference for inpaintor with backend wrapper.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
masked_img (torch.Tensor): Image with hole as input.
|
||||||
|
mask (torch.Tensor): Mask as input.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[np.ndarray] : Inpainted image.
|
||||||
|
"""
|
||||||
|
inputs = dict(masked_img=masked_img, mask=mask)
|
||||||
|
outputs = self.wrapper(inputs)
|
||||||
|
outputs = self.wrapper.output_to_list(outputs)
|
||||||
|
outputs = [out.detach().cpu().numpy() for out in outputs]
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def evaluate(self, output: Union[torch.Tensor, np.ndarray], masked_img,
|
||||||
|
mask, **kwargs):
|
||||||
|
"""Evaluation function implemented in mmedit.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output (torch.Tensor | np.ndarray): Model output with
|
||||||
|
shape (n, c, h, w).
|
||||||
|
masked_img (torch.Tensor): Image with hole as input.
|
||||||
|
mask (torch.Tensor): Mask as input.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Evaluation results.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if isinstance(output, np.ndarray):
|
||||||
|
output = torch.from_numpy(output)
|
||||||
|
gt_img = kwargs['gt_img'].cpu()
|
||||||
|
|
||||||
|
eval_result = dict()
|
||||||
|
data_dict = dict(gt_img=gt_img, fake_img=output, mask=mask.cpu())
|
||||||
|
for metric in self.test_cfg.metrics:
|
||||||
|
if metric in ['ssim', 'psnr']:
|
||||||
|
eval_result[metric] = self._eval_metrics[metric](
|
||||||
|
tensor2img(output, min_max=(-1, 1)),
|
||||||
|
tensor2img(gt_img, min_max=(-1, 1)),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
eval_result[metric] = self._eval_metrics[metric]()(
|
||||||
|
data_dict).item()
|
||||||
|
|
||||||
|
return eval_result
|
||||||
|
|
||||||
|
def test_post_process(self, outputs: list, masked_img, mask, *args,
|
||||||
|
**kwargs):
|
||||||
|
"""Get evaluation results by post-processing model outputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output (list[np.ndarray]) : The output inpainted image.
|
||||||
|
masked_img (torch.Tensor): Image with hole as input.
|
||||||
|
mask (torch.Tensor): Mask as input.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Evaluation results.
|
||||||
|
"""
|
||||||
|
if self.test_cfg is not None and self.test_cfg.get('metrics', None):
|
||||||
|
assert 'gt_img' in kwargs, (
|
||||||
|
'evaluation with metrics must have gt images.')
|
||||||
|
results = dict(
|
||||||
|
eval_result=self.evaluate(outputs[0], masked_img, mask, *args,
|
||||||
|
**kwargs))
|
||||||
|
else:
|
||||||
|
results = dict(masked_img=masked_img, fake_img=outputs)
|
||||||
|
if 'gt_img' in kwargs:
|
||||||
|
results['gt_img'] = kwargs['gt_img'].cpu()
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def show_result(self, *args, **kwargs):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
def build_inpainting_model(model_files: Sequence[str],
|
||||||
|
model_cfg: Union[str, Config],
|
||||||
|
deploy_cfg: Union[str,
|
||||||
|
Config], device: str, **kwargs):
|
||||||
|
model_cfg = load_config(model_cfg)[0]
|
||||||
|
deploy_cfg = load_config(deploy_cfg)[0]
|
||||||
|
|
||||||
|
backend = get_backend(deploy_cfg)
|
||||||
|
model_type = get_codebase_config(deploy_cfg).get('model_type', 'end2end')
|
||||||
|
|
||||||
|
backend_model = __BACKEND_MODEL.build(
|
||||||
|
model_type,
|
||||||
|
backend=backend,
|
||||||
|
backend_files=model_files,
|
||||||
|
device=device,
|
||||||
|
model_cfg=model_cfg,
|
||||||
|
deploy_cfg=deploy_cfg,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
return backend_model
|
2
mmdeploy/codebase/mmedit/models/__init__.py
Normal file
2
mmdeploy/codebase/mmedit/models/__init__.py
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from . import contextual_attention # noqa: F401,F403
|
30
mmdeploy/codebase/mmedit/models/contextual_attention.py
Normal file
30
mmdeploy/codebase/mmedit/models/contextual_attention.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from mmdeploy.core import FUNCTION_REWRITER
|
||||||
|
|
||||||
|
|
||||||
|
def _shape(x):
|
||||||
|
return torch.gather(torch.tensor(x.shape), 0, torch.tensor(
|
||||||
|
(0, 1, 2, 3))).tolist()
|
||||||
|
|
||||||
|
|
||||||
|
@FUNCTION_REWRITER.register_rewriter(
|
||||||
|
'mmedit.models.common.contextual_attention.ContextualAttentionModule.'
|
||||||
|
'patch_correlation')
|
||||||
|
def contextual_attention__patch_correlation(ctx, self, x, kernel):
|
||||||
|
# Force tensor shape to avoid the following RuntimeError:
|
||||||
|
# Unsupported: ONNX export of convolution for kernel of unknown shape.
|
||||||
|
kernel = kernel.reshape(_shape(kernel))
|
||||||
|
return ctx.origin_func(self, x, kernel)
|
||||||
|
|
||||||
|
|
||||||
|
@FUNCTION_REWRITER.register_rewriter(
|
||||||
|
'mmedit.models.common.contextual_attention.ContextualAttentionModule.'
|
||||||
|
'patch_copy_deconv')
|
||||||
|
def contextual_attention__patch_copy_deconv(ctx, self, attention_score,
|
||||||
|
context_filter):
|
||||||
|
# Force tensor shape to avoid the following RuntimeError:
|
||||||
|
# Unsupported: ONNX export of convolution for kernel of unknown shape.
|
||||||
|
context_filter = context_filter.reshape(_shape(context_filter))
|
||||||
|
return ctx.origin_func(self, attention_score, context_filter)
|
@ -29,6 +29,7 @@ class Task(AdvancedEnum):
|
|||||||
POSE_DETECTION = 'PoseDetection'
|
POSE_DETECTION = 'PoseDetection'
|
||||||
ROTATED_DETECTION = 'RotatedDetection'
|
ROTATED_DETECTION = 'RotatedDetection'
|
||||||
VIDEO_RECOGNITION = 'VideoRecognition'
|
VIDEO_RECOGNITION = 'VideoRecognition'
|
||||||
|
INPAINTING = 'Inpainting'
|
||||||
|
|
||||||
|
|
||||||
class Codebase(AdvancedEnum):
|
class Codebase(AdvancedEnum):
|
||||||
|
1
tests/test_codebase/test_mmedit/data/ann_file.txt
Normal file
1
tests/test_codebase/test_mmedit/data/ann_file.txt
Normal file
@ -0,0 +1 @@
|
|||||||
|
imgs/blank.jpg
|
109
tests/test_codebase/test_mmedit/data/inpainting_model.py
Normal file
109
tests/test_codebase/test_mmedit/data/inpainting_model.py
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
exp_name = 'deepfillv1_256x256_8x2_places'
|
||||||
|
|
||||||
|
model = dict(
|
||||||
|
type='DeepFillv1Inpaintor',
|
||||||
|
encdec=dict(
|
||||||
|
type='DeepFillEncoderDecoder',
|
||||||
|
stage1=dict(
|
||||||
|
type='GLEncoderDecoder',
|
||||||
|
encoder=dict(type='DeepFillEncoder', padding_mode='reflect'),
|
||||||
|
decoder=dict(
|
||||||
|
type='DeepFillDecoder',
|
||||||
|
in_channels=128,
|
||||||
|
padding_mode='reflect'),
|
||||||
|
dilation_neck=dict(
|
||||||
|
type='GLDilationNeck',
|
||||||
|
in_channels=128,
|
||||||
|
act_cfg=dict(type='ELU'),
|
||||||
|
padding_mode='reflect')),
|
||||||
|
stage2=dict(
|
||||||
|
type='DeepFillRefiner',
|
||||||
|
encoder_attention=dict(
|
||||||
|
type='DeepFillEncoder',
|
||||||
|
encoder_type='stage2_attention',
|
||||||
|
padding_mode='reflect'),
|
||||||
|
encoder_conv=dict(
|
||||||
|
type='DeepFillEncoder',
|
||||||
|
encoder_type='stage2_conv',
|
||||||
|
padding_mode='reflect'),
|
||||||
|
dilation_neck=dict(
|
||||||
|
type='GLDilationNeck',
|
||||||
|
in_channels=128,
|
||||||
|
act_cfg=dict(type='ELU'),
|
||||||
|
padding_mode='reflect'),
|
||||||
|
contextual_attention=dict(
|
||||||
|
type='ContextualAttentionNeck',
|
||||||
|
in_channels=128,
|
||||||
|
padding_mode='reflect'),
|
||||||
|
decoder=dict(
|
||||||
|
type='DeepFillDecoder',
|
||||||
|
in_channels=256,
|
||||||
|
padding_mode='reflect'))),
|
||||||
|
disc=dict(
|
||||||
|
type='DeepFillv1Discriminators',
|
||||||
|
global_disc_cfg=dict(
|
||||||
|
type='MultiLayerDiscriminator',
|
||||||
|
in_channels=3,
|
||||||
|
max_channels=256,
|
||||||
|
fc_in_channels=65536,
|
||||||
|
fc_out_channels=1,
|
||||||
|
num_convs=4,
|
||||||
|
norm_cfg=None,
|
||||||
|
act_cfg=dict(type='ELU'),
|
||||||
|
out_act_cfg=dict(type='LeakyReLU', negative_slope=0.2)),
|
||||||
|
local_disc_cfg=dict(
|
||||||
|
type='MultiLayerDiscriminator',
|
||||||
|
in_channels=3,
|
||||||
|
max_channels=512,
|
||||||
|
fc_in_channels=32768,
|
||||||
|
fc_out_channels=1,
|
||||||
|
num_convs=4,
|
||||||
|
norm_cfg=None,
|
||||||
|
act_cfg=dict(type='ELU'),
|
||||||
|
out_act_cfg=dict(type='LeakyReLU', negative_slope=0.2))),
|
||||||
|
stage1_loss_type=('loss_l1_hole', 'loss_l1_valid'),
|
||||||
|
stage2_loss_type=('loss_l1_hole', 'loss_l1_valid', 'loss_gan'),
|
||||||
|
loss_gan=dict(type='GANLoss', gan_type='wgan', loss_weight=0.0001),
|
||||||
|
loss_l1_hole=dict(type='L1Loss', loss_weight=1.0),
|
||||||
|
loss_l1_valid=dict(type='L1Loss', loss_weight=1.0),
|
||||||
|
loss_gp=dict(type='GradientPenaltyLoss', loss_weight=10.0),
|
||||||
|
loss_disc_shift=dict(type='DiscShiftLoss', loss_weight=0.001),
|
||||||
|
pretrained=None)
|
||||||
|
|
||||||
|
test_cfg = dict(metrics=['l1', 'psnr', 'ssim'])
|
||||||
|
|
||||||
|
test_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile', key='gt_img'),
|
||||||
|
dict(
|
||||||
|
type='LoadMask',
|
||||||
|
mask_mode='bbox',
|
||||||
|
mask_config=dict(
|
||||||
|
max_bbox_shape=(128, 128),
|
||||||
|
max_bbox_delta=40,
|
||||||
|
min_margin=20,
|
||||||
|
img_shape=(256, 256))),
|
||||||
|
dict(type='Crop', keys=['gt_img'], crop_size=(384, 384), random_crop=True),
|
||||||
|
dict(type='Resize', keys=['gt_img'], scale=(256, 256), keep_ratio=False),
|
||||||
|
dict(
|
||||||
|
type='Normalize',
|
||||||
|
keys=['gt_img'],
|
||||||
|
mean=[127.5, 127.5, 127.5],
|
||||||
|
std=[127.5, 127.5, 127.5],
|
||||||
|
to_rgb=False),
|
||||||
|
dict(type='GetMaskedImage'),
|
||||||
|
dict(
|
||||||
|
type='Collect',
|
||||||
|
keys=['gt_img', 'masked_img', 'mask', 'mask_bbox'],
|
||||||
|
meta_keys=['gt_img_path']),
|
||||||
|
dict(type='ImageToTensor', keys=['gt_img', 'masked_img', 'mask']),
|
||||||
|
dict(type='ToTensor', keys=['mask_bbox'])
|
||||||
|
]
|
||||||
|
data = dict(
|
||||||
|
test_dataloader=dict(samples_per_gpu=1),
|
||||||
|
test=dict(
|
||||||
|
type='ImgInpaintingDataset',
|
||||||
|
ann_file='tests/test_codebase/test_mmedit/data/ann_file.txt',
|
||||||
|
data_prefix='tests/test_codebase/test_mmedit/data',
|
||||||
|
pipeline=test_pipeline,
|
||||||
|
test_mode=True))
|
128
tests/test_codebase/test_mmedit/test_inpainting.py
Normal file
128
tests/test_codebase/test_mmedit/test_inpainting.py
Normal file
@ -0,0 +1,128 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import mmcv
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from mmdeploy.apis import build_task_processor
|
||||||
|
from mmdeploy.utils import load_config
|
||||||
|
from mmdeploy.utils.test import SwitchBackendWrapper
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope='module')
|
||||||
|
def model_cfg():
|
||||||
|
cfg = 'tests/test_codebase/test_mmedit/data/inpainting_model.py'
|
||||||
|
return load_config(cfg)[0]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope='module')
|
||||||
|
def deploy_cfg():
|
||||||
|
return mmcv.Config(
|
||||||
|
dict(
|
||||||
|
backend_config=dict(type='onnxruntime'),
|
||||||
|
codebase_config=dict(type='mmedit', task='Inpainting'),
|
||||||
|
onnx_config=dict(
|
||||||
|
type='onnx',
|
||||||
|
export_params=True,
|
||||||
|
keep_initializers_as_inputs=False,
|
||||||
|
opset_version=11,
|
||||||
|
input_shape=None,
|
||||||
|
input_names=['masked_img', 'mask'],
|
||||||
|
output_names=['fake_img'])))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope='module')
|
||||||
|
def input_img():
|
||||||
|
return np.random.rand(32, 32, 3)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope='module')
|
||||||
|
def model_input():
|
||||||
|
return dict(
|
||||||
|
masked_img=np.random.rand(32, 32, 3),
|
||||||
|
mask=np.random.randint(0, 2, (32, 32)))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope='module')
|
||||||
|
def task_processor(model_cfg, deploy_cfg):
|
||||||
|
return build_task_processor(model_cfg, deploy_cfg, device='cpu')
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_pytorch_model(task_processor):
|
||||||
|
torch_model = task_processor.init_pytorch_model(model_checkpoint=None)
|
||||||
|
assert torch_model is not None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope='module')
|
||||||
|
def backend_model(task_processor):
|
||||||
|
from mmdeploy.backend.onnxruntime import ORTWrapper
|
||||||
|
with SwitchBackendWrapper(ORTWrapper) as wrapper:
|
||||||
|
wrapper.set(outputs=dict(fake_img=torch.rand(3, 32, 32)))
|
||||||
|
yield task_processor.init_backend_model([''])
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_backend_model(backend_model):
|
||||||
|
assert backend_model is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_input(task_processor, input_img):
|
||||||
|
inputs, _ = task_processor.create_input(
|
||||||
|
input_img, img_shape=input_img.shape[:2])
|
||||||
|
assert 'masked_img' in inputs
|
||||||
|
assert 'mask' in inputs
|
||||||
|
|
||||||
|
|
||||||
|
def test_visualize(backend_model, task_processor, model_input, input_img):
|
||||||
|
result = task_processor.run_inference(backend_model, model_input)
|
||||||
|
with tempfile.TemporaryDirectory() as dir:
|
||||||
|
filename = dir + 'tmp.jpg'
|
||||||
|
task_processor.visualize(backend_model, input_img, result[0], filename,
|
||||||
|
'onnxruntime')
|
||||||
|
assert os.path.exists(filename)
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_inference(backend_model, task_processor, model_input):
|
||||||
|
results = task_processor.run_inference(backend_model, model_input)
|
||||||
|
assert results is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_tensor_from_input(task_processor, model_input):
|
||||||
|
with pytest.raises(NotImplementedError):
|
||||||
|
task_processor.get_tensor_from_input(model_input)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_partition_cfg(task_processor):
|
||||||
|
with pytest.raises(NotImplementedError):
|
||||||
|
task_processor.get_partition_cfg(None)
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_dataset(task_processor):
|
||||||
|
data = dict(
|
||||||
|
test=dict(
|
||||||
|
type='ImgInpaintingDataset',
|
||||||
|
ann_file='tests/test_codebase/test_mmedit/data/ann_file.txt',
|
||||||
|
data_prefix='tests/test_codebase/test_mmedit/data',
|
||||||
|
pipeline=[
|
||||||
|
dict(type='LoadImageFromFile', key='gt_img'),
|
||||||
|
dict(type='LoadMask')
|
||||||
|
]))
|
||||||
|
dataset_cfg = mmcv.Config(dict(data=data))
|
||||||
|
dataset = task_processor.build_dataset(
|
||||||
|
dataset_cfg=dataset_cfg, dataset_type='test')
|
||||||
|
assert dataset is not None, 'Failed to build dataset'
|
||||||
|
dataloader = task_processor.build_dataloader(dataset, 1, 1)
|
||||||
|
assert dataloader is not None, 'Failed to build dataloader'
|
||||||
|
|
||||||
|
|
||||||
|
def test_single_gpu_test(backend_model, model_cfg, task_processor):
|
||||||
|
from mmcv.parallel import MMDataParallel
|
||||||
|
dataset = task_processor.build_dataset(model_cfg, dataset_type='test')
|
||||||
|
assert dataset is not None, 'Failed to build dataset'
|
||||||
|
dataloader = task_processor.build_dataloader(dataset, 1, 1)
|
||||||
|
assert dataloader is not None, 'Failed to build dataloader'
|
||||||
|
backend_model = MMDataParallel(backend_model, device_ids=[0])
|
||||||
|
outputs = task_processor.single_gpu_test(backend_model, dataloader)
|
||||||
|
assert outputs is not None, 'Failed to test model'
|
48
tests/test_codebase/test_mmedit/test_inpainting_model.py
Normal file
48
tests/test_codebase/test_mmedit/test_inpainting_model.py
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import mmcv
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from mmdeploy.utils import Backend, load_config
|
||||||
|
from mmdeploy.utils.test import SwitchBackendWrapper, backend_checker
|
||||||
|
|
||||||
|
|
||||||
|
@backend_checker(Backend.ONNXRUNTIME)
|
||||||
|
class TestEnd2EndModel:
|
||||||
|
|
||||||
|
@pytest.fixture(scope='class')
|
||||||
|
def end2end_model(self):
|
||||||
|
# force add backend wrapper regardless of plugins
|
||||||
|
# make sure ONNXRuntimeEditor can use ORTWrapper inside itself
|
||||||
|
from mmdeploy.backend.onnxruntime import ORTWrapper
|
||||||
|
from mmdeploy.codebase.mmedit.deploy.inpainting_model import \
|
||||||
|
End2EndModel
|
||||||
|
|
||||||
|
# simplify backend inference
|
||||||
|
with SwitchBackendWrapper(ORTWrapper) as wrapper:
|
||||||
|
wrapper.set(outputs=dict(fake_img=torch.rand(3, 32, 32)))
|
||||||
|
deploy_cfg = mmcv.Config(
|
||||||
|
dict(
|
||||||
|
onnx_config=dict(
|
||||||
|
input_names=['masked_img', 'mask'],
|
||||||
|
output_names=['fake_img'])))
|
||||||
|
model_cfg = load_config(
|
||||||
|
'tests/test_codebase/test_mmedit/data/inpainting_model.py')[0]
|
||||||
|
model = End2EndModel(Backend.ONNXRUNTIME, [''], 'cpu', model_cfg,
|
||||||
|
deploy_cfg)
|
||||||
|
yield model
|
||||||
|
|
||||||
|
def test_forward(self, end2end_model):
|
||||||
|
masked_img = np.random.rand(3, 32, 32)
|
||||||
|
mask = np.random.randint(0, 2, (1, 32, 32))
|
||||||
|
|
||||||
|
results = end2end_model.forward(masked_img, mask, test_mode=False)
|
||||||
|
assert results is not None
|
||||||
|
|
||||||
|
results = end2end_model.forward(
|
||||||
|
masked_img,
|
||||||
|
torch.tensor(mask),
|
||||||
|
test_mode=True,
|
||||||
|
gt_img=torch.tensor(results[0]))
|
||||||
|
assert results is not None
|
Loading…
x
Reference in New Issue
Block a user