Support mmseg:dev-1.x (#790)

* support pspnet + ort

* add rewriting for adapt_avg_pool

* test pspnet

* resize seg_pred to original image shape

* run with test.py

* keep as original

* fix ut of segmentation

* update var name

* fix export to torchscript

* sync with mmseg:test-1.x branch

* fix ut

* fix regression test for mmseg

* fix mmseg.ops

* update mmseg yml

* fix mmseg2.0 sdk

* fix adaptive pool

* update rewriting and tests

* fix sdk inputs
This commit is contained in:
RunningLeon 2022-09-14 20:08:52 +08:00 committed by GitHub
parent 0aad6359e2
commit 06028d6a21
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
30 changed files with 641 additions and 1073 deletions

View File

@ -4,5 +4,7 @@ codebase_config = dict(model_type='sdk')
backend_config = dict(pipeline=[ backend_config = dict(pipeline=[
dict(type='LoadImageFromFile'), dict(type='LoadImageFromFile'),
dict(type='Collect', keys=['img'], meta_keys=['filename', 'ori_shape']) dict(type='LoadAnnotations'),
dict(
type='PackSegInputs', meta_keys=['img_path', 'ori_shape', 'img_shape'])
]) ])

View File

@ -78,7 +78,7 @@ class PipelineCaller:
call_id = self._call_id if call_id is None else call_id call_id = self._call_id if call_id is None else call_id
if call_id not in self._mp_dict: if call_id not in self._mp_dict:
get_root_logger().error( get_root_logger().error(
f'`{self._func_name}` with Call id: {call_id} failed. exit.') f'`{self._func_name}` with Call id: {call_id} failed.')
exit(1) exit(1)
ret = self._mp_dict[call_id] ret = self._mp_dict[call_id]
self._mp_dict.pop(call_id) self._mp_dict.pop(call_id)

View File

@ -42,7 +42,10 @@ def torch2torchscript(img: Any,
task_processor = build_task_processor(model_cfg, deploy_cfg, device) task_processor = build_task_processor(model_cfg, deploy_cfg, device)
torch_model = task_processor.build_pytorch_model(model_checkpoint) torch_model = task_processor.build_pytorch_model(model_checkpoint)
_, model_inputs = task_processor.create_input(img, input_shape) _, model_inputs = task_processor.create_input(
img,
input_shape,
data_preprocessor=getattr(torch_model, 'data_preprocessor', None))
if not isinstance(model_inputs, torch.Tensor): if not isinstance(model_inputs, torch.Tensor):
model_inputs = model_inputs[0] model_inputs = model_inputs[0]

View File

@ -91,7 +91,9 @@ class BaseTask(metaclass=ABCMeta):
nn.Module: An initialized torch model generated by other OpenMMLab nn.Module: An initialized torch model generated by other OpenMMLab
codebases. codebases.
""" """
from mmengine.model import revert_sync_batchnorm
from mmengine.registry import MODELS from mmengine.registry import MODELS
model = deepcopy(self.model_cfg.model) model = deepcopy(self.model_cfg.model)
preprocess_cfg = deepcopy(self.model_cfg.get('preprocess_cfg', {})) preprocess_cfg = deepcopy(self.model_cfg.get('preprocess_cfg', {}))
model.setdefault('data_preprocessor', preprocess_cfg) model.setdefault('data_preprocessor', preprocess_cfg)
@ -99,9 +101,10 @@ class BaseTask(metaclass=ABCMeta):
if model_checkpoint is not None: if model_checkpoint is not None:
from mmengine.runner.checkpoint import load_checkpoint from mmengine.runner.checkpoint import load_checkpoint
load_checkpoint(model, model_checkpoint) load_checkpoint(model, model_checkpoint)
model = revert_sync_batchnorm(model)
model = model.to(self.device) model = model.to(self.device)
model.eval() model.eval()
return model return model
def build_dataset(self, def build_dataset(self,
@ -280,7 +283,10 @@ class BaseTask(metaclass=ABCMeta):
visualizer = self.get_visualizer(window_name, save_dir) visualizer = self.get_visualizer(window_name, save_dir)
name = osp.splitext(save_name)[0] name = osp.splitext(save_name)[0]
if isinstance(image, str):
image = mmcv.imread(image, channel_order='rgb') image = mmcv.imread(image, channel_order='rgb')
assert isinstance(image, np.ndarray)
visualizer.add_datasample( visualizer.add_datasample(
name, name,
image, image,

View File

@ -1,5 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .mmsegmentation import MMSegmentation
from .segmentation import Segmentation from .segmentation import Segmentation
__all__ = ['MMSegmentation', 'Segmentation'] __all__ = ['Segmentation']

View File

@ -1,148 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Union
import mmengine
import torch
from mmcv.utils import Registry
from torch.utils.data import DataLoader, Dataset
from mmdeploy.codebase.base import CODEBASE, BaseTask, MMCodebase
from mmdeploy.utils import Codebase, get_task_type
def __build_mmseg_task(model_cfg: mmengine.Config, deploy_cfg: mmengine.Config,
device: str, registry: Registry) -> BaseTask:
task = get_task_type(deploy_cfg)
return registry.module_dict[task.value](model_cfg, deploy_cfg, device)
MMSEG_TASK = Registry('mmseg_tasks', build_func=__build_mmseg_task)
@CODEBASE.register_module(Codebase.MMSEG.value)
class MMSegmentation(MMCodebase):
"""mmsegmentation codebase class."""
task_registry = MMSEG_TASK
def __init__(self):
super(MMSegmentation, self).__init__()
@staticmethod
def build_task_processor(model_cfg: mmengine.Config,
deploy_cfg: mmengine.Config, device: str):
"""The interface to build the task processors of mmseg.
Args:
model_cfg (str | mmengine.Config): Model config file.
deploy_cfg (str | mmengine.Config): Deployment config file.
device (str): A string specifying device type.
Returns:
BaseTask: A task processor.
"""
return MMSEG_TASK.build(model_cfg, deploy_cfg, device)
@staticmethod
def build_dataset(dataset_cfg: Union[str, mmengine.Config],
dataset_type: str = 'val',
**kwargs) -> Dataset:
"""Build dataset for segmentation.
Args:
dataset_cfg (str | mmengine.Config): The input dataset config.
dataset_type (str): A string represents dataset type, e.g.: 'train'
, 'test', 'val'. Defaults to 'val'.
Returns:
Dataset: A PyTorch dataset.
"""
from mmseg.datasets import build_dataset as build_dataset_mmseg
assert dataset_type in dataset_cfg.data
data_cfg = dataset_cfg.data[dataset_type]
dataset = build_dataset_mmseg(data_cfg)
return dataset
@staticmethod
def build_dataloader(dataset: Dataset,
samples_per_gpu: int,
workers_per_gpu: int,
num_gpus: int = 1,
dist: bool = False,
shuffle: bool = False,
seed: Optional[int] = None,
drop_last: bool = False,
pin_memory: bool = True,
persistent_workers: bool = True,
**kwargs) -> DataLoader:
"""Build dataloader for segmentation.
Args:
dataset (Dataset): Input dataset.
samples_per_gpu (int): Number of training samples on each GPU, i.e.
,batch size of each GPU.
workers_per_gpu (int): How many subprocesses to use for data
loading for each GPU.
num_gpus (int): Number of GPUs. Only used in non-distributed
training. dist (bool): Distributed training/test or not.
Defaults to `False`.
dist (bool): Distributed training/test or not. Default: True.
shuffle (bool): Whether to shuffle the data at every epoch.
Defaults to `False`.
seed (int): An integer set to be seed. Default is `None`.
drop_last (bool): Whether to drop the last incomplete batch in
epoch. Default to `False`.
pin_memory (bool): Whether to use pin_memory in DataLoader.
Default is `True`.
persistent_workers (bool): If `True`, the data loader will not
shutdown the worker processes after a dataset has been
consumed once. This allows to maintain the workers Dataset
instances alive. The argument also has effect in
PyTorch>=1.7.0. Default is `True`.
kwargs: Any other keyword argument to be used to initialize
DataLoader.
Returns:
DataLoader: A PyTorch dataloader.
"""
from mmseg.datasets import build_dataloader as build_dataloader_mmseg
return build_dataloader_mmseg(
dataset,
samples_per_gpu,
workers_per_gpu,
num_gpus=num_gpus,
dist=dist,
shuffle=shuffle,
seed=seed,
drop_last=drop_last,
pin_memory=pin_memory,
persistent_workers=persistent_workers,
**kwargs)
@staticmethod
def single_gpu_test(model: torch.nn.Module,
data_loader: DataLoader,
show: bool = False,
out_dir: Optional[str] = None,
pre_eval: bool = True,
**kwargs):
"""Run test with single gpu.
Args:
model (torch.nn.Module): Input model from nn.Module.
data_loader (DataLoader): PyTorch data loader.
show (bool): Specifying whether to show plotted results. Defaults
to `False`.
out_dir (str): A directory to save results, defaults to `None`.
pre_eval (bool): Use dataset.pre_eval() function to generate
pre_results for metric evaluation. Mutually exclusive with
efficient_test and format_results. Default: False.
Returns:
list: The prediction results.
"""
from mmseg.apis import single_gpu_test
outputs = single_gpu_test(
model, data_loader, show, out_dir, pre_eval=pre_eval, **kwargs)
return outputs

View File

@ -1,15 +1,19 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Dict, Optional, Sequence, Tuple, Union import os.path as osp
from collections import defaultdict
from copy import deepcopy
from typing import Callable, Dict, Optional, Sequence, Tuple, Union
import mmcv import mmcv
import mmengine import mmengine
import numpy as np import numpy as np
import torch import torch
from torch.utils.data import Dataset from mmengine import Config
from mmengine.model import BaseDataPreprocessor
from mmengine.registry import Registry
from mmdeploy.codebase.base import BaseTask from mmdeploy.codebase.base import CODEBASE, BaseTask, MMCodebase
from mmdeploy.utils import Task, get_input_shape from mmdeploy.utils import Codebase, Task, get_input_shape, get_root_logger
from .mmsegmentation import MMSEG_TASK
def process_model_config(model_cfg: mmengine.Config, def process_model_config(model_cfg: mmengine.Config,
@ -27,22 +31,81 @@ def process_model_config(model_cfg: mmengine.Config,
Returns: Returns:
mmengine.Config: the model config after processing. mmengine.Config: the model config after processing.
""" """
from mmseg.apis.inference import LoadImage cfg = deepcopy(model_cfg)
cfg = model_cfg.copy()
if isinstance(imgs[0], np.ndarray): if isinstance(imgs[0], np.ndarray):
cfg = cfg.copy() cfg = cfg.copy()
# set loading pipeline type # set loading pipeline type
cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam' cfg.test_pipeline[0].type = 'LoadImageFromNDArray'
# remove some training related pipeline
removed_indices = []
for i in range(len(cfg.test_pipeline)):
if cfg.test_pipeline[i]['type'] in ['LoadAnnotations']:
removed_indices.append(i)
for i in reversed(removed_indices):
cfg.test_pipeline.pop(i)
# for static exporting # for static exporting
if input_shape is not None: if input_shape is not None:
cfg.data.test.pipeline[1]['img_scale'] = tuple(input_shape) found_resize = False
cfg.data.test.pipeline[1]['transforms'][0]['keep_ratio'] = False for i in range(len(cfg.test_pipeline)):
cfg.data.test.pipeline = [LoadImage()] + cfg.data.test.pipeline[1:] if 'Resize' == cfg.test_pipeline[i]['type']:
cfg.test_pipeline[i]['scale'] = tuple(input_shape)
cfg.test_pipeline[i]['keep_ratio'] = False
found_resize = True
if not found_resize:
logger = get_root_logger()
logger.warning(
f'Not found Resize in test_pipeline: {cfg.test_pipeline}')
return cfg return cfg
def _get_dataset_metainfo(model_cfg: Config):
"""Get metainfo of dataset.
Args:
model_cfg Config: Input model Config object.
Returns:
(list[str], list[np.ndarray]): Class names and palette
"""
from mmseg import datasets # noqa
from mmseg.registry import DATASETS
module_dict = DATASETS.module_dict
for dataloader_name in [
'test_dataloader', 'val_dataloader', 'train_dataloader'
]:
if dataloader_name not in model_cfg:
continue
dataloader_cfg = model_cfg[dataloader_name]
dataset_cfg = dataloader_cfg.dataset
dataset_mmseg = module_dict.get(dataset_cfg.type, None)
if dataset_mmseg is None:
continue
if hasattr(dataset_mmseg, '_load_metainfo') and isinstance(
dataset_mmseg._load_metainfo, Callable):
meta = dataset_mmseg._load_metainfo(
dataset_cfg.get('metainfo', None))
if meta is not None:
return meta
if hasattr(dataset_mmseg, 'METAINFO'):
return dataset_mmseg.METAINFO
return None
MMSEG_TASK = Registry('mmseg_tasks')
@CODEBASE.register_module(Codebase.MMSEG.value)
class MMSegmentation(MMCodebase):
"""mmsegmentation codebase class."""
task_registry = MMSEG_TASK
@MMSEG_TASK.register_module(Task.SEGMENTATION.value) @MMSEG_TASK.register_module(Task.SEGMENTATION.value)
class Segmentation(BaseTask): class Segmentation(BaseTask):
"""Segmentation task class. """Segmentation task class.
@ -70,43 +133,23 @@ class Segmentation(BaseTask):
nn.Module: An initialized backend model. nn.Module: An initialized backend model.
""" """
from .segmentation_model import build_segmentation_model from .segmentation_model import build_segmentation_model
data_preprocessor = self.model_cfg.model.data_preprocessor
model = build_segmentation_model( model = build_segmentation_model(
model_files, model_files,
self.model_cfg, self.model_cfg,
self.deploy_cfg, self.deploy_cfg,
device=self.device, device=self.device,
**kwargs) data_preprocessor=data_preprocessor)
return model.eval() model = model.to(self.device).eval()
return model
def build_pytorch_model(self, def create_input(
model_checkpoint: Optional[str] = None, self,
cfg_options: Optional[Dict] = None, imgs: Union[str, np.ndarray, Sequence],
**kwargs) -> torch.nn.Module: input_shape: Sequence[int] = None,
"""Initialize torch model. data_preprocessor: Optional[BaseDataPreprocessor] = None
) -> Tuple[Dict, torch.Tensor]:
Args:
model_checkpoint (str): The checkpoint file of torch model,
defaults to `None`.
cfg_options (dict): Optional config key-pair parameters.
Returns:
nn.Module: An initialized torch model generated by OpenMMLab
codebases.
"""
from mmcv.cnn.utils import revert_sync_batchnorm
if self.from_mmrazor:
from mmrazor.apis import init_mmseg_model as init_segmentor
else:
from mmseg.apis import init_segmentor
model = init_segmentor(self.model_cfg, model_checkpoint, self.device)
model = revert_sync_batchnorm(model)
return model.eval()
def create_input(self,
imgs: Union[str, np.ndarray],
input_shape: Sequence[int] = None) \
-> Tuple[Dict, torch.Tensor]:
"""Create input for segmentor. """Create input for segmentor.
Args: Args:
@ -118,43 +161,64 @@ class Segmentation(BaseTask):
Returns: Returns:
tuple: (data, img), meta information for the input image and input. tuple: (data, img), meta information for the input image and input.
""" """
from mmcv.parallel import collate, scatter from mmengine.dataset import Compose
from mmseg.datasets.pipelines import Compose
if not isinstance(imgs, (list, tuple)): if not isinstance(imgs, (tuple, list)):
imgs = [imgs] imgs = [imgs]
cfg = process_model_config(self.model_cfg, imgs, input_shape) cfg = process_model_config(self.model_cfg, imgs, input_shape)
test_pipeline = Compose(cfg.data.test.pipeline) test_pipeline = Compose(cfg.test_pipeline)
data_list = [] batch_data = defaultdict(list)
for img in imgs: for img in imgs:
# prepare data if isinstance(img, str):
data = dict(img_path=img)
else:
data = dict(img=img) data = dict(img=img)
# build the data pipeline
data = test_pipeline(data) data = test_pipeline(data)
data_list.append(data) batch_data['inputs'].append(data['inputs'])
batch_data['data_samples'].append(data['data_samples'])
data = collate(data_list, samples_per_gpu=len(imgs)) # batch_data = pseudo_collate([batch_data])
if data_preprocessor is not None:
batch_data = data_preprocessor(batch_data, False)
input_tensor = batch_data['inputs']
else:
input_tensor = BaseTask.get_tensor_from_input(batch_data)
return batch_data, input_tensor
data['img_metas'] = [ def get_visualizer(self, name: str, save_dir: str):
img_metas.data[0] for img_metas in data['img_metas'] """
]
data['img'] = [img.data[0][None, :] for img in data['img']]
if self.device != 'cpu':
data = scatter(data, [self.device])[0]
return data, data['img'] Args:
name:
save_dir:
Returns:
"""
# import to make SegLocalVisualizer could be built
from mmseg.visualization import SegLocalVisualizer # noqa: F401,F403
visualizer = super().get_visualizer(name, save_dir)
# force to change save_dir instead of
# save_dir/vis_data/vis_image/xx.jpg
visualizer._vis_backends['LocalVisBackend']._save_dir = save_dir
visualizer._vis_backends['LocalVisBackend']._img_save_dir = '.'
metainfo = _get_dataset_metainfo(self.model_cfg)
if metainfo is not None:
visualizer.dataset_meta = metainfo
return visualizer
def visualize(self, def visualize(self,
model,
image: Union[str, np.ndarray], image: Union[str, np.ndarray],
result: list, result: list,
output_file: str, output_file: str,
window_name: str = '', window_name: str = '',
show_result: bool = False, show_result: bool = False,
opacity: float = 0.5): opacity: float = 0.5,
"""Visualize predictions of a model. **kwargs):
"""Visualize segmentation predictions.
Args: Args:
model (nn.Module): Input model.
image (str | np.ndarray): Input image to draw predictions on. image (str | np.ndarray): Input image to draw predictions on.
result (list): A list of predictions. result (list): A list of predictions.
output_file (str): Output file to save drawn image. output_file (str): Output file to save drawn image.
@ -165,88 +229,18 @@ class Segmentation(BaseTask):
opacity: (float): Opacity of painted segmentation map. opacity: (float): Opacity of painted segmentation map.
Defaults to `0.5`. Defaults to `0.5`.
""" """
show_img = mmcv.imread(image) if isinstance(image, str) else image save_dir, filename = osp.split(output_file)
output_file = None if show_result else output_file visualizer = self.get_visualizer(window_name, save_dir)
# Need to wrapper the result with list for mmseg name = osp.splitext(filename)[0]
result = [result] if isinstance(image, str):
model.show_result( image = mmcv.imread(image, channel_order='rgb')
show_img, visualizer.add_datasample(
result, name, image, data_sample=result.cpu(), show=show_result)
out_file=output_file,
win_name=window_name,
show=show_result,
opacity=opacity)
@staticmethod
def run_inference(model, model_inputs: Dict[str, torch.Tensor]):
"""Run inference once for a segmentation model of mmseg.
Args:
model (nn.Module): Input model.
model_inputs (dict): A dict containing model inputs tensor and
meta info.
Returns:
list: The predictions of model inference.
"""
return model(**model_inputs, return_loss=False, rescale=True)
@staticmethod @staticmethod
def get_partition_cfg(partition_type: str) -> Dict: def get_partition_cfg(partition_type: str) -> Dict:
raise NotImplementedError('Not supported yet.') raise NotImplementedError('Not supported yet.')
@staticmethod
def get_tensor_from_input(input_data: Dict[str, Any]) -> torch.Tensor:
"""Get input tensor from input data.
Args:
input_data (dict): Input data containing meta info and image
tensor.
Returns:
torch.Tensor: An image in `Tensor`.
"""
return input_data['img'][0]
@staticmethod
def evaluate_outputs(model_cfg,
outputs: Sequence,
dataset: Dataset,
metrics: Optional[str] = None,
out: Optional[str] = None,
metric_options: Optional[dict] = None,
format_only: bool = False,
log_file: Optional[str] = None):
"""Perform post-processing to predictions of model.
Args:
outputs (list): A list of predictions of model inference.
dataset (Dataset): Input dataset to run test.
model_cfg (mmengine.Config): The model config.
metrics (str): Evaluation metrics, which depends on
the codebase and the dataset, e.g., e.g., "mIoU" for generic
datasets, and "cityscapes" for Cityscapes in mmseg.
out (str): Output result file in pickle format, defaults to `None`.
metric_options (dict): Custom options for evaluation, will be
kwargs for dataset.evaluate() function. Defaults to `None`.
format_only (bool): Format the output results without perform
evaluation. It is useful when you want to format the result
to a specific format and submit it to the test server. Defaults
to `False`.
log_file (str | None): The file to write the evaluation results.
Defaults to `None` and the results will only print on stdout.
"""
from mmcv.utils import get_logger
logger = get_logger('test', log_file=log_file)
if out:
logger.debug(f'writing results to {out}')
mmcv.dump(outputs, out)
kwargs = {} if metric_options is None else metric_options
if format_only:
dataset.format_results(outputs, **kwargs)
if metrics:
dataset.evaluate(outputs, metrics, logger=logger, **kwargs)
def get_preprocess(self) -> Dict: def get_preprocess(self) -> Dict:
"""Get the preprocess information for SDK. """Get the preprocess information for SDK.
@ -254,10 +248,28 @@ class Segmentation(BaseTask):
dict: Composed of the preprocess information. dict: Composed of the preprocess information.
""" """
input_shape = get_input_shape(self.deploy_cfg) input_shape = get_input_shape(self.deploy_cfg)
load_from_file = self.model_cfg.data.test.pipeline[0] load_from_file = self.model_cfg.test_pipeline[0]
model_cfg = process_model_config(self.model_cfg, [''], input_shape) model_cfg = process_model_config(self.model_cfg, [''], input_shape)
preprocess = model_cfg.data.test.pipeline preprocess = deepcopy(model_cfg.test_pipeline)
preprocess[0] = load_from_file preprocess[0] = load_from_file
dp = self.model_cfg.data_preprocessor
assert preprocess[1].type == 'Resize'
preprocess[1]['size'] = list(reversed(preprocess[1].pop('scale')))
if preprocess[-1].type == 'PackSegInputs':
preprocess[-1] = dict(
type='Normalize',
mean=dp.mean,
std=dp.std,
to_rgb=dp.bgr_to_rgb)
preprocess.append(dict(type='ImageToTensor', keys=['img']))
preprocess.append(
dict(
type='Collect',
keys=['img'],
meta_keys=[
'img_shape', 'pad_shape', 'ori_shape', 'img_norm_cfg',
'scale_factor'
]))
return preprocess return preprocess
def get_postprocess(self) -> Dict: def get_postprocess(self) -> Dict:

View File

@ -1,26 +1,23 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Sequence, Union from typing import List, Optional, Sequence, Union
import mmcv
import mmengine
import numpy as np
import torch import torch
from mmcv.utils import Registry from mmengine import Config
from mmseg.datasets import DATASETS from mmengine.model import BaseDataPreprocessor
from mmseg.models.segmentors.base import BaseSegmentor from mmengine.registry import Registry
from mmseg.ops import resize from mmengine.structures import BaseDataElement, PixelData
from torch import nn
from mmdeploy.codebase.base import BaseBackendModel from mmdeploy.codebase.base import BaseBackendModel
from mmdeploy.utils import (Backend, get_backend, get_codebase_config, from mmdeploy.utils import (Backend, get_backend, get_codebase_config,
load_config) get_root_logger, load_config)
def __build_backend_model(cls_name: str, registry: Registry, *args, **kwargs): def __build_backend_model(cls_name: str, registry: Registry, *args, **kwargs):
return registry.module_dict[cls_name](*args, **kwargs) return registry.module_dict[cls_name](*args, **kwargs)
__BACKEND_MODEL = mmcv.utils.Registry( __BACKEND_MODEL = Registry('backend_segmentors')
'backend_segmentors', build_func=__build_backend_model)
@__BACKEND_MODEL.register_module('end2end') @__BACKEND_MODEL.register_module('end2end')
@ -42,14 +39,13 @@ class End2EndModel(BaseBackendModel):
backend: Backend, backend: Backend,
backend_files: Sequence[str], backend_files: Sequence[str],
device: str, device: str,
class_names: Sequence[str], deploy_cfg: Union[str, Config] = None,
palette: np.ndarray, data_preprocessor: Optional[Union[dict, nn.Module]] = None,
deploy_cfg: Union[str, mmcv.Config] = None,
**kwargs): **kwargs):
super(End2EndModel, self).__init__(deploy_cfg=deploy_cfg) super(End2EndModel, self).__init__(
self.CLASSES = class_names deploy_cfg=deploy_cfg, data_preprocessor=data_preprocessor)
self.PALETTE = palette
self.deploy_cfg = deploy_cfg self.deploy_cfg = deploy_cfg
self.device = device
self._init_wrapper( self._init_wrapper(
backend=backend, backend=backend,
backend_files=backend_files, backend_files=backend_files,
@ -67,8 +63,10 @@ class End2EndModel(BaseBackendModel):
deploy_cfg=self.deploy_cfg, deploy_cfg=self.deploy_cfg,
**kwargs) **kwargs)
def forward(self, img: Sequence[torch.Tensor], def forward(self,
img_metas: Sequence[Sequence[dict]], *args, **kwargs): inputs: torch.Tensor,
data_samples: Optional[List[BaseDataElement]] = None,
mode: str = 'predict'):
"""Run forward inference. """Run forward inference.
Args: Args:
@ -82,78 +80,47 @@ class End2EndModel(BaseBackendModel):
Returns: Returns:
list: A list contains predictions. list: A list contains predictions.
""" """
input_img = img[0].contiguous() assert mode == 'predict', \
outputs = self.forward_test(input_img, img_metas, *args, **kwargs) 'Backend model only support mode==predict,' f' but get {mode}'
seg_pred = outputs[0] if inputs.device != torch.device(self.device):
# whole mode supports dynamic shape get_root_logger().warning(f'expect input device {self.device}'
ori_shape = img_metas[0][0]['ori_shape'] f' but get {inputs.device}.')
if not (ori_shape[0] == seg_pred.shape[-2] inputs = inputs.to(self.device)
and ori_shape[1] == seg_pred.shape[-1]): batch_outputs = self.wrapper({self.input_name:
seg_pred = torch.from_numpy(seg_pred).float() inputs})[self.output_names[0]]
return self.pack_result(batch_outputs, data_samples)
def pack_result(self, batch_outputs, data_samples):
predictions = []
for seg_pred, data_sample in zip(batch_outputs, data_samples):
# resize seg_pred to original image shape
metainfo = data_sample.metainfo
if metainfo['ori_shape'] != metainfo['img_shape']:
from mmseg.models.utils import resize
ori_type = seg_pred.dtype
seg_pred = resize( seg_pred = resize(
seg_pred, size=tuple(ori_shape[:2]), mode='nearest') seg_pred.unsqueeze(0).to(torch.float32),
seg_pred = seg_pred.long().detach().cpu().numpy() size=metainfo['ori_shape'],
# remove unnecessary dim mode='nearest').squeeze(0).to(ori_type)
seg_pred = seg_pred.squeeze(1) data_sample.set_data(
seg_pred = list(seg_pred) dict(pred_sem_seg=PixelData(**dict(data=seg_pred))))
return seg_pred predictions.append(data_sample)
def forward_test(self, imgs: torch.Tensor, *args, **kwargs) -> \ return predictions
List[np.ndarray]:
"""The interface for forward test.
Args:
imgs (torch.Tensor): Input image(s) in [N x C x H x W] format.
Returns:
List[np.ndarray]: A list of segmentation map.
"""
outputs = self.wrapper({self.input_name: imgs})
outputs = self.wrapper.output_to_list(outputs)
outputs = [out.detach().cpu().numpy() for out in outputs]
return outputs
def show_result(self,
img: np.ndarray,
result: list,
win_name: str = '',
palette: Optional[np.ndarray] = None,
show: bool = True,
opacity: float = 0.5,
out_file: str = None):
"""Show predictions of segmentation.
Args:
img: (np.ndarray): Input image to draw predictions.
result (list): A list of predictions.
win_name (str): The name of visualization window. Default is ''.
palette (np.ndarray): The palette of segmentation map.
show (bool): Whether to show plotted image in windows. Defaults to
`True`.
opacity: (float): Opacity of painted segmentation map.
Defaults to `0.5`.
out_file (str): Output image file to save drawn predictions.
Returns:
np.ndarray: Drawn image, only if not `show` or `out_file`.
"""
palette = self.PALETTE if palette is None else palette
return BaseSegmentor.show_result(
self,
img,
result,
palette=palette,
opacity=opacity,
show=show,
win_name=win_name,
out_file=out_file)
@__BACKEND_MODEL.register_module('sdk') @__BACKEND_MODEL.register_module('sdk')
class SDKEnd2EndModel(End2EndModel): class SDKEnd2EndModel(End2EndModel):
"""SDK inference class, converts SDK output to mmseg format.""" """SDK inference class, converts SDK output to mmseg format."""
def forward(self, img: Sequence[torch.Tensor], def __init__(self, *args, **kwargs):
img_metas: Sequence[Sequence[dict]], *args, **kwargs): kwargs['data_preprocessor'] = None
super(SDKEnd2EndModel, self).__init__(*args, **kwargs)
def forward(self,
inputs: torch.Tensor,
data_samples: Optional[List[BaseDataElement]] = None,
mode: str = 'predict'):
"""Run forward inference. """Run forward inference.
Args: Args:
@ -167,42 +134,26 @@ class SDKEnd2EndModel(End2EndModel):
Returns: Returns:
list: A list contains predictions. list: A list contains predictions.
""" """
masks = self.wrapper.invoke(img[0].contiguous().detach().cpu().numpy()) if isinstance(inputs, list):
return masks inputs = inputs[0]
# inputs are c,h,w, sdk requested h,w,c
inputs = inputs.permute(1, 2, 0)
outputs = self.wrapper.invoke(
inputs.contiguous().detach().cpu().numpy())
batch_outputs = torch.from_numpy(outputs).to(torch.int64).to(
self.device)
batch_outputs = batch_outputs.unsqueeze(0).unsqueeze(0)
return self.pack_result(batch_outputs, data_samples)
def get_classes_palette_from_config(model_cfg: Union[str, mmengine.Config]): def build_segmentation_model(
"""Get class name and palette from config. model_files: Sequence[str],
model_cfg: Union[str, Config],
Args: deploy_cfg: Union[str, Config],
model_cfg (str | mmengine.Config): Input model config file or device: str,
Config object. data_preprocessor: Optional[Union[Config,
Returns: BaseDataPreprocessor]] = None,
tuple(Sequence[str], np.ndarray): A list of string specifying names of **kwargs):
different class and the palette of segmentation map.
"""
# load cfg if necessary
model_cfg = load_config(model_cfg)[0]
module_dict = DATASETS.module_dict
data_cfg = model_cfg.data
if 'val' in data_cfg:
module = module_dict[data_cfg.val.type]
elif 'test' in data_cfg:
module = module_dict[data_cfg.test.type]
elif 'train' in data_cfg:
module = module_dict[data_cfg.train.type]
else:
raise RuntimeError(f'No dataset config found in: {model_cfg}')
return module.CLASSES, module.PALETTE
def build_segmentation_model(model_files: Sequence[str],
model_cfg: Union[str, mmengine.Config],
deploy_cfg: Union[str, mmengine.Config],
device: str, **kwargs):
"""Build object segmentation model for different backends. """Build object segmentation model for different backends.
Args: Args:
@ -212,25 +163,25 @@ def build_segmentation_model(model_files: Sequence[str],
deploy_cfg (str | mmengine.Config): Input deployment config file or deploy_cfg (str | mmengine.Config): Input deployment config file or
Config object. Config object.
device (str): Device to input model. device (str): Device to input model.
data_preprocessor (BaseDataPreprocessor | Config): The data
preprocessor of the model.
Returns: Returns:
BaseBackendModel: Segmentor for a configured backend. BaseBackendModel: Segmentor for a configured backend.
""" """
# load cfg if necessary # load cfg if necessary
deploy_cfg, model_cfg = load_config(deploy_cfg, model_cfg) deploy_cfg, model_cfg = load_config(deploy_cfg, model_cfg)
backend = get_backend(deploy_cfg) backend = get_backend(deploy_cfg)
model_type = get_codebase_config(deploy_cfg).get('model_type', 'end2end') model_type = get_codebase_config(deploy_cfg).get('model_type', 'end2end')
class_names, palette = get_classes_palette_from_config(model_cfg)
backend_segmentor = __BACKEND_MODEL.build( backend_segmentor = __BACKEND_MODEL.build(
model_type, dict(
type=model_type,
backend=backend, backend=backend,
backend_files=model_files, backend_files=model_files,
device=device, device=device,
class_names=class_names,
palette=palette,
deploy_cfg=deploy_cfg, deploy_cfg=deploy_cfg,
**kwargs) data_preprocessor=data_preprocessor,
**kwargs))
return backend_segmentor return backend_segmentor

View File

@ -1,6 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .aspp_head import aspp_head__forward
from .ema_head import ema_module__forward from .ema_head import ema_module__forward
from .psp_head import ppm__forward
__all__ = ['aspp_head__forward', 'ppm__forward', 'ema_module__forward'] __all__ = ['ema_module__forward']

View File

@ -1,43 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmseg.ops import resize
from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.utils import is_dynamic_shape
@FUNCTION_REWRITER.register_rewriter(
func_name='mmseg.models.decode_heads.ASPPHead.forward')
def aspp_head__forward(ctx, self, inputs):
"""Rewrite `forward` for default backend.
Support configured dynamic/static shape in resize op.
Args:
ctx (ContextCaller): The context with additional information.
self: The instance of the original class.
inputs (list[Tensor]): List of multi-level img features.
Returns:
torch.Tensor: Output segmentation map.
"""
x = self._transform_inputs(inputs)
deploy_cfg = ctx.cfg
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
# get origin input shape as tensor to support onnx dynamic shape
size = x.shape[2:]
if not is_dynamic_flag:
size = [int(val) for val in size]
aspp_outs = [
resize(
self.image_pool(x),
size=size,
mode='bilinear',
align_corners=self.align_corners)
]
aspp_outs.extend(self.aspp_modules(x))
aspp_outs = torch.cat(aspp_outs, dim=1)
output = self.bottleneck(aspp_outs)
output = self.cls_seg(output)
return output

View File

@ -1,52 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmseg.ops import resize
from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.utils import IR, get_root_logger, is_dynamic_shape
@FUNCTION_REWRITER.register_rewriter(
func_name='mmseg.models.decode_heads.psp_head.PPM.forward', ir=IR.ONNX)
def ppm__forward(ctx, self, x):
"""Rewrite `forward` for default backend.
Support configured dynamic/static shape in resize op.
Args:
ctx (ContextCaller): The context with additional information.
self: The instance of the original class.
x (Tensor): The transformed input feature.
Returns:
List[torch.Tensor]: Up-sampled segmentation maps of different
scales.
"""
deploy_cfg = ctx.cfg
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
# get origin input shape as tensor to support onnx dynamic shape
size = x.shape[2:]
if not is_dynamic_flag:
size = [int(val) for val in size]
ppm_outs = []
for ppm in self:
if isinstance(ppm[0], nn.AdaptiveAvgPool2d) and \
ppm[0].output_size != 1:
if is_dynamic_flag:
logger = get_root_logger()
logger.warning('`AdaptiveAvgPool2d` would be '
'replaced to `AvgPool2d` explicitly')
# replace AdaptiveAvgPool2d with AvgPool2d explicitly
output_size = 2 * [ppm[0].output_size]
k = [int(size[i] / output_size[i]) for i in range(0, len(size))]
ppm[0] = nn.AvgPool2d(k, stride=k, padding=0, ceil_mode=False)
ppm_out = ppm(x)
upsampled_ppm_out = resize(
ppm_out,
size=size,
mode='bilinear',
align_corners=self.align_corners)
ppm_outs.append(upsampled_ppm_out)
return ppm_outs

View File

@ -1,5 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .base import base_segmentor__forward from .base import base_segmentor__forward
from .encoder_decoder import encoder_decoder__simple_test from .cascade_encoder_decoder import cascade_encoder_decoder__predict
from .encoder_decoder import encoder_decoder__predict
__all__ = ['base_segmentor__forward', 'encoder_decoder__simple_test'] __all__ = [
'base_segmentor__forward', 'encoder_decoder__predict',
'cascade_encoder_decoder__predict'
]

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import torch from mmseg.structures import SegDataSample
from mmdeploy.core import FUNCTION_REWRITER from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.utils import is_dynamic_shape from mmdeploy.utils import is_dynamic_shape
@ -7,7 +7,12 @@ from mmdeploy.utils import is_dynamic_shape
@FUNCTION_REWRITER.register_rewriter( @FUNCTION_REWRITER.register_rewriter(
func_name='mmseg.models.segmentors.BaseSegmentor.forward') func_name='mmseg.models.segmentors.BaseSegmentor.forward')
def base_segmentor__forward(ctx, self, img, img_metas=None, **kwargs): def base_segmentor__forward(ctx,
self,
inputs,
data_samples=None,
mode='predict',
**kwargs):
"""Rewrite `forward` for default backend. """Rewrite `forward` for default backend.
Support configured dynamic/static shape for model input. Support configured dynamic/static shape for model input.
@ -15,27 +20,23 @@ def base_segmentor__forward(ctx, self, img, img_metas=None, **kwargs):
Args: Args:
ctx (ContextCaller): The context with additional information. ctx (ContextCaller): The context with additional information.
self: The instance of the original class. self: The instance of the original class.
img (Tensor | List[Tensor]): Input image tensor(s). inputs (Tensor | List[Tensor]): Input image tensor(s).
img_metas (List[dict]): List of dicts containing image's meta data_samples (List[dict]): List of dicts containing image's meta
information such as `img_shape`. information such as `img_shape`.
Returns: Returns:
torch.Tensor: Output segmentation map pf shape [N, 1, H, W]. torch.Tensor: Output segmentation map pf shape [N, 1, H, W].
""" """
if img_metas is None: if data_samples is None:
img_metas = [{}] data_samples = [SegDataSample()]
else:
assert len(img_metas) == 1, 'do not support aug_test'
img_metas = img_metas[0]
if isinstance(img, list):
img = img[0]
assert isinstance(img, torch.Tensor)
deploy_cfg = ctx.cfg deploy_cfg = ctx.cfg
is_dynamic_flag = is_dynamic_shape(deploy_cfg) is_dynamic_flag = is_dynamic_shape(deploy_cfg)
# get origin input shape as tensor to support onnx dynamic shape # get origin input shape as tensor to support onnx dynamic shape
img_shape = img.shape[2:] img_shape = inputs.shape[2:]
if not is_dynamic_flag: if not is_dynamic_flag:
img_shape = [int(val) for val in img_shape] img_shape = [int(val) for val in img_shape]
img_metas[0]['img_shape'] = img_shape for data_sample in data_samples:
return self.simple_test(img, img_metas, **kwargs) data_sample.set_field(
name='img_shape', value=img_shape, field_type='metainfo')
return self.predict(inputs, data_samples)

View File

@ -0,0 +1,33 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
func_name='mmseg.models.segmentors.CascadeEncoderDecoder.predict')
def cascade_encoder_decoder__predict(ctx, self, inputs, data_samples,
**kwargs):
"""Rewrite `predict` for default backend.
1. only support mode=`whole` inference
2. skip calling self.postprocess_result
Args:
ctx (ContextCaller): The context with additional information.
self: The instance of the original class.
inputs (Tensor): Inputs with shape (N, C, H, W).
data_samples (SampleList): The seg data samples.
Returns:
torch.Tensor: Output segmentation map pf shape [N, 1, H, W].
"""
batch_img_metas = []
for data_sample in data_samples:
batch_img_metas.append(data_sample.metainfo)
x = self.extract_feat(inputs)
out = self.decode_head[0].forward(x)
for i in range(1, self.num_stages - 1):
out = self.decode_head[i].forward(x, out)
seg_logit = self.decode_head[-1].predict(x, out, batch_img_metas,
self.test_cfg)
seg_pred = seg_logit.argmax(dim=1, keepdim=True)
return seg_pred

View File

@ -1,28 +1,28 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import torch.nn.functional as F
from mmdeploy.core import FUNCTION_REWRITER from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter( @FUNCTION_REWRITER.register_rewriter(
func_name='mmseg.models.segmentors.EncoderDecoder.simple_test') func_name='mmseg.models.segmentors.EncoderDecoder.predict')
def encoder_decoder__simple_test(ctx, self, img, img_meta, **kwargs): def encoder_decoder__predict(ctx, self, inputs, data_samples, **kwargs):
"""Rewrite `simple_test` for default backend. """Rewrite `predict` for default backend.
Support configured dynamic/static shape for model input and return 1. only support mode=`whole` inference
segmentation map as Tensor instead of numpy array. 2. skip calling self.postprocess_result
Args: Args:
ctx (ContextCaller): The context with additional information. ctx (ContextCaller): The context with additional information.
self: The instance of the original class. self: The instance of the original class.
img (Tensor | List[Tensor]): Input image tensor(s). inputs (Tensor): Inputs with shape (N, C, H, W).
img_meta (dict): Dict containing image's meta information data_samples (SampleList): The seg data samples.
such as `img_shape`.
Returns: Returns:
torch.Tensor: Output segmentation map pf shape [N, 1, H, W]. torch.Tensor: Output segmentation map pf shape [N, 1, H, W].
""" """
seg_logit = self.encode_decode(img, img_meta) batch_img_metas = []
seg_logit = F.softmax(seg_logit, dim=1) for data_sample in data_samples:
batch_img_metas.append(data_sample.metainfo)
x = self.extract_feat(inputs)
seg_logit = self.decode_head.predict(x, batch_img_metas, self.test_cfg)
seg_pred = seg_logit.argmax(dim=1, keepdim=True) seg_pred = seg_logit.argmax(dim=1, keepdim=True)
return seg_pred return seg_pred

View File

@ -28,7 +28,7 @@ def up_conv_block__forward(ctx, self, skip, x):
# only valid when self.upsample is from build_upsample_layer # only valid when self.upsample is from build_upsample_layer
if is_dynamic_shape(ctx.cfg) and not isinstance(self.upsample, ConvModule): if is_dynamic_shape(ctx.cfg) and not isinstance(self.upsample, ConvModule):
# upsample with `size` instead of `scale_factor` # upsample with `size` instead of `scale_factor`
from mmseg.ops import Upsample from mmseg.models.utils import Upsample
for c in self.upsample.interp_upsample: for c in self.upsample.interp_upsample:
if isinstance(c, Upsample): if isinstance(c, Upsample):
c.size = skip.shape[-2:] c.size = skip.shape[-2:]

View File

@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .adaptive_pool import (adaptive_avg_pool2d__default,
adaptive_avg_pool2d__ncnn)
from .atan2 import atan2__default from .atan2 import atan2__default
from .chunk import chunk__ncnn, chunk__torchscript from .chunk import chunk__ncnn, chunk__torchscript
from .expand import expand__ncnn from .expand import expand__ncnn
@ -18,7 +20,8 @@ __all__ = [
'tensor__getattribute__ncnn', 'group_norm__ncnn', 'interpolate__ncnn', 'tensor__getattribute__ncnn', 'group_norm__ncnn', 'interpolate__ncnn',
'interpolate__tensorrt', 'linear__ncnn', 'tensor__repeat__tensorrt', 'interpolate__tensorrt', 'linear__ncnn', 'tensor__repeat__tensorrt',
'tensor__size__ncnn', 'topk__dynamic', 'topk__tensorrt', 'chunk__ncnn', 'tensor__size__ncnn', 'topk__dynamic', 'topk__tensorrt', 'chunk__ncnn',
'triu__default', 'atan2__default', 'normalize__ncnn', 'expand__ncnn', 'triu__default', 'atan2__default', 'adaptive_avg_pool2d__default',
'chunk__torchscript', 'masked_fill__onnxruntime', 'normalize__ncnn', 'expand__ncnn', 'chunk__torchscript',
'tensor__setitem__default' 'masked_fill__onnxruntime', 'tensor__setitem__default',
'adaptive_avg_pool2d__ncnn'
] ]

View File

@ -0,0 +1,44 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn.functional as F
from torch.nn.modules.utils import _pair
from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.utils import Backend, get_root_logger, is_dynamic_shape
@FUNCTION_REWRITER.register_rewriter(
func_name='torch.nn.functional.adaptive_avg_pool2d')
def adaptive_avg_pool2d__default(ctx, input, output_size):
"""Rewrite `adaptive_avg_pool2d` for default backend."""
output_size = _pair(output_size)
if int(output_size[0]) == int(output_size[1]) == 1:
out = ctx.origin_func(input, output_size)
else:
deploy_cfg = ctx.cfg
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
if is_dynamic_flag:
logger = get_root_logger()
logger.warning('`adaptive_avg_pool2d` would be '
'replaced to `avg_pool2d` explicitly')
size = input.shape[2:]
k = [int(size[i] / output_size[i]) for i in range(0, len(size))]
out = F.avg_pool2d(
input,
kernel_size=k,
stride=k,
padding=0,
ceil_mode=False,
count_include_pad=False)
return out
@FUNCTION_REWRITER.register_rewriter(
func_name='torch.nn.functional.adaptive_avg_pool2d',
backend=Backend.NCNN.value)
@FUNCTION_REWRITER.register_rewriter(
func_name='torch.nn.functional.adaptive_avg_pool2d',
backend=Backend.TORCHSCRIPT.value)
def adaptive_avg_pool2d__ncnn(ctx, input, output_size):
"""Rewrite `adaptive_avg_pool2d` for ncnn and torchscript backend."""
return ctx.origin_func(input, output_size)

View File

@ -1,8 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .adaptive_avg_pool import (adaptive_avg_pool1d__default, from .adaptive_pool import adaptive_avg_pool2d__ncnn
adaptive_avg_pool2d__default,
adaptive_avg_pool2d__ncnn,
adaptive_avg_pool3d__default)
from .gelu import gelu__ncnn from .gelu import gelu__ncnn
from .grid_sampler import grid_sampler__default from .grid_sampler import grid_sampler__default
from .hardsigmoid import hardsigmoid__default from .hardsigmoid import hardsigmoid__default
@ -15,10 +12,8 @@ from .roll import roll_default
from .squeeze import squeeze__default from .squeeze import squeeze__default
__all__ = [ __all__ = [
'adaptive_avg_pool1d__default', 'adaptive_avg_pool2d__default', 'grid_sampler__default', 'hardsigmoid__default', 'instance_norm__tensorrt',
'adaptive_avg_pool3d__default', 'grid_sampler__default', 'generic_rnn__ncnn', 'squeeze__default', 'adaptive_avg_pool2d__ncnn',
'hardsigmoid__default', 'instance_norm__tensorrt', 'generic_rnn__ncnn', 'gelu__ncnn', 'layer_norm__ncnn', 'linear__ncnn',
'squeeze__default', 'adaptive_avg_pool2d__ncnn', 'gelu__ncnn', '_prepare_onnx_paddings__tensorrt', 'roll_default'
'layer_norm__ncnn', 'linear__ncnn', '_prepare_onnx_paddings__tensorrt',
'roll_default'
] ]

View File

@ -1,90 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# Modified from:
# https://github.com/pytorch/pytorch/blob/9ade03959392e5a90b74261012de1d806cab2253/torch/onnx/symbolic_opset9.py
from torch.nn.modules.utils import _pair, _single, _triple
from torch.onnx.symbolic_helper import parse_args
from mmdeploy.core import SYMBOLIC_REWRITER
def _adaptive_pool(name, type, tuple_fn, fn=None):
"""Generic adaptive pooling."""
@parse_args('v', 'is')
def symbolic_fn(g, input, output_size):
if output_size == [1] * len(output_size) and type == 'AveragePool':
return g.op('GlobalAveragePool', input)
if not input.isCompleteTensor():
if output_size == [1] * len(output_size):
return g.op('GlobalMaxPool', input), None
raise NotImplementedError(
'[Adaptive pool]:input size not accessible')
dim = input.type().sizes()[2:]
if output_size == [1] * len(output_size) and type == 'MaxPool':
return g.op('GlobalMaxPool', input), None
# compute stride = floor(input_size / output_size)
s = [int(dim[i] / output_size[i]) for i in range(0, len(dim))]
# compute kernel_size = input_size - (output_size - 1) * stride
k = [dim[i] - (output_size[i] - 1) * s[i] for i in range(0, len(dim))]
# call max_poolxd_with_indices to get indices in the output
if type == 'MaxPool':
return fn(g, input, k, k, (0, ) * len(dim), (1, ) * len(dim),
False)
output = g.op(
type,
input,
kernel_shape_i=tuple_fn(k),
strides_i=tuple_fn(s),
ceil_mode_i=False)
return output
return symbolic_fn
adaptive_avg_pool1d = _adaptive_pool('adaptive_avg_pool1d', 'AveragePool',
_single)
adaptive_avg_pool2d = _adaptive_pool('adaptive_avg_pool2d', 'AveragePool',
_pair)
adaptive_avg_pool3d = _adaptive_pool('adaptive_avg_pool3d', 'AveragePool',
_triple)
@SYMBOLIC_REWRITER.register_symbolic('adaptive_avg_pool1d', is_pytorch=True)
def adaptive_avg_pool1d__default(ctx, *args):
"""Register default symbolic function for `adaptive_avg_pool1d`.
Align symbolic of adaptive_pool between different torch version.
"""
return adaptive_avg_pool1d(*args)
@SYMBOLIC_REWRITER.register_symbolic('adaptive_avg_pool2d', is_pytorch=True)
def adaptive_avg_pool2d__default(ctx, *args):
"""Register default symbolic function for `adaptive_avg_pool2d`.
Align symbolic of adaptive_pool between different torch version.
"""
return adaptive_avg_pool2d(*args)
@SYMBOLIC_REWRITER.register_symbolic('adaptive_avg_pool3d', is_pytorch=True)
def adaptive_avg_pool3d__default(ctx, *args):
"""Register default symbolic function for `adaptive_avg_pool3d`.
Align symbolic of adaptive_pool between different torch version.
"""
return adaptive_avg_pool3d(*args)
@SYMBOLIC_REWRITER.register_symbolic(
'adaptive_avg_pool2d', is_pytorch=True, backend='ncnn')
def adaptive_avg_pool2d__ncnn(ctx, g, x, output_size):
"""Register ncnn symbolic function for `adaptive_avg_pool2d`.
Align symbolic of adaptive_avg_pool2d in ncnn.
"""
return g.op('mmdeploy::AdaptiveAvgPool2d', x, output_size)

View File

@ -0,0 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmdeploy.core import SYMBOLIC_REWRITER
@SYMBOLIC_REWRITER.register_symbolic(
'adaptive_avg_pool2d', is_pytorch=True, backend='ncnn')
def adaptive_avg_pool2d__ncnn(ctx, g, x, output_size):
"""Register ncnn symbolic function for `adaptive_avg_pool2d`.
Align symbolic of adaptive_avg_pool2d in ncnn.
"""
return g.op('mmdeploy::AdaptiveAvgPool2d', x, output_size)

View File

@ -92,7 +92,8 @@ class TimeCounter:
warmup: int = 1, warmup: int = 1,
log_interval: int = 1, log_interval: int = 1,
with_sync: bool = False, with_sync: bool = False,
file: Optional[str] = None): file: Optional[str] = None,
logger=None):
"""Activate the time counter. """Activate the time counter.
Args: Args:
@ -106,6 +107,7 @@ class TimeCounter:
is `None`. is `None`.
""" """
assert warmup >= 1 assert warmup >= 1
if logger is None:
logger = get_logger('test', log_file=file) logger = get_logger('test', log_file=file)
cls.logger = logger cls.logger = logger
if func_name is not None: if func_name is not None:

View File

@ -2,12 +2,9 @@ globals:
codebase_dir: ../mmsegmentation codebase_dir: ../mmsegmentation
checkpoint_force_download: False checkpoint_force_download: False
images: images:
img_leftImg8bit: &img_leftImg8bit ../mmsegmentation/tests/data/pseudo_cityscapes_dataset/leftImg8bit/frankfurt_000000_000294_leftImg8bit.png img_leftImg8bit: &img_leftImg8bit ../mmsegmentation/tests/data/pseudo_cityscapes_dataset/leftImg8bit/val/frankfurt/frankfurt_000000_000294_leftImg8bit.png
img_loveda_0: &img_loveda_0 ../mmsegmentation/tests/data/pseudo_loveda_dataset/img_dir/0.png img_loveda_0: &img_loveda_0 ../mmsegmentation/tests/data/pseudo_loveda_dataset/img_dir/0.png
img_loveda_1: &img_loveda_1 ../mmsegmentation/tests/data/pseudo_loveda_dataset/img_dir/1.png
img_loveda_2: &img_loveda_2 ../mmsegmentation/tests/data/pseudo_loveda_dataset/img_dir/2.png
img_potsdam: &img_potsdam ../mmsegmentation/tests/data/pseudo_potsdam_dataset/img_dir/2_10_0_0_512_512.png
img_vaihingen: &img_vaihingen ../mmsegmentation/tests/data/pseudo_vaihingen_dataset/img_dir/area1_0_0_512_512.png
metric_info: &metric_info metric_info: &metric_info
mIoU: # named after metafile.Results.Metrics mIoU: # named after metafile.Results.Metrics
eval_name: mIoU # test.py --metrics args eval_name: mIoU # test.py --metrics args
@ -25,13 +22,9 @@ globals:
onnxruntime: onnxruntime:
pipeline_ort_static_fp32: &pipeline_ort_static_fp32 pipeline_ort_static_fp32: &pipeline_ort_static_fp32
convert_image: *convert_image convert_image: *convert_image
backend_test: *default_backend_test
sdk_config: *sdk_dynamic
deploy_config: configs/mmseg/segmentation_onnxruntime_static-1024x2048.py deploy_config: configs/mmseg/segmentation_onnxruntime_static-1024x2048.py
pipeline_ort_static_fp32_512x512: &pipeline_ort_static_fp32_512x512 pipeline_ort_static_fp32_512x512: &pipeline_ort_static_fp32_512x512
convert_image: *convert_image
backend_test: False
deploy_config: configs/mmseg/segmentation_onnxruntime_static-512x512.py deploy_config: configs/mmseg/segmentation_onnxruntime_static-512x512.py
pipeline_ort_dynamic_fp32: &pipeline_ort_dynamic_fp32 pipeline_ort_dynamic_fp32: &pipeline_ort_dynamic_fp32
@ -129,13 +122,11 @@ models:
- name: FCN - name: FCN
metafile: configs/fcn/fcn.yml metafile: configs/fcn/fcn.yml
model_configs: model_configs:
- configs/fcn/fcn_r50-d8_512x1024_40k_cityscapes.py - configs/fcn/fcn_r50-d8_4xb2-40k_cityscapes-512x1024.py
pipelines: pipelines:
- *pipeline_ts_fp32 - *pipeline_ts_fp32
- *pipeline_ort_dynamic_fp32 - *pipeline_ort_dynamic_fp32
# - *pipeline_trt_dynamic_fp32
- *pipeline_trt_dynamic_fp16 - *pipeline_trt_dynamic_fp16
# - *pipeline_trt_dynamic_int8
- *pipeline_ncnn_static_fp32 - *pipeline_ncnn_static_fp32
- *pipeline_pplnn_dynamic_fp32 - *pipeline_pplnn_dynamic_fp32
- *pipeline_openvino_dynamic_fp32 - *pipeline_openvino_dynamic_fp32
@ -143,7 +134,7 @@ models:
- name: PSPNet - name: PSPNet
metafile: configs/pspnet/pspnet.yml metafile: configs/pspnet/pspnet.yml
model_configs: model_configs:
- configs/pspnet/pspnet_r50-d8_512x1024_80k_cityscapes.py - configs/pspnet/pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py
pipelines: pipelines:
- *pipeline_ts_fp32 - *pipeline_ts_fp32
- *pipeline_ort_static_fp32 - *pipeline_ort_static_fp32
@ -155,7 +146,7 @@ models:
- name: deeplabv3 - name: deeplabv3
metafile: configs/deeplabv3/deeplabv3.yml metafile: configs/deeplabv3/deeplabv3.yml
model_configs: model_configs:
- configs/deeplabv3/deeplabv3_r50-d8_512x1024_40k_cityscapes.py - configs/deeplabv3/deeplabv3_r50-d8_4xb2-40k_cityscapes-512x1024.py
pipelines: pipelines:
- *pipeline_ts_fp32 - *pipeline_ts_fp32
- *pipeline_ort_dynamic_fp32 - *pipeline_ort_dynamic_fp32
@ -167,7 +158,7 @@ models:
- name: deeplabv3+ - name: deeplabv3+
metafile: configs/deeplabv3plus/deeplabv3plus.yml metafile: configs/deeplabv3plus/deeplabv3plus.yml
model_configs: model_configs:
- configs/deeplabv3plus/deeplabv3plus_r50-d8_512x1024_40k_cityscapes.py - configs/deeplabv3plus/deeplabv3plus_r50-d8_4xb2-40k_cityscapes-512x1024.py
pipelines: pipelines:
- *pipeline_ts_fp32 - *pipeline_ts_fp32
- *pipeline_ort_dynamic_fp32 - *pipeline_ort_dynamic_fp32
@ -179,7 +170,7 @@ models:
- name: Fast-SCNN - name: Fast-SCNN
metafile: configs/fastscnn/fastscnn.yml metafile: configs/fastscnn/fastscnn.yml
model_configs: model_configs:
- configs/fastscnn/fast_scnn_lr0.12_8x4_160k_cityscapes.py - configs/fastscnn/fast_scnn_8xb4-160k_cityscapes-512x1024.py
pipelines: pipelines:
- *pipeline_ts_fp32 - *pipeline_ts_fp32
- *pipeline_ort_static_fp32 - *pipeline_ort_static_fp32
@ -190,7 +181,7 @@ models:
- name: UNet - name: UNet
metafile: configs/unet/unet.yml metafile: configs/unet/unet.yml
model_configs: model_configs:
- configs/unet/fcn_unet_s5-d16_4x4_512x1024_160k_cityscapes.py - configs/unet/unet-s5-d16_fcn_4xb4-160k_cityscapes-512x1024.py
pipelines: pipelines:
- *pipeline_ts_fp32 - *pipeline_ts_fp32
- *pipeline_ort_dynamic_fp32 - *pipeline_ort_dynamic_fp32
@ -202,7 +193,7 @@ models:
- name: ANN - name: ANN
metafile: configs/ann/ann.yml metafile: configs/ann/ann.yml
model_configs: model_configs:
- configs/ann/ann_r50-d8_512x1024_40k_cityscapes.py - configs/ann/ann_r50-d8_4xb2-40k_cityscapes-512x1024.py
pipelines: pipelines:
- *pipeline_ort_static_fp32 - *pipeline_ort_static_fp32
- *pipeline_trt_static_fp16 - *pipeline_trt_static_fp16
@ -210,7 +201,7 @@ models:
- name: APCNet - name: APCNet
metafile: configs/apcnet/apcnet.yml metafile: configs/apcnet/apcnet.yml
model_configs: model_configs:
- configs/apcnet/apcnet_r50-d8_512x1024_40k_cityscapes.py - configs/apcnet/apcnet_r50-d8_4xb2-40k_cityscapes-512x1024.py
pipelines: pipelines:
- *pipeline_ort_dynamic_fp32 - *pipeline_ort_dynamic_fp32
- *pipeline_trt_dynamic_fp16 - *pipeline_trt_dynamic_fp16
@ -219,7 +210,7 @@ models:
- name: BiSeNetV1 - name: BiSeNetV1
metafile: configs/bisenetv1/bisenetv1.yml metafile: configs/bisenetv1/bisenetv1.yml
model_configs: model_configs:
- configs/bisenetv1/bisenetv1_r18-d32_4x4_1024x1024_160k_cityscapes.py - configs/bisenetv1/bisenetv1_r18-d32_4xb4-160k_cityscapes-1024x1024.py
pipelines: pipelines:
- *pipeline_ort_dynamic_fp32 - *pipeline_ort_dynamic_fp32
- *pipeline_trt_dynamic_fp16 - *pipeline_trt_dynamic_fp16
@ -229,7 +220,7 @@ models:
- name: BiSeNetV2 - name: BiSeNetV2
metafile: configs/bisenetv2/bisenetv2.yml metafile: configs/bisenetv2/bisenetv2.yml
model_configs: model_configs:
- configs/bisenetv2/bisenetv2_fcn_4x4_1024x1024_160k_cityscapes.py - configs/bisenetv2/bisenetv2_fcn_4xb4-160k_cityscapes-1024x1024.py
pipelines: pipelines:
- *pipeline_ort_dynamic_fp32 - *pipeline_ort_dynamic_fp32
- *pipeline_trt_dynamic_fp16 - *pipeline_trt_dynamic_fp16
@ -239,7 +230,7 @@ models:
- name: CGNet - name: CGNet
metafile: configs/cgnet/cgnet.yml metafile: configs/cgnet/cgnet.yml
model_configs: model_configs:
- configs/cgnet/cgnet_512x1024_60k_cityscapes.py - configs/cgnet/cgnet_fcn_4xb8-60k_cityscapes-512x1024.py
pipelines: pipelines:
- *pipeline_ort_dynamic_fp32 - *pipeline_ort_dynamic_fp32
- *pipeline_trt_dynamic_fp16 - *pipeline_trt_dynamic_fp16
@ -249,7 +240,7 @@ models:
- name: EMANet - name: EMANet
metafile: configs/emanet/emanet.yml metafile: configs/emanet/emanet.yml
model_configs: model_configs:
- configs/emanet/emanet_r50-d8_512x1024_80k_cityscapes.py - configs/emanet/emanet_r50-d8_4xb2-80k_cityscapes-512x1024.py
pipelines: pipelines:
- *pipeline_ort_dynamic_fp32 - *pipeline_ort_dynamic_fp32
- *pipeline_trt_dynamic_fp16 - *pipeline_trt_dynamic_fp16
@ -258,7 +249,7 @@ models:
- name: EncNet - name: EncNet
metafile: configs/encnet/encnet.yml metafile: configs/encnet/encnet.yml
model_configs: model_configs:
- configs/encnet/encnet_r50-d8_512x1024_40k_cityscapes.py - configs/encnet/encnet_r50-d8_4xb2-40k_cityscapes-512x1024.py
pipelines: pipelines:
- *pipeline_ort_dynamic_fp32 - *pipeline_ort_dynamic_fp32
- *pipeline_trt_dynamic_fp16 - *pipeline_trt_dynamic_fp16
@ -267,7 +258,7 @@ models:
- name: ERFNet - name: ERFNet
metafile: configs/erfnet/erfnet.yml metafile: configs/erfnet/erfnet.yml
model_configs: model_configs:
- configs/erfnet/erfnet_fcn_4x4_512x1024_160k_cityscapes.py - configs/erfnet/erfnet_fcn_4xb4-160k_cityscapes-512x1024.py
pipelines: pipelines:
- *pipeline_ort_dynamic_fp32 - *pipeline_ort_dynamic_fp32
- *pipeline_trt_dynamic_fp16 - *pipeline_trt_dynamic_fp16
@ -277,7 +268,7 @@ models:
- name: FastFCN - name: FastFCN
metafile: configs/fastfcn/fastfcn.yml metafile: configs/fastfcn/fastfcn.yml
model_configs: model_configs:
- configs/fastfcn/fastfcn_r50-d32_jpu_aspp_512x1024_80k_cityscapes.py - configs/fastfcn/fastfcn_r50-d32_jpu_aspp_4xb2-80k_cityscapes-512x1024.py
pipelines: pipelines:
- *pipeline_ort_dynamic_fp32 - *pipeline_ort_dynamic_fp32
- *pipeline_trt_dynamic_fp16 - *pipeline_trt_dynamic_fp16
@ -287,7 +278,7 @@ models:
- name: GCNet - name: GCNet
metafile: configs/gcnet/gcnet.yml metafile: configs/gcnet/gcnet.yml
model_configs: model_configs:
- configs/gcnet/gcnet_r50-d8_512x1024_40k_cityscapes.py - configs/gcnet/gcnet_r50-d8_4xb2-40k_cityscapes-512x1024.py
pipelines: pipelines:
- *pipeline_ort_dynamic_fp32 - *pipeline_ort_dynamic_fp32
- *pipeline_trt_dynamic_fp16 - *pipeline_trt_dynamic_fp16
@ -295,7 +286,7 @@ models:
- name: ICNet - name: ICNet
metafile: configs/icnet/icnet.yml metafile: configs/icnet/icnet.yml
model_configs: model_configs:
- configs/icnet/icnet_r18-d8_832x832_80k_cityscapes.py - configs/icnet/icnet_r18-d8_4xb2-80k_cityscapes-832x832.py
pipelines: pipelines:
- *pipeline_ort_static_fp32 - *pipeline_ort_static_fp32
- *pipeline_trt_static_fp16 - *pipeline_trt_static_fp16
@ -304,26 +295,26 @@ models:
- name: ISANet - name: ISANet
metafile: configs/isanet/isanet.yml metafile: configs/isanet/isanet.yml
model_configs: model_configs:
- configs/isanet/isanet_r50-d8_512x1024_40k_cityscapes.py - configs/isanet/isanet_r50-d8_4xb2-40k_cityscapes-512x1024.py
pipelines: pipelines:
- *pipeline_ort_dynamic_fp32 - *pipeline_ort_static_fp32
- *pipeline_trt_dynamic_fp16 - *pipeline_trt_static_fp16
- *pipeline_openvino_dynamic_fp32 - *pipeline_openvino_static_fp32
- name: OCRNet - name: OCRNet
metafile: configs/ocrnet/ocrnet.yml metafile: configs/ocrnet/ocrnet.yml
model_configs: model_configs:
- configs/ocrnet/ocrnet_hr18s_512x1024_40k_cityscapes.py - configs/ocrnet/ocrnet_hr18s_4xb2-40k_cityscapes-512x1024.py
pipelines: pipelines:
- *pipeline_ort_dynamic_fp32 - *pipeline_ort_dynamic_fp32
- *pipeline_trt_dynamic_fp16 - *pipeline_trt_dynamic_fp32
- *pipeline_ncnn_static_fp32 - *pipeline_ncnn_static_fp32
- *pipeline_openvino_dynamic_fp32 - *pipeline_openvino_dynamic_fp32
- name: PointRend - name: PointRend
metafile: configs/point_rend/point_rend.yml metafile: configs/point_rend/point_rend.yml
model_configs: model_configs:
- configs/point_rend/pointrend_r50_512x1024_80k_cityscapes.py - configs/point_rend/pointrend_r50_4xb2-80k_cityscapes-512x1024.py
pipelines: pipelines:
- *pipeline_ort_dynamic_fp32 - *pipeline_ort_dynamic_fp32
- *pipeline_trt_dynamic_fp16 - *pipeline_trt_dynamic_fp16
@ -332,18 +323,18 @@ models:
- name: Semantic FPN - name: Semantic FPN
metafile: configs/sem_fpn/sem_fpn.yml metafile: configs/sem_fpn/sem_fpn.yml
model_configs: model_configs:
- configs/sem_fpn/fpn_r50_512x1024_80k_cityscapes.py - configs/sem_fpn/fpn_r50_4xb2-80k_cityscapes-512x1024.py
pipelines: pipelines:
- *pipeline_ort_dynamic_fp32 - *pipeline_ort_dynamic_fp32
- *pipeline_trt_dynamic_fp16 - *pipeline_trt_dynamic_fp32
- *pipeline_ncnn_static_fp32 - *pipeline_ncnn_static_fp32
- *pipeline_openvino_dynamic_fp32 - *pipeline_openvino_dynamic_fp32
- name: STDC - name: STDC
metafile: configs/stdc/stdc.yml metafile: configs/stdc/stdc.yml
model_configs: model_configs:
- configs/stdc/stdc1_in1k-pre_512x1024_80k_cityscapes.py - configs/stdc/stdc1_in1k-pre_4xb12-80k_cityscapes-512x1024.py
- configs/stdc/stdc2_in1k-pre_512x1024_80k_cityscapes.py - configs/stdc/stdc2_in1k-pre_4xb12-80k_cityscapes-512x1024.py
pipelines: pipelines:
- *pipeline_ort_dynamic_fp32 - *pipeline_ort_dynamic_fp32
- *pipeline_trt_dynamic_fp16 - *pipeline_trt_dynamic_fp16
@ -353,15 +344,15 @@ models:
- name: UPerNet - name: UPerNet
metafile: configs/upernet/upernet.yml metafile: configs/upernet/upernet.yml
model_configs: model_configs:
- configs/upernet/upernet_r50_512x1024_40k_cityscapes.py - configs/upernet/upernet_r50_4xb2-40k_cityscapes-512x1024.py
pipelines: pipelines:
- *pipeline_ort_static_fp32 - *pipeline_ort_static_fp32
- *pipeline_trt_static_fp16 - *pipeline_trt_static_fp16
- name: Segmenter - name: Segmenter
metafile: configs/segmenter/segmenter.yml metafile: configs/segmenter/segmenter.yml
model_configs: model_configs:
- configs/segmenter/segmenter_vit-s_mask_8x1_512x512_160k_ade20k.py - configs/segmenter/segmenter_vit-s_fcn_8xb1-160k_ade20k-512x512.py
- configs/segmenter/segmenter_vit-s_linear_8x1_512x512_160k_ade20k.py - configs/segmenter/segmenter_vit-s_mask_8xb1-160k_ade20k-512x512.py
pipelines: pipelines:
- *pipeline_ort_static_fp32_512x512 - *pipeline_ort_static_fp32_512x512
- *pipeline_trt_static_fp32_512x512 - *pipeline_trt_static_fp32_512x512

View File

@ -0,0 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .utils import (generate_datasample, generate_mmseg_deploy_config,
generate_mmseg_task_processor)
__all__ = [
'generate_datasample', 'generate_mmseg_deploy_config',
'generate_mmseg_task_processor'
]

View File

@ -2,44 +2,44 @@
# dataset settings # dataset settings
dataset_type = 'CityscapesDataset' dataset_type = 'CityscapesDataset'
data_root = '.' data_root = '.'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
crop_size = (128, 128) crop_size = (128, 128)
test_pipeline = [ test_pipeline = [
dict(type='LoadImageFromFile'), dict(type='LoadImageFromFile'),
dict( dict(type='Resize', scale=crop_size, keep_ratio=False),
type='MultiScaleFlipAug', # add loading annotation after ``Resize`` because ground truth
img_scale=(128, 128), # does not need to do resize data transform
flip=False, dict(type='LoadAnnotations', reduce_zero_label=True),
transforms=[ dict(type='PackSegInputs')
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
] ]
data = dict( val_dataloader = dict(
samples_per_gpu=1, batch_size=1,
workers_per_gpu=1, num_workers=1,
val=dict( persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type, type=dataset_type,
data_root=data_root, data_root=data_root,
img_dir='', data_prefix=dict(img_path='', seg_map_path=''),
ann_dir='',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
data_root=data_root,
img_dir='',
ann_dir='',
pipeline=test_pipeline)) pipeline=test_pipeline))
test_dataloader = val_dataloader
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
test_evaluator = val_evaluator
# model settings # model settings
norm_cfg = dict(type='SyncBN', requires_grad=True, momentum=0.01) norm_cfg = dict(type='SyncBN', requires_grad=True, momentum=0.01)
data_preprocessor = dict(
type='SegDataPreProcessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_val=0,
seg_pad_val=255)
model = dict( model = dict(
type='EncoderDecoder', type='EncoderDecoder',
data_preprocessor=data_preprocessor,
backbone=dict( backbone=dict(
type='FastSCNN', type='FastSCNN',
downsample_dw_channels=(32, 48), downsample_dw_channels=(32, 48),
@ -64,6 +64,33 @@ model = dict(
align_corners=False, align_corners=False,
loss_decode=dict( loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1)), type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1)),
# model training and testing settings # model training and testing settings
train_cfg=dict(), train_cfg=dict(),
test_cfg=dict(mode='whole')) test_cfg=dict(mode='whole'))
# from default_runtime
default_scope = 'mmseg'
env_cfg = dict(
cudnn_benchmark=True,
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
dist_cfg=dict(backend='nccl'),
)
log_level = 'INFO'
load_from = None
resume = False
vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(
type='SegLocalVisualizer', vis_backends=vis_backends, name='visualizer')
# from schedules
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
default_hooks = dict(
timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook', interval=50),
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=2000),
sampler_seed=dict(type='DistSamplerSeedHook'),
)

View File

@ -1,12 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import mmcv import mmengine
import numpy as np
import pytest import pytest
import torch import torch
import torch.nn as nn
from mmcv import ConfigDict
from mmseg.models import BACKBONES, HEADS
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
from mmdeploy.codebase import import_codebase from mmdeploy.codebase import import_codebase
from mmdeploy.utils import Backend, Codebase, Task from mmdeploy.utils import Backend, Codebase, Task
@ -15,225 +10,52 @@ from mmdeploy.utils.test import (WrapModel, check_backend, get_model_outputs,
import_codebase(Codebase.MMSEG) import_codebase(Codebase.MMSEG)
from .utils import generate_datasample # noqa: E402
@BACKBONES.register_module() from .utils import generate_mmseg_deploy_config # noqa: E402
class ExampleBackbone(nn.Module): from .utils import generate_mmseg_task_processor # noqa: E402
def __init__(self):
super(ExampleBackbone, self).__init__()
self.conv = nn.Conv2d(3, 3, 3)
def init_weights(self, pretrained=None):
pass
def forward(self, x):
return [self.conv(x)]
@HEADS.register_module() @pytest.mark.parametrize('backend', [Backend.ONNXRUNTIME])
class ExampleDecodeHead(BaseDecodeHead): def test_encoderdecoder_predict(backend):
def __init__(self):
super(ExampleDecodeHead, self).__init__(3, 3, num_classes=19)
def forward(self, inputs):
return self.cls_seg(inputs[0])
def get_model(type='EncoderDecoder',
backbone='ExampleBackbone',
decode_head='ExampleDecodeHead'):
from mmseg.models import build_segmentor
cfg = ConfigDict(
type=type,
backbone=dict(type=backbone),
decode_head=dict(type=decode_head),
train_cfg=None,
test_cfg=dict(mode='whole'))
segmentor = build_segmentor(cfg)
return segmentor
def _demo_mm_inputs(input_shape=(1, 3, 8, 16), num_classes=10):
"""Create a superset of inputs needed to run test or train batches.
Args:
input_shape (tuple):
input batch dimensions
num_classes (int):
number of semantic classes
"""
(N, C, H, W) = input_shape
rng = np.random.RandomState(0)
imgs = rng.rand(*input_shape)
segs = rng.randint(
low=0, high=num_classes - 1, size=(N, 1, H, W)).astype(np.uint8)
img_metas = [{
'img_shape': (H, W, C),
'ori_shape': (H, W, C),
'pad_shape': (H, W, C),
'filename': '<demo>.png',
'scale_factor': 1.0,
'flip': False,
'flip_direction': 'horizontal'
} for _ in range(N)]
mm_inputs = {
'imgs': torch.FloatTensor(imgs),
'img_metas': img_metas,
'gt_semantic_seg': torch.LongTensor(segs)
}
return mm_inputs
@pytest.mark.parametrize('backend',
[Backend.ONNXRUNTIME, Backend.OPENVINO, Backend.NCNN])
def test_encoderdecoder_simple_test(backend):
check_backend(backend) check_backend(backend)
segmentor = get_model() deploy_cfg = generate_mmseg_deploy_config(backend.value)
segmentor.cpu().eval() task_processor = generate_mmseg_task_processor(deploy_cfg=deploy_cfg)
segmentor = task_processor.build_pytorch_model()
deploy_cfg = mmcv.Config( size = 256
dict( inputs = torch.randn(1, 3, size, size)
backend_config=dict(type=backend.value), data_samples = [generate_datasample(size, size)]
onnx_config=dict(output_names=['result'], input_shape=None), wrapped_model = WrapModel(segmentor, 'predict', data_samples=data_samples)
codebase_config=dict(type='mmseg', task='Segmentation'))) model_outputs = wrapped_model(inputs)[0].pred_sem_seg.data
if isinstance(segmentor.decode_head, nn.ModuleList):
num_classes = segmentor.decode_head[-1].num_classes
else:
num_classes = segmentor.decode_head.num_classes
mm_inputs = _demo_mm_inputs(
input_shape=(1, 3, 32, 32), num_classes=num_classes)
imgs = mm_inputs.pop('imgs')
img_metas = mm_inputs.pop('img_metas')
model_inputs = {'img': imgs, 'img_meta': img_metas}
model_outputs = get_model_outputs(segmentor, 'simple_test', model_inputs)
img_meta = {
'img_shape':
(img_metas[0]['img_shape'][0], img_metas[0]['img_shape'][1])
}
wrapped_model = WrapModel(segmentor, 'simple_test', img_meta=img_meta)
rewrite_inputs = { rewrite_inputs = {
'img': imgs, 'inputs': inputs,
} }
rewrite_outputs, is_backend_output = get_rewrite_outputs( rewrite_outputs, _ = get_rewrite_outputs(
wrapped_model=wrapped_model, wrapped_model=wrapped_model,
model_inputs=rewrite_inputs, model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg) deploy_cfg=deploy_cfg)
assert torch.allclose(model_outputs, rewrite_outputs[0].squeeze(0))
model_outputs = torch.tensor(model_outputs[0])
rewrite_outputs = rewrite_outputs[0].to(model_outputs).reshape(
model_outputs.shape)
assert torch.allclose(rewrite_outputs, model_outputs)
@pytest.mark.parametrize('backend', [Backend.ONNXRUNTIME, Backend.OPENVINO]) @pytest.mark.parametrize('backend', [Backend.ONNXRUNTIME])
def test_basesegmentor_forward(backend): def test_basesegmentor_forward(backend):
check_backend(backend) check_backend(backend)
segmentor = get_model() deploy_cfg = generate_mmseg_deploy_config(backend.value)
segmentor.cpu().eval() task_processor = generate_mmseg_task_processor(deploy_cfg=deploy_cfg)
segmentor = task_processor.build_pytorch_model()
deploy_cfg = mmcv.Config( size = 256
dict( inputs = torch.randn(1, 3, size, size)
backend_config=dict(type=backend.value), data_samples = [generate_datasample(size, size)]
onnx_config=dict(output_names=['result'], input_shape=None), wrapped_model = WrapModel(
codebase_config=dict(type='mmseg', task='Segmentation'))) segmentor, 'forward', data_samples=data_samples, mode='predict')
model_outputs = wrapped_model(inputs)[0].pred_sem_seg.data
if isinstance(segmentor.decode_head, nn.ModuleList): rewrite_inputs = {
num_classes = segmentor.decode_head[-1].num_classes 'inputs': inputs,
else:
num_classes = segmentor.decode_head.num_classes
mm_inputs = _demo_mm_inputs(num_classes=num_classes)
imgs = mm_inputs.pop('imgs')
img_metas = mm_inputs.pop('img_metas')
model_inputs = {
'img': [imgs],
'img_metas': [img_metas],
'return_loss': False
} }
model_outputs = get_model_outputs(segmentor, 'forward', model_inputs) rewrite_outputs, _ = get_rewrite_outputs(
wrapped_model = WrapModel(segmentor, 'forward')
rewrite_inputs = {'img': imgs}
rewrite_outputs, is_backend_output = get_rewrite_outputs(
wrapped_model=wrapped_model, wrapped_model=wrapped_model,
model_inputs=rewrite_inputs, model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg) deploy_cfg=deploy_cfg)
assert torch.allclose(model_outputs, rewrite_outputs[0].squeeze(0))
model_outputs = torch.tensor(model_outputs[0])
rewrite_outputs = rewrite_outputs[0].to(model_outputs).reshape(
model_outputs.shape)
assert torch.allclose(rewrite_outputs, model_outputs)
@pytest.mark.parametrize('backend', [Backend.ONNXRUNTIME, Backend.OPENVINO])
def test_aspphead_forward(backend):
check_backend(backend)
from mmseg.models.decode_heads import ASPPHead
head = ASPPHead(
in_channels=32, channels=16, num_classes=19,
dilations=(1, 12, 24)).eval()
deploy_cfg = mmcv.Config(
dict(
backend_config=dict(type=backend.value),
onnx_config=dict(
output_names=['result'], input_shape=(1, 32, 45, 45)),
codebase_config=dict(type='mmseg', task='Segmentation')))
inputs = [torch.randn(1, 32, 45, 45)]
model_inputs = {'inputs': inputs}
with torch.no_grad():
model_outputs = get_model_outputs(head, 'forward', model_inputs)
wrapped_model = WrapModel(head, 'forward')
rewrite_inputs = {'inputs': inputs}
rewrite_outputs, is_backend_output = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg)
if is_backend_output:
rewrite_outputs = rewrite_outputs[0]
rewrite_outputs = rewrite_outputs.to(model_outputs).reshape(
model_outputs.shape)
assert torch.allclose(
rewrite_outputs, model_outputs, rtol=1e-03, atol=1e-05)
@pytest.mark.parametrize('backend',
[Backend.ONNXRUNTIME, Backend.OPENVINO, Backend.NCNN])
def test_psphead_forward(backend):
check_backend(backend)
from mmseg.models.decode_heads import PSPHead
head = PSPHead(in_channels=32, channels=16, num_classes=19).eval()
deploy_cfg = mmcv.Config(
dict(
backend_config=dict(type=backend.value),
onnx_config=dict(output_names=['result'], input_shape=None),
codebase_config=dict(type='mmseg', task='Segmentation')))
inputs = [torch.randn(1, 32, 45, 45)]
model_inputs = {'inputs': inputs}
with torch.no_grad():
model_outputs = get_model_outputs(head, 'forward', model_inputs)
wrapped_model = WrapModel(head, 'forward')
rewrite_inputs = {'inputs': inputs}
rewrite_outputs, is_backend_output = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg)
if is_backend_output:
rewrite_outputs = rewrite_outputs[0]
rewrite_outputs = rewrite_outputs.to(model_outputs).reshape(
model_outputs.shape)
assert torch.allclose(rewrite_outputs, model_outputs, rtol=1, atol=1)
@pytest.mark.parametrize('backend', [Backend.ONNXRUNTIME]) @pytest.mark.parametrize('backend', [Backend.ONNXRUNTIME])
@ -242,7 +64,7 @@ def test_emamodule_forward(backend):
from mmseg.models.decode_heads.ema_head import EMAModule from mmseg.models.decode_heads.ema_head import EMAModule
head = EMAModule(8, 2, 2, 1.0).eval() head = EMAModule(8, 2, 2, 1.0).eval()
deploy_cfg = mmcv.Config( deploy_cfg = mmengine.Config(
dict( dict(
backend_config=dict(type=backend.value), backend_config=dict(type=backend.value),
onnx_config=dict( onnx_config=dict(
@ -290,7 +112,7 @@ def test_upconvblock_forward(backend, is_dynamic_shape):
3: 'w' 3: 'w'
}, },
} if is_dynamic_shape else None } if is_dynamic_shape else None
deploy_cfg = mmcv.Config( deploy_cfg = mmengine.Config(
dict( dict(
backend_config=dict(type=backend.value), backend_config=dict(type=backend.value),
onnx_config=dict( onnx_config=dict(

View File

@ -1,42 +1,32 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import copy import copy
import os
from tempfile import NamedTemporaryFile, TemporaryDirectory from tempfile import NamedTemporaryFile, TemporaryDirectory
from typing import Any from typing import Any
import mmcv import mmcv
import numpy as np
import pytest import pytest
import torch import torch
from torch.utils.data import DataLoader
import mmdeploy.backend.onnxruntime as ort_apis import mmdeploy.backend.onnxruntime as ort_apis
from mmdeploy.apis import build_task_processor from mmdeploy.apis import build_task_processor
from mmdeploy.codebase import import_codebase from mmdeploy.codebase import import_codebase
from mmdeploy.utils import Codebase, load_config from mmdeploy.utils import Codebase, load_config
from mmdeploy.utils.test import DummyModel, SwitchBackendWrapper from mmdeploy.utils.test import SwitchBackendWrapper
import_codebase(Codebase.MMSEG) import_codebase(Codebase.MMSEG)
from .utils import generate_datasample # noqa: E402
from .utils import generate_mmseg_deploy_config # noqa: E402
model_cfg_path = 'tests/test_codebase/test_mmseg/data/model.py' model_cfg_path = 'tests/test_codebase/test_mmseg/data/model.py'
model_cfg = load_config(model_cfg_path)[0] model_cfg = load_config(model_cfg_path)[0]
deploy_cfg = mmcv.Config( deploy_cfg = generate_mmseg_deploy_config()
dict(
backend_config=dict(type='onnxruntime'),
codebase_config=dict(type='mmseg', task='Segmentation'),
onnx_config=dict(
type='onnx',
export_params=True,
keep_initializers_as_inputs=False,
opset_version=11,
input_shape=None,
input_names=['input'],
output_names=['output'])))
onnx_file = NamedTemporaryFile(suffix='.onnx').name
task_processor = build_task_processor(model_cfg, deploy_cfg, 'cpu') task_processor = build_task_processor(model_cfg, deploy_cfg, 'cpu')
img_shape = (32, 32) img_shape = (32, 32)
img = np.random.rand(*img_shape, 3) tiger_img_path = 'tests/data/tiger.jpeg'
img = mmcv.imread(tiger_img_path)
img = mmcv.imresize(img, img_shape)
@pytest.mark.parametrize('from_mmrazor', [True, False, '123', 0]) @pytest.mark.parametrize('from_mmrazor', [True, False, '123', 0])
@ -88,29 +78,31 @@ def test_build_backend_model(backend_model):
def test_create_input(): def test_create_input():
inputs = task_processor.create_input(img, input_shape=img_shape) img_path = 'tests/data/tiger.jpeg'
data_preprocessor = task_processor.build_data_preprocessor()
inputs = task_processor.create_input(
img_path, input_shape=img_shape, data_preprocessor=data_preprocessor)
assert isinstance(inputs, tuple) and len(inputs) == 2 assert isinstance(inputs, tuple) and len(inputs) == 2
def test_run_inference(backend_model): def test_build_data_preprocessor():
input_dict, _ = task_processor.create_input(img, input_shape=img_shape) from mmseg.models import SegDataPreProcessor
results = task_processor.run_inference(backend_model, input_dict) data_preprocessor = task_processor.build_data_preprocessor()
assert results is not None assert isinstance(data_preprocessor, SegDataPreProcessor)
def test_visualize(backend_model): def test_get_visualizer():
input_dict, _ = task_processor.create_input(img, input_shape=img_shape) from mmseg.visualization import SegLocalVisualizer
results = task_processor.run_inference(backend_model, input_dict) tmp_dir = TemporaryDirectory().name
with TemporaryDirectory() as dir: visualizer = task_processor.get_visualizer('ort', tmp_dir)
filename = dir + 'tmp.jpg' assert isinstance(visualizer, SegLocalVisualizer)
task_processor.visualize(backend_model, img, results[0], filename, '')
assert os.path.exists(filename)
def test_get_tensort_from_input(): def test_get_tensort_from_input():
input_data = {'img': [torch.ones(3, 4, 5)]} data = torch.rand(3, 4, 5)
input_data = {'inputs': data}
inputs = task_processor.get_tensor_from_input(input_data) inputs = task_processor.get_tensor_from_input(input_data)
assert torch.equal(inputs, torch.ones(3, 4, 5)) assert torch.equal(inputs, data)
def test_get_partition_cfg(): def test_get_partition_cfg():
@ -122,24 +114,39 @@ def test_get_partition_cfg():
def test_build_dataset_and_dataloader(): def test_build_dataset_and_dataloader():
from torch.utils.data import DataLoader, Dataset from torch.utils.data import DataLoader, Dataset
val_dataloader = model_cfg['val_dataloader']
dataset = task_processor.build_dataset( dataset = task_processor.build_dataset(
dataset_cfg=model_cfg, dataset_type='test') dataset_cfg=val_dataloader['dataset'])
assert isinstance(dataset, Dataset), 'Failed to build dataset' assert isinstance(dataset, Dataset), 'Failed to build dataset'
dataloader = task_processor.build_dataloader(dataset, 1, 1) dataloader = task_processor.build_dataloader(val_dataloader)
assert isinstance(dataloader, DataLoader), 'Failed to build dataloader' assert isinstance(dataloader, DataLoader), 'Failed to build dataloader'
def test_single_gpu_test_and_evaluate(): def test_build_test_runner(backend_model):
from mmcv.parallel import MMDataParallel from mmdeploy.codebase.base.runner import DeployTestRunner
temp_dir = TemporaryDirectory().name
runner = task_processor.build_test_runner(backend_model, temp_dir)
assert isinstance(runner, DeployTestRunner)
# Prepare dataloader
dataloader = DataLoader([])
# Prepare dummy model def test_visualize():
model = DummyModel(outputs=[torch.rand([1, 1, *img_shape])]) h, w = img.shape[:2]
model = MMDataParallel(model, device_ids=[0]) datasample = generate_datasample(h, w)
assert model is not None output_file = NamedTemporaryFile(suffix='.jpg').name
# Run test task_processor.visualize(
outputs = task_processor.single_gpu_test(model, dataloader) img, datasample, output_file, show_result=False, window_name='test')
assert outputs is not None
task_processor.evaluate_outputs(model_cfg, outputs, [])
def test_get_preprocess():
process = task_processor.get_preprocess()
assert process is not None
def test_get_postprocess():
process = task_processor.get_postprocess()
assert isinstance(process, dict)
def test_get_model_name():
name = task_processor.get_model_name()
assert isinstance(name, str)

View File

@ -1,10 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp import mmengine
from tempfile import NamedTemporaryFile
import mmcv
import numpy as np
import pytest
import torch import torch
import mmdeploy.backend.onnxruntime as ort_apis import mmdeploy.backend.onnxruntime as ort_apis
@ -14,6 +9,9 @@ from mmdeploy.utils.test import SwitchBackendWrapper, backend_checker
import_codebase(Codebase.MMSEG) import_codebase(Codebase.MMSEG)
from .utils import generate_datasample # noqa: E402
from .utils import generate_mmseg_deploy_config # noqa: E402
NUM_CLASS = 19 NUM_CLASS = 19
IMAGE_SIZE = 32 IMAGE_SIZE = 32
@ -30,101 +28,34 @@ class TestEnd2EndModel:
# simplify backend inference # simplify backend inference
cls.wrapper = SwitchBackendWrapper(ORTWrapper) cls.wrapper = SwitchBackendWrapper(ORTWrapper)
cls.outputs = { cls.outputs = {
'outputs': torch.rand(1, 1, IMAGE_SIZE, IMAGE_SIZE), 'output': torch.rand(1, 1, IMAGE_SIZE, IMAGE_SIZE),
} }
cls.wrapper.set(outputs=cls.outputs) cls.wrapper.set(outputs=cls.outputs)
deploy_cfg = mmcv.Config( deploy_cfg = generate_mmseg_deploy_config()
{'onnx_config': {
'output_names': ['outputs']
}})
from mmdeploy.codebase.mmseg.deploy.segmentation_model import \ from mmdeploy.codebase.mmseg.deploy.segmentation_model import \
End2EndModel End2EndModel
class_names = ['' for i in range(NUM_CLASS)]
palette = np.random.randint(0, 255, size=(NUM_CLASS, 3))
cls.end2end_model = End2EndModel( cls.end2end_model = End2EndModel(
Backend.ONNXRUNTIME, [''], Backend.ONNXRUNTIME, [''], device='cpu', deploy_cfg=deploy_cfg)
device='cpu',
class_names=class_names,
palette=palette,
deploy_cfg=deploy_cfg)
@classmethod @classmethod
def teardown_class(cls): def teardown_class(cls):
cls.wrapper.recover() cls.wrapper.recover()
@pytest.mark.parametrize( def test_forward(self):
'ori_shape', from mmseg.structures import SegDataSample
[[IMAGE_SIZE, IMAGE_SIZE, 3], [2 * IMAGE_SIZE, 2 * IMAGE_SIZE, 3]]) imgs = torch.rand(1, 3, IMAGE_SIZE, IMAGE_SIZE)
def test_forward(self, ori_shape): data_samples = [generate_datasample(IMAGE_SIZE, IMAGE_SIZE)]
imgs = [torch.rand(1, 3, IMAGE_SIZE, IMAGE_SIZE)] results = self.end2end_model.forward(imgs, data_samples)
img_metas = [[{ assert len(results) == 1
'ori_shape': ori_shape, assert isinstance(results[0], SegDataSample)
'img_shape': [IMAGE_SIZE, IMAGE_SIZE, 3],
'scale_factor': [1., 1., 1., 1.],
}]]
results = self.end2end_model.forward(imgs, img_metas)
assert results is not None, 'failed to get output using '\
'End2EndModel'
def test_forward_test(self):
imgs = torch.rand(2, 3, IMAGE_SIZE, IMAGE_SIZE)
results = self.end2end_model.forward_test(imgs)
assert isinstance(results[0], np.ndarray)
def test_show_result(self):
input_img = np.zeros([IMAGE_SIZE, IMAGE_SIZE, 3])
img_path = NamedTemporaryFile(suffix='.jpg').name
result = [torch.rand(IMAGE_SIZE, IMAGE_SIZE)]
self.end2end_model.show_result(
input_img, result, '', show=False, out_file=img_path)
assert osp.exists(img_path), 'Fails to create drawn image.'
@pytest.mark.parametrize('from_file', [True, False])
@pytest.mark.parametrize('data_type', ['train', 'val', 'test'])
def test_get_classes_palette_from_config(from_file, data_type):
from mmseg.datasets import DATASETS
from mmdeploy.codebase.mmseg.deploy.segmentation_model import \
get_classes_palette_from_config
dataset_type = 'CityscapesDataset'
data_cfg = mmcv.Config({
'data': {
data_type:
dict(
type=dataset_type,
data_root='',
img_dir='',
ann_dir='',
pipeline=None)
}
})
if from_file:
config_path = NamedTemporaryFile(suffix='.py').name
with open(config_path, 'w') as file:
file.write(data_cfg.pretty_text)
data_cfg = config_path
classes, palette = get_classes_palette_from_config(data_cfg)
module = DATASETS.module_dict[dataset_type]
assert classes == module.CLASSES, \
f'fail to get CLASSES of dataset: {dataset_type}'
assert palette == module.PALETTE, \
f'fail to get PALETTE of dataset: {dataset_type}'
@backend_checker(Backend.ONNXRUNTIME) @backend_checker(Backend.ONNXRUNTIME)
def test_build_segmentation_model(): def test_build_segmentation_model():
model_cfg = mmcv.Config( model_cfg = mmengine.Config(
dict(data=dict(test={'type': 'CityscapesDataset'}))) dict(data=dict(test={'type': 'CityscapesDataset'})))
deploy_cfg = mmcv.Config( deploy_cfg = generate_mmseg_deploy_config()
dict(
backend_config=dict(type='onnxruntime'),
onnx_config=dict(output_names=['outputs']),
codebase_config=dict(type='mmseg')))
from mmdeploy.backend.onnxruntime import ORTWrapper from mmdeploy.backend.onnxruntime import ORTWrapper
ort_apis.__dict__.update({'ORTWrapper': ORTWrapper}) ort_apis.__dict__.update({'ORTWrapper': ORTWrapper})

View File

@ -0,0 +1,46 @@
# Copyright (c) OpenMMLab. All rights reserved.
import mmengine
import torch
from mmengine.structures import PixelData
from mmdeploy.apis import build_task_processor
from mmdeploy.utils import load_config
def generate_datasample(h, w):
from mmseg.structures import SegDataSample
metainfo = dict(img_shape=(h, w), ori_shape=(h, w), pad_shape=(h, w))
data_sample = SegDataSample()
data_sample.set_metainfo(metainfo)
seg_pred = torch.randint(0, 2, (1, h, w))
seg_gt = torch.randint(0, 2, (1, h, w))
data_sample.set_data(dict(pred_sem_seg=PixelData(**dict(data=seg_pred))))
data_sample.set_data(
dict(gt_sem_seg=PixelData(**dict(data=seg_gt, metainfo=metainfo))))
return data_sample
def generate_mmseg_deploy_config(backend='onnxruntime'):
deploy_cfg = mmengine.Config(
dict(
backend_config=dict(type=backend),
codebase_config=dict(type='mmseg', task='Segmentation'),
onnx_config=dict(
type='onnx',
export_params=True,
keep_initializers_as_inputs=False,
opset_version=11,
input_shape=None,
input_names=['input'],
output_names=['output'])))
return deploy_cfg
def generate_mmseg_task_processor(model_cfg=None, deploy_cfg=None):
if model_cfg is None:
model_cfg = 'tests/test_codebase/test_mmseg/data/model.py'
if deploy_cfg is None:
deploy_cfg = generate_mmseg_deploy_config()
model_cfg, deploy_cfg = load_config(model_cfg, deploy_cfg)
task_processor = build_task_processor(model_cfg, deploy_cfg, 'cpu')
return task_processor

View File

@ -447,10 +447,14 @@ def get_info_from_log_file(info_type: str, log_path: Path,
line_index = -1 line_index = -1
else: else:
line_index = -2 line_index = -2
if yaml_metric_key == 'mIoU':
if yaml_metric_key in ['accuracy_top-1', 'mIoU', 'Eval-PSNR']: metric_line = lines[-1]
info_value = metric_line.split('mIoU: ')[1].split(' ')[0]
info_value = float(info_value)
return info_value
elif yaml_metric_key in ['accuracy_top-1', 'Eval-PSNR']:
# info in last second line # info in last second line
# mmcls, mmseg, mmedit # mmcls, mmeg, mmedit
metric_line = lines[line_index - 1] metric_line = lines[line_index - 1]
elif yaml_metric_key == 'AP': elif yaml_metric_key == 'AP':
# info in last tenth line # info in last tenth line
@ -655,10 +659,6 @@ def get_backend_fps_metric(deploy_cfg_path: str, model_cfg_path: Path,
f'--device {device_type} ' f'--device {device_type} '
codebase_name = get_codebase(str(deploy_cfg_path)).value codebase_name = get_codebase(str(deploy_cfg_path)).value
if codebase_name != 'mmedit' and codebase_name != 'mmdet':
# mmedit and mmdet dont --metric
cmd_str += f'--metrics {eval_name} '
logger.info(f'Process cmd = {cmd_str}') logger.info(f'Process cmd = {cmd_str}')
# Test backend # Test backend
shell_res = subprocess.run( shell_res = subprocess.run(
@ -937,6 +937,9 @@ def get_backend_result(pipeline_info: dict, model_cfg_path: Path,
# Test the model # Test the model
if convert_result and test_type == 'precision': if convert_result and test_type == 'precision':
# Get evaluation metric from model config # Get evaluation metric from model config
if codebase_name == 'mmseg':
metrics_eval_list = model_cfg.val_evaluator.iou_metrics
else:
metrics_eval_list = model_cfg.test_evaluator.get('metric', []) metrics_eval_list = model_cfg.test_evaluator.get('metric', [])
if isinstance(metrics_eval_list, str): if isinstance(metrics_eval_list, str):
# some config is using str only # some config is using str only