remove deploy test loop and enable edit SDK (#1083)
parent
11409d7c98
commit
f26b352b7d
|
@ -1,12 +1,10 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, Optional, Sequence, Union
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
from mmengine.device import get_device
|
||||
from mmengine.logging import MMLogger
|
||||
from mmengine.model import BaseModel
|
||||
from mmengine.registry import LOOPS
|
||||
from mmengine.runner import Runner, TestLoop, autocast
|
||||
from mmengine.runner import Runner
|
||||
|
||||
|
||||
class DeployTestRunner(Runner):
|
||||
|
@ -64,38 +62,3 @@ class DeployTestRunner(Runner):
|
|||
log_file = self._log_file
|
||||
|
||||
return super().build_logger(log_level, log_file, **kwargs)
|
||||
|
||||
|
||||
@LOOPS.register_module()
|
||||
class DeployTestLoop(TestLoop):
|
||||
"""Loop for test. To skip data_preprocessor for SDK.
|
||||
|
||||
Args:
|
||||
runner (Runner): A reference of runner.
|
||||
dataloader (Dataloader or dict): A dataloader object or a dict to
|
||||
build a dataloader.
|
||||
evaluator (Evaluator or dict or list): Used for computing metrics.
|
||||
fp16 (bool): Whether to enable fp16 testing. Defaults to
|
||||
False.
|
||||
"""
|
||||
|
||||
@torch.no_grad()
|
||||
def run_iter(self, idx, data_batch: Sequence[dict]) -> None:
|
||||
"""Iterate one mini-batch.
|
||||
|
||||
Args:
|
||||
data_batch (Sequence[dict]): Batch of data from dataloader.
|
||||
"""
|
||||
self.runner.call_hook(
|
||||
'before_test_iter', batch_idx=idx, data_batch=data_batch)
|
||||
# predictions should be sequence of BaseDataElement
|
||||
with autocast(enabled=self.fp16):
|
||||
# skip data_preprocessor to avoid Normalize and Padding for SDK
|
||||
outputs = self.runner.model._run_forward(
|
||||
data_batch, mode='predict')
|
||||
self.evaluator.process(data_samples=outputs, data_batch=data_batch)
|
||||
self.runner.call_hook(
|
||||
'after_test_iter',
|
||||
batch_idx=idx,
|
||||
data_batch=data_batch,
|
||||
outputs=outputs)
|
||||
|
|
|
@ -90,6 +90,10 @@ class End2EndModel(BaseBackendModel):
|
|||
class SDKEnd2EndModel(End2EndModel):
|
||||
"""SDK inference class, converts SDK output to mmcls format."""
|
||||
|
||||
def __init__(self, *arg, **kwargs):
|
||||
kwargs['data_preprocessor'] = None
|
||||
super().__init__(*arg, **kwargs)
|
||||
|
||||
def forward(self,
|
||||
inputs: Sequence[torch.Tensor],
|
||||
data_samples: Optional[List[BaseDataElement]] = None,
|
||||
|
|
|
@ -307,9 +307,6 @@ class SuperResolution(BaseTask):
|
|||
'valid_ratio'
|
||||
]
|
||||
preprocess = model_cfg.test_pipeline
|
||||
for item in preprocess:
|
||||
if 'Normalize' == item['type'] and 'std' in item:
|
||||
item['std'] = [255, 255, 255]
|
||||
|
||||
preprocess.insert(1, model_cfg.model.data_preprocessor)
|
||||
transforms = preprocess
|
||||
|
@ -322,6 +319,7 @@ class SuperResolution(BaseTask):
|
|||
transform['type'] = 'ImageToTensor'
|
||||
if transform['type'] == 'EditDataPreprocessor':
|
||||
transform['type'] = 'Normalize'
|
||||
transform['to_rgb'] = transform.get('to_rgb', False)
|
||||
if transform['type'] == 'PackEditInputs':
|
||||
meta_keys += transform[
|
||||
'meta_keys'] if 'meta_keys' in transform else []
|
||||
|
|
|
@ -113,8 +113,11 @@ class End2EndModel(BaseBackendModel):
|
|||
class SDKEnd2EndModel(End2EndModel):
|
||||
"""SDK inference class, converts SDK output to mmedit format."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(SDKEnd2EndModel, self).__init__(*args, **kwargs)
|
||||
|
||||
def forward(self,
|
||||
lq: torch.Tensor,
|
||||
inputs: torch.Tensor,
|
||||
data_samples: Optional[List[BaseDataElement]] = None,
|
||||
mode: str = 'predict',
|
||||
*args,
|
||||
|
@ -135,8 +138,25 @@ class SDKEnd2EndModel(End2EndModel):
|
|||
Returns:
|
||||
list | dict: High resolution image or a evaluation results.
|
||||
"""
|
||||
output = self.wrapper.invoke(lq[0].contiguous().detach().cpu().numpy())
|
||||
return [output]
|
||||
if hasattr(self.data_preprocessor, 'destructor'):
|
||||
inputs = self.data_preprocessor.destructor(
|
||||
inputs.to(self.data_preprocessor.input_std.device))
|
||||
outputs = []
|
||||
for i in range(inputs.shape[0]):
|
||||
output = self.wrapper.invoke(inputs[i].permute(
|
||||
1, 2, 0).contiguous().detach().cpu().numpy())
|
||||
outputs.append(
|
||||
torch.from_numpy(output).permute(2, 0, 1).contiguous())
|
||||
outputs = torch.stack(outputs, 0) / 255.
|
||||
if hasattr(self.data_preprocessor, 'destructor'):
|
||||
outputs = self.data_preprocessor.destructor(
|
||||
outputs.to(self.data_preprocessor.outputs_std.device))
|
||||
|
||||
for i, sr_pred in enumerate(outputs):
|
||||
pred = EditDataSample()
|
||||
pred.set_data(dict(pred_img=PixelData(**dict(data=sr_pred))))
|
||||
data_samples[i].set_data(dict(output=pred))
|
||||
return data_samples
|
||||
|
||||
|
||||
def build_super_resolution_model(
|
||||
|
|
|
@ -23,10 +23,8 @@ class End2EndModel(BaseBackendModel):
|
|||
backend_files (Sequence[str]): Paths to all required backend files(e.g.
|
||||
'.onnx' for ONNX Runtime, '.param' and '.bin' for ncnn).
|
||||
device (str): A string represents device type.
|
||||
deploy_cfg (str | mmengine.Config): Deployment config file or loaded
|
||||
Config object.
|
||||
model_cfg (str | mmengine.Config): Model config file or loaded Config
|
||||
object.
|
||||
deploy_cfg (mmengine.Config | None): Loaded Config object of MMDeploy.
|
||||
model_cfg (mmengine.Config | None): Loaded Config object of MMOCR.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -34,19 +32,18 @@ class End2EndModel(BaseBackendModel):
|
|||
backend: Backend,
|
||||
backend_files: Sequence[str],
|
||||
device: str,
|
||||
deploy_cfg: Union[str, mmengine.Config] = None,
|
||||
model_cfg: Union[str, mmengine.Config] = None,
|
||||
deploy_cfg: Optional[mmengine.Config] = None,
|
||||
model_cfg: Optional[mmengine.Config] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super(End2EndModel, self).__init__(deploy_cfg=deploy_cfg)
|
||||
model_cfg, deploy_cfg = load_config(model_cfg, deploy_cfg)
|
||||
super(End2EndModel, self).__init__(
|
||||
deploy_cfg=deploy_cfg,
|
||||
data_preprocessor=model_cfg.model.data_preprocessor)
|
||||
self.deploy_cfg = deploy_cfg
|
||||
self.show_score = False
|
||||
|
||||
from mmocr.registry import MODELS
|
||||
self.det_head = MODELS.build(model_cfg.model.det_head)
|
||||
self.data_preprocessor = MODELS.build(
|
||||
model_cfg.model.data_preprocessor)
|
||||
self._init_wrapper(
|
||||
backend=backend,
|
||||
backend_files=backend_files,
|
||||
|
@ -125,6 +122,10 @@ class End2EndModel(BaseBackendModel):
|
|||
class SDKEnd2EndModel(End2EndModel):
|
||||
"""SDK inference class, converts SDK output to mmocr format."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs['model_cfg'].model.data_preprocessor = None
|
||||
super(SDKEnd2EndModel, self).__init__(*args, **kwargs)
|
||||
|
||||
def forward(self,
|
||||
inputs: torch.Tensor,
|
||||
data_samples: Optional[List[BaseDataElement]] = None,
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Sequence, Union
|
||||
from typing import Optional, Sequence, Union
|
||||
|
||||
import mmengine
|
||||
import torch
|
||||
|
@ -23,10 +23,8 @@ class End2EndModel(BaseBackendModel):
|
|||
backend_files (Sequence[str]): Paths to all required backend files(e.g.
|
||||
'.onnx' for ONNX Runtime, '.param' and '.bin' for ncnn).
|
||||
device (str): A string represents device type.
|
||||
deploy_cfg (str | mmengine.Config): Deployment config file or loaded
|
||||
Config object.
|
||||
model_cfg (str | mmengine.Config): Model config file or loaded Config
|
||||
object.
|
||||
deploy_cfg (mmengine.Config | None): Loaded Config object of MMDeploy.
|
||||
model_cfg (mmengine.Config | None): Loaded Config object of MMOCR.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -34,10 +32,12 @@ class End2EndModel(BaseBackendModel):
|
|||
backend: Backend,
|
||||
backend_files: Sequence[str],
|
||||
device: str,
|
||||
deploy_cfg: Union[str, mmengine.Config] = None,
|
||||
model_cfg: Union[str, mmengine.Config] = None,
|
||||
deploy_cfg: Optional[mmengine.Config] = None,
|
||||
model_cfg: Optional[mmengine.Config] = None,
|
||||
):
|
||||
super(End2EndModel, self).__init__(deploy_cfg=deploy_cfg)
|
||||
super(End2EndModel, self).__init__(
|
||||
deploy_cfg=deploy_cfg,
|
||||
data_preprocessor=model_cfg.model.data_preprocessor)
|
||||
model_cfg, deploy_cfg = load_config(model_cfg, deploy_cfg)
|
||||
self.deploy_cfg = deploy_cfg
|
||||
self.show_score = False
|
||||
|
@ -53,8 +53,6 @@ class End2EndModel(BaseBackendModel):
|
|||
if decoder.get('dictionary', None) is None:
|
||||
decoder.update(dictionary=self.dictionary)
|
||||
self.decoder = MODELS.build(decoder)
|
||||
self.data_preprocessor = MODELS.build(
|
||||
model_cfg.model.data_preprocessor)
|
||||
self._init_wrapper(
|
||||
backend=backend, backend_files=backend_files, device=device)
|
||||
|
||||
|
@ -113,6 +111,10 @@ class End2EndModel(BaseBackendModel):
|
|||
class SDKEnd2EndModel(End2EndModel):
|
||||
"""SDK inference class, converts SDK output to mmocr format."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs['model_cfg'].model.data_preprocessor = None
|
||||
super(SDKEnd2EndModel, self).__init__(*args, **kwargs)
|
||||
|
||||
def forward(self, inputs: Sequence[torch.Tensor],
|
||||
data_samples: RecSampleList, *args, **kwargs):
|
||||
"""Run forward inference.
|
||||
|
|
|
@ -6,8 +6,7 @@ from copy import deepcopy
|
|||
from mmengine import DictAction
|
||||
|
||||
from mmdeploy.apis import build_task_processor
|
||||
from mmdeploy.utils.config_utils import get_backend, load_config
|
||||
from mmdeploy.utils.constants import Backend
|
||||
from mmdeploy.utils.config_utils import load_config
|
||||
from mmdeploy.utils.timer import TimeCounter
|
||||
|
||||
|
||||
|
@ -89,8 +88,6 @@ def main():
|
|||
|
||||
# load deploy_cfg
|
||||
deploy_cfg, model_cfg = load_config(deploy_cfg_path, model_cfg_path)
|
||||
if get_backend(deploy_cfg) == Backend.SDK:
|
||||
model_cfg.test_cfg.type = 'DeployTestLoop'
|
||||
|
||||
# work_dir is determined in this priority: CLI > segment in file > filename
|
||||
if args.work_dir is not None:
|
||||
|
|
Loading…
Reference in New Issue