remove deploy test loop and enable edit SDK (#1083)

pull/1091/head
AllentDan 2022-09-28 16:28:24 +08:00 committed by GitHub
parent 11409d7c98
commit f26b352b7d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 54 additions and 69 deletions

View File

@ -1,12 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional, Sequence, Union
from typing import Dict, Optional, Union
import torch
from mmengine.device import get_device
from mmengine.logging import MMLogger
from mmengine.model import BaseModel
from mmengine.registry import LOOPS
from mmengine.runner import Runner, TestLoop, autocast
from mmengine.runner import Runner
class DeployTestRunner(Runner):
@ -64,38 +62,3 @@ class DeployTestRunner(Runner):
log_file = self._log_file
return super().build_logger(log_level, log_file, **kwargs)
@LOOPS.register_module()
class DeployTestLoop(TestLoop):
"""Loop for test. To skip data_preprocessor for SDK.
Args:
runner (Runner): A reference of runner.
dataloader (Dataloader or dict): A dataloader object or a dict to
build a dataloader.
evaluator (Evaluator or dict or list): Used for computing metrics.
fp16 (bool): Whether to enable fp16 testing. Defaults to
False.
"""
@torch.no_grad()
def run_iter(self, idx, data_batch: Sequence[dict]) -> None:
"""Iterate one mini-batch.
Args:
data_batch (Sequence[dict]): Batch of data from dataloader.
"""
self.runner.call_hook(
'before_test_iter', batch_idx=idx, data_batch=data_batch)
# predictions should be sequence of BaseDataElement
with autocast(enabled=self.fp16):
# skip data_preprocessor to avoid Normalize and Padding for SDK
outputs = self.runner.model._run_forward(
data_batch, mode='predict')
self.evaluator.process(data_samples=outputs, data_batch=data_batch)
self.runner.call_hook(
'after_test_iter',
batch_idx=idx,
data_batch=data_batch,
outputs=outputs)

View File

@ -90,6 +90,10 @@ class End2EndModel(BaseBackendModel):
class SDKEnd2EndModel(End2EndModel):
"""SDK inference class, converts SDK output to mmcls format."""
def __init__(self, *arg, **kwargs):
kwargs['data_preprocessor'] = None
super().__init__(*arg, **kwargs)
def forward(self,
inputs: Sequence[torch.Tensor],
data_samples: Optional[List[BaseDataElement]] = None,

View File

@ -307,9 +307,6 @@ class SuperResolution(BaseTask):
'valid_ratio'
]
preprocess = model_cfg.test_pipeline
for item in preprocess:
if 'Normalize' == item['type'] and 'std' in item:
item['std'] = [255, 255, 255]
preprocess.insert(1, model_cfg.model.data_preprocessor)
transforms = preprocess
@ -322,6 +319,7 @@ class SuperResolution(BaseTask):
transform['type'] = 'ImageToTensor'
if transform['type'] == 'EditDataPreprocessor':
transform['type'] = 'Normalize'
transform['to_rgb'] = transform.get('to_rgb', False)
if transform['type'] == 'PackEditInputs':
meta_keys += transform[
'meta_keys'] if 'meta_keys' in transform else []

View File

@ -113,8 +113,11 @@ class End2EndModel(BaseBackendModel):
class SDKEnd2EndModel(End2EndModel):
"""SDK inference class, converts SDK output to mmedit format."""
def __init__(self, *args, **kwargs):
super(SDKEnd2EndModel, self).__init__(*args, **kwargs)
def forward(self,
lq: torch.Tensor,
inputs: torch.Tensor,
data_samples: Optional[List[BaseDataElement]] = None,
mode: str = 'predict',
*args,
@ -135,8 +138,25 @@ class SDKEnd2EndModel(End2EndModel):
Returns:
list | dict: High resolution image or a evaluation results.
"""
output = self.wrapper.invoke(lq[0].contiguous().detach().cpu().numpy())
return [output]
if hasattr(self.data_preprocessor, 'destructor'):
inputs = self.data_preprocessor.destructor(
inputs.to(self.data_preprocessor.input_std.device))
outputs = []
for i in range(inputs.shape[0]):
output = self.wrapper.invoke(inputs[i].permute(
1, 2, 0).contiguous().detach().cpu().numpy())
outputs.append(
torch.from_numpy(output).permute(2, 0, 1).contiguous())
outputs = torch.stack(outputs, 0) / 255.
if hasattr(self.data_preprocessor, 'destructor'):
outputs = self.data_preprocessor.destructor(
outputs.to(self.data_preprocessor.outputs_std.device))
for i, sr_pred in enumerate(outputs):
pred = EditDataSample()
pred.set_data(dict(pred_img=PixelData(**dict(data=sr_pred))))
data_samples[i].set_data(dict(output=pred))
return data_samples
def build_super_resolution_model(

View File

@ -23,10 +23,8 @@ class End2EndModel(BaseBackendModel):
backend_files (Sequence[str]): Paths to all required backend files(e.g.
'.onnx' for ONNX Runtime, '.param' and '.bin' for ncnn).
device (str): A string represents device type.
deploy_cfg (str | mmengine.Config): Deployment config file or loaded
Config object.
model_cfg (str | mmengine.Config): Model config file or loaded Config
object.
deploy_cfg (mmengine.Config | None): Loaded Config object of MMDeploy.
model_cfg (mmengine.Config | None): Loaded Config object of MMOCR.
"""
def __init__(
@ -34,19 +32,18 @@ class End2EndModel(BaseBackendModel):
backend: Backend,
backend_files: Sequence[str],
device: str,
deploy_cfg: Union[str, mmengine.Config] = None,
model_cfg: Union[str, mmengine.Config] = None,
deploy_cfg: Optional[mmengine.Config] = None,
model_cfg: Optional[mmengine.Config] = None,
**kwargs,
):
super(End2EndModel, self).__init__(deploy_cfg=deploy_cfg)
model_cfg, deploy_cfg = load_config(model_cfg, deploy_cfg)
super(End2EndModel, self).__init__(
deploy_cfg=deploy_cfg,
data_preprocessor=model_cfg.model.data_preprocessor)
self.deploy_cfg = deploy_cfg
self.show_score = False
from mmocr.registry import MODELS
self.det_head = MODELS.build(model_cfg.model.det_head)
self.data_preprocessor = MODELS.build(
model_cfg.model.data_preprocessor)
self._init_wrapper(
backend=backend,
backend_files=backend_files,
@ -125,6 +122,10 @@ class End2EndModel(BaseBackendModel):
class SDKEnd2EndModel(End2EndModel):
"""SDK inference class, converts SDK output to mmocr format."""
def __init__(self, *args, **kwargs):
kwargs['model_cfg'].model.data_preprocessor = None
super(SDKEnd2EndModel, self).__init__(*args, **kwargs)
def forward(self,
inputs: torch.Tensor,
data_samples: Optional[List[BaseDataElement]] = None,

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Sequence, Union
from typing import Optional, Sequence, Union
import mmengine
import torch
@ -23,10 +23,8 @@ class End2EndModel(BaseBackendModel):
backend_files (Sequence[str]): Paths to all required backend files(e.g.
'.onnx' for ONNX Runtime, '.param' and '.bin' for ncnn).
device (str): A string represents device type.
deploy_cfg (str | mmengine.Config): Deployment config file or loaded
Config object.
model_cfg (str | mmengine.Config): Model config file or loaded Config
object.
deploy_cfg (mmengine.Config | None): Loaded Config object of MMDeploy.
model_cfg (mmengine.Config | None): Loaded Config object of MMOCR.
"""
def __init__(
@ -34,10 +32,12 @@ class End2EndModel(BaseBackendModel):
backend: Backend,
backend_files: Sequence[str],
device: str,
deploy_cfg: Union[str, mmengine.Config] = None,
model_cfg: Union[str, mmengine.Config] = None,
deploy_cfg: Optional[mmengine.Config] = None,
model_cfg: Optional[mmengine.Config] = None,
):
super(End2EndModel, self).__init__(deploy_cfg=deploy_cfg)
super(End2EndModel, self).__init__(
deploy_cfg=deploy_cfg,
data_preprocessor=model_cfg.model.data_preprocessor)
model_cfg, deploy_cfg = load_config(model_cfg, deploy_cfg)
self.deploy_cfg = deploy_cfg
self.show_score = False
@ -53,8 +53,6 @@ class End2EndModel(BaseBackendModel):
if decoder.get('dictionary', None) is None:
decoder.update(dictionary=self.dictionary)
self.decoder = MODELS.build(decoder)
self.data_preprocessor = MODELS.build(
model_cfg.model.data_preprocessor)
self._init_wrapper(
backend=backend, backend_files=backend_files, device=device)
@ -113,6 +111,10 @@ class End2EndModel(BaseBackendModel):
class SDKEnd2EndModel(End2EndModel):
"""SDK inference class, converts SDK output to mmocr format."""
def __init__(self, *args, **kwargs):
kwargs['model_cfg'].model.data_preprocessor = None
super(SDKEnd2EndModel, self).__init__(*args, **kwargs)
def forward(self, inputs: Sequence[torch.Tensor],
data_samples: RecSampleList, *args, **kwargs):
"""Run forward inference.

View File

@ -6,8 +6,7 @@ from copy import deepcopy
from mmengine import DictAction
from mmdeploy.apis import build_task_processor
from mmdeploy.utils.config_utils import get_backend, load_config
from mmdeploy.utils.constants import Backend
from mmdeploy.utils.config_utils import load_config
from mmdeploy.utils.timer import TimeCounter
@ -89,8 +88,6 @@ def main():
# load deploy_cfg
deploy_cfg, model_cfg = load_config(deploy_cfg_path, model_cfg_path)
if get_backend(deploy_cfg) == Backend.SDK:
model_cfg.test_cfg.type = 'DeployTestLoop'
# work_dir is determined in this priority: CLI > segment in file > filename
if args.work_dir is not None: