Support mmseg:dev-1.x (#790)
* support pspnet + ort * add rewriting for adapt_avg_pool * test pspnet * resize seg_pred to original image shape * run with test.py * keep as original * fix ut of segmentation * update var name * fix export to torchscript * sync with mmseg:test-1.x branch * fix ut * fix regression test for mmseg * fix mmseg.ops * update mmseg yml * fix mmseg2.0 sdk * fix adaptive pool * update rewriting and tests * fix sdk inputspull/979/head
parent
0aad6359e2
commit
06028d6a21
|
@ -4,5 +4,7 @@ codebase_config = dict(model_type='sdk')
|
|||
|
||||
backend_config = dict(pipeline=[
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='Collect', keys=['img'], meta_keys=['filename', 'ori_shape'])
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(
|
||||
type='PackSegInputs', meta_keys=['img_path', 'ori_shape', 'img_shape'])
|
||||
])
|
||||
|
|
|
@ -78,7 +78,7 @@ class PipelineCaller:
|
|||
call_id = self._call_id if call_id is None else call_id
|
||||
if call_id not in self._mp_dict:
|
||||
get_root_logger().error(
|
||||
f'`{self._func_name}` with Call id: {call_id} failed. exit.')
|
||||
f'`{self._func_name}` with Call id: {call_id} failed.')
|
||||
exit(1)
|
||||
ret = self._mp_dict[call_id]
|
||||
self._mp_dict.pop(call_id)
|
||||
|
|
|
@ -42,7 +42,10 @@ def torch2torchscript(img: Any,
|
|||
task_processor = build_task_processor(model_cfg, deploy_cfg, device)
|
||||
|
||||
torch_model = task_processor.build_pytorch_model(model_checkpoint)
|
||||
_, model_inputs = task_processor.create_input(img, input_shape)
|
||||
_, model_inputs = task_processor.create_input(
|
||||
img,
|
||||
input_shape,
|
||||
data_preprocessor=getattr(torch_model, 'data_preprocessor', None))
|
||||
if not isinstance(model_inputs, torch.Tensor):
|
||||
model_inputs = model_inputs[0]
|
||||
|
||||
|
|
|
@ -91,7 +91,9 @@ class BaseTask(metaclass=ABCMeta):
|
|||
nn.Module: An initialized torch model generated by other OpenMMLab
|
||||
codebases.
|
||||
"""
|
||||
from mmengine.model import revert_sync_batchnorm
|
||||
from mmengine.registry import MODELS
|
||||
|
||||
model = deepcopy(self.model_cfg.model)
|
||||
preprocess_cfg = deepcopy(self.model_cfg.get('preprocess_cfg', {}))
|
||||
model.setdefault('data_preprocessor', preprocess_cfg)
|
||||
|
@ -99,9 +101,10 @@ class BaseTask(metaclass=ABCMeta):
|
|||
if model_checkpoint is not None:
|
||||
from mmengine.runner.checkpoint import load_checkpoint
|
||||
load_checkpoint(model, model_checkpoint)
|
||||
|
||||
model = revert_sync_batchnorm(model)
|
||||
model = model.to(self.device)
|
||||
model.eval()
|
||||
|
||||
return model
|
||||
|
||||
def build_dataset(self,
|
||||
|
@ -280,7 +283,10 @@ class BaseTask(metaclass=ABCMeta):
|
|||
visualizer = self.get_visualizer(window_name, save_dir)
|
||||
|
||||
name = osp.splitext(save_name)[0]
|
||||
image = mmcv.imread(image, channel_order='rgb')
|
||||
if isinstance(image, str):
|
||||
image = mmcv.imread(image, channel_order='rgb')
|
||||
assert isinstance(image, np.ndarray)
|
||||
|
||||
visualizer.add_datasample(
|
||||
name,
|
||||
image,
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .mmsegmentation import MMSegmentation
|
||||
from .segmentation import Segmentation
|
||||
|
||||
__all__ = ['MMSegmentation', 'Segmentation']
|
||||
__all__ = ['Segmentation']
|
||||
|
|
|
@ -1,148 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Optional, Union
|
||||
|
||||
import mmengine
|
||||
import torch
|
||||
from mmcv.utils import Registry
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
from mmdeploy.codebase.base import CODEBASE, BaseTask, MMCodebase
|
||||
from mmdeploy.utils import Codebase, get_task_type
|
||||
|
||||
|
||||
def __build_mmseg_task(model_cfg: mmengine.Config, deploy_cfg: mmengine.Config,
|
||||
device: str, registry: Registry) -> BaseTask:
|
||||
task = get_task_type(deploy_cfg)
|
||||
return registry.module_dict[task.value](model_cfg, deploy_cfg, device)
|
||||
|
||||
|
||||
MMSEG_TASK = Registry('mmseg_tasks', build_func=__build_mmseg_task)
|
||||
|
||||
|
||||
@CODEBASE.register_module(Codebase.MMSEG.value)
|
||||
class MMSegmentation(MMCodebase):
|
||||
"""mmsegmentation codebase class."""
|
||||
|
||||
task_registry = MMSEG_TASK
|
||||
|
||||
def __init__(self):
|
||||
super(MMSegmentation, self).__init__()
|
||||
|
||||
@staticmethod
|
||||
def build_task_processor(model_cfg: mmengine.Config,
|
||||
deploy_cfg: mmengine.Config, device: str):
|
||||
"""The interface to build the task processors of mmseg.
|
||||
|
||||
Args:
|
||||
model_cfg (str | mmengine.Config): Model config file.
|
||||
deploy_cfg (str | mmengine.Config): Deployment config file.
|
||||
device (str): A string specifying device type.
|
||||
|
||||
Returns:
|
||||
BaseTask: A task processor.
|
||||
"""
|
||||
return MMSEG_TASK.build(model_cfg, deploy_cfg, device)
|
||||
|
||||
@staticmethod
|
||||
def build_dataset(dataset_cfg: Union[str, mmengine.Config],
|
||||
dataset_type: str = 'val',
|
||||
**kwargs) -> Dataset:
|
||||
"""Build dataset for segmentation.
|
||||
|
||||
Args:
|
||||
dataset_cfg (str | mmengine.Config): The input dataset config.
|
||||
dataset_type (str): A string represents dataset type, e.g.: 'train'
|
||||
, 'test', 'val'. Defaults to 'val'.
|
||||
|
||||
Returns:
|
||||
Dataset: A PyTorch dataset.
|
||||
"""
|
||||
from mmseg.datasets import build_dataset as build_dataset_mmseg
|
||||
|
||||
assert dataset_type in dataset_cfg.data
|
||||
data_cfg = dataset_cfg.data[dataset_type]
|
||||
dataset = build_dataset_mmseg(data_cfg)
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def build_dataloader(dataset: Dataset,
|
||||
samples_per_gpu: int,
|
||||
workers_per_gpu: int,
|
||||
num_gpus: int = 1,
|
||||
dist: bool = False,
|
||||
shuffle: bool = False,
|
||||
seed: Optional[int] = None,
|
||||
drop_last: bool = False,
|
||||
pin_memory: bool = True,
|
||||
persistent_workers: bool = True,
|
||||
**kwargs) -> DataLoader:
|
||||
"""Build dataloader for segmentation.
|
||||
|
||||
Args:
|
||||
dataset (Dataset): Input dataset.
|
||||
samples_per_gpu (int): Number of training samples on each GPU, i.e.
|
||||
,batch size of each GPU.
|
||||
workers_per_gpu (int): How many subprocesses to use for data
|
||||
loading for each GPU.
|
||||
num_gpus (int): Number of GPUs. Only used in non-distributed
|
||||
training. dist (bool): Distributed training/test or not.
|
||||
Defaults to `False`.
|
||||
dist (bool): Distributed training/test or not. Default: True.
|
||||
shuffle (bool): Whether to shuffle the data at every epoch.
|
||||
Defaults to `False`.
|
||||
seed (int): An integer set to be seed. Default is `None`.
|
||||
drop_last (bool): Whether to drop the last incomplete batch in
|
||||
epoch. Default to `False`.
|
||||
pin_memory (bool): Whether to use pin_memory in DataLoader.
|
||||
Default is `True`.
|
||||
persistent_workers (bool): If `True`, the data loader will not
|
||||
shutdown the worker processes after a dataset has been
|
||||
consumed once. This allows to maintain the workers Dataset
|
||||
instances alive. The argument also has effect in
|
||||
PyTorch>=1.7.0. Default is `True`.
|
||||
kwargs: Any other keyword argument to be used to initialize
|
||||
DataLoader.
|
||||
|
||||
Returns:
|
||||
DataLoader: A PyTorch dataloader.
|
||||
"""
|
||||
from mmseg.datasets import build_dataloader as build_dataloader_mmseg
|
||||
return build_dataloader_mmseg(
|
||||
dataset,
|
||||
samples_per_gpu,
|
||||
workers_per_gpu,
|
||||
num_gpus=num_gpus,
|
||||
dist=dist,
|
||||
shuffle=shuffle,
|
||||
seed=seed,
|
||||
drop_last=drop_last,
|
||||
pin_memory=pin_memory,
|
||||
persistent_workers=persistent_workers,
|
||||
**kwargs)
|
||||
|
||||
@staticmethod
|
||||
def single_gpu_test(model: torch.nn.Module,
|
||||
data_loader: DataLoader,
|
||||
show: bool = False,
|
||||
out_dir: Optional[str] = None,
|
||||
pre_eval: bool = True,
|
||||
**kwargs):
|
||||
"""Run test with single gpu.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): Input model from nn.Module.
|
||||
data_loader (DataLoader): PyTorch data loader.
|
||||
show (bool): Specifying whether to show plotted results. Defaults
|
||||
to `False`.
|
||||
out_dir (str): A directory to save results, defaults to `None`.
|
||||
pre_eval (bool): Use dataset.pre_eval() function to generate
|
||||
pre_results for metric evaluation. Mutually exclusive with
|
||||
efficient_test and format_results. Default: False.
|
||||
|
||||
Returns:
|
||||
list: The prediction results.
|
||||
"""
|
||||
from mmseg.apis import single_gpu_test
|
||||
outputs = single_gpu_test(
|
||||
model, data_loader, show, out_dir, pre_eval=pre_eval, **kwargs)
|
||||
return outputs
|
|
@ -1,15 +1,19 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Any, Dict, Optional, Sequence, Tuple, Union
|
||||
import os.path as osp
|
||||
from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
from typing import Callable, Dict, Optional, Sequence, Tuple, Union
|
||||
|
||||
import mmcv
|
||||
import mmengine
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from mmengine import Config
|
||||
from mmengine.model import BaseDataPreprocessor
|
||||
from mmengine.registry import Registry
|
||||
|
||||
from mmdeploy.codebase.base import BaseTask
|
||||
from mmdeploy.utils import Task, get_input_shape
|
||||
from .mmsegmentation import MMSEG_TASK
|
||||
from mmdeploy.codebase.base import CODEBASE, BaseTask, MMCodebase
|
||||
from mmdeploy.utils import Codebase, Task, get_input_shape, get_root_logger
|
||||
|
||||
|
||||
def process_model_config(model_cfg: mmengine.Config,
|
||||
|
@ -27,22 +31,81 @@ def process_model_config(model_cfg: mmengine.Config,
|
|||
Returns:
|
||||
mmengine.Config: the model config after processing.
|
||||
"""
|
||||
from mmseg.apis.inference import LoadImage
|
||||
cfg = model_cfg.copy()
|
||||
cfg = deepcopy(model_cfg)
|
||||
|
||||
if isinstance(imgs[0], np.ndarray):
|
||||
cfg = cfg.copy()
|
||||
# set loading pipeline type
|
||||
cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam'
|
||||
cfg.test_pipeline[0].type = 'LoadImageFromNDArray'
|
||||
|
||||
# remove some training related pipeline
|
||||
removed_indices = []
|
||||
for i in range(len(cfg.test_pipeline)):
|
||||
if cfg.test_pipeline[i]['type'] in ['LoadAnnotations']:
|
||||
removed_indices.append(i)
|
||||
for i in reversed(removed_indices):
|
||||
cfg.test_pipeline.pop(i)
|
||||
|
||||
# for static exporting
|
||||
if input_shape is not None:
|
||||
cfg.data.test.pipeline[1]['img_scale'] = tuple(input_shape)
|
||||
cfg.data.test.pipeline[1]['transforms'][0]['keep_ratio'] = False
|
||||
cfg.data.test.pipeline = [LoadImage()] + cfg.data.test.pipeline[1:]
|
||||
found_resize = False
|
||||
for i in range(len(cfg.test_pipeline)):
|
||||
if 'Resize' == cfg.test_pipeline[i]['type']:
|
||||
cfg.test_pipeline[i]['scale'] = tuple(input_shape)
|
||||
cfg.test_pipeline[i]['keep_ratio'] = False
|
||||
found_resize = True
|
||||
if not found_resize:
|
||||
logger = get_root_logger()
|
||||
logger.warning(
|
||||
f'Not found Resize in test_pipeline: {cfg.test_pipeline}')
|
||||
|
||||
return cfg
|
||||
|
||||
|
||||
def _get_dataset_metainfo(model_cfg: Config):
|
||||
"""Get metainfo of dataset.
|
||||
|
||||
Args:
|
||||
model_cfg Config: Input model Config object.
|
||||
Returns:
|
||||
(list[str], list[np.ndarray]): Class names and palette
|
||||
"""
|
||||
from mmseg import datasets # noqa
|
||||
from mmseg.registry import DATASETS
|
||||
|
||||
module_dict = DATASETS.module_dict
|
||||
|
||||
for dataloader_name in [
|
||||
'test_dataloader', 'val_dataloader', 'train_dataloader'
|
||||
]:
|
||||
if dataloader_name not in model_cfg:
|
||||
continue
|
||||
dataloader_cfg = model_cfg[dataloader_name]
|
||||
dataset_cfg = dataloader_cfg.dataset
|
||||
dataset_mmseg = module_dict.get(dataset_cfg.type, None)
|
||||
if dataset_mmseg is None:
|
||||
continue
|
||||
if hasattr(dataset_mmseg, '_load_metainfo') and isinstance(
|
||||
dataset_mmseg._load_metainfo, Callable):
|
||||
meta = dataset_mmseg._load_metainfo(
|
||||
dataset_cfg.get('metainfo', None))
|
||||
if meta is not None:
|
||||
return meta
|
||||
if hasattr(dataset_mmseg, 'METAINFO'):
|
||||
return dataset_mmseg.METAINFO
|
||||
|
||||
return None
|
||||
|
||||
|
||||
MMSEG_TASK = Registry('mmseg_tasks')
|
||||
|
||||
|
||||
@CODEBASE.register_module(Codebase.MMSEG.value)
|
||||
class MMSegmentation(MMCodebase):
|
||||
"""mmsegmentation codebase class."""
|
||||
task_registry = MMSEG_TASK
|
||||
|
||||
|
||||
@MMSEG_TASK.register_module(Task.SEGMENTATION.value)
|
||||
class Segmentation(BaseTask):
|
||||
"""Segmentation task class.
|
||||
|
@ -70,43 +133,23 @@ class Segmentation(BaseTask):
|
|||
nn.Module: An initialized backend model.
|
||||
"""
|
||||
from .segmentation_model import build_segmentation_model
|
||||
|
||||
data_preprocessor = self.model_cfg.model.data_preprocessor
|
||||
model = build_segmentation_model(
|
||||
model_files,
|
||||
self.model_cfg,
|
||||
self.deploy_cfg,
|
||||
device=self.device,
|
||||
**kwargs)
|
||||
return model.eval()
|
||||
data_preprocessor=data_preprocessor)
|
||||
model = model.to(self.device).eval()
|
||||
return model
|
||||
|
||||
def build_pytorch_model(self,
|
||||
model_checkpoint: Optional[str] = None,
|
||||
cfg_options: Optional[Dict] = None,
|
||||
**kwargs) -> torch.nn.Module:
|
||||
"""Initialize torch model.
|
||||
|
||||
Args:
|
||||
model_checkpoint (str): The checkpoint file of torch model,
|
||||
defaults to `None`.
|
||||
cfg_options (dict): Optional config key-pair parameters.
|
||||
|
||||
Returns:
|
||||
nn.Module: An initialized torch model generated by OpenMMLab
|
||||
codebases.
|
||||
"""
|
||||
from mmcv.cnn.utils import revert_sync_batchnorm
|
||||
if self.from_mmrazor:
|
||||
from mmrazor.apis import init_mmseg_model as init_segmentor
|
||||
else:
|
||||
from mmseg.apis import init_segmentor
|
||||
|
||||
model = init_segmentor(self.model_cfg, model_checkpoint, self.device)
|
||||
model = revert_sync_batchnorm(model)
|
||||
return model.eval()
|
||||
|
||||
def create_input(self,
|
||||
imgs: Union[str, np.ndarray],
|
||||
input_shape: Sequence[int] = None) \
|
||||
-> Tuple[Dict, torch.Tensor]:
|
||||
def create_input(
|
||||
self,
|
||||
imgs: Union[str, np.ndarray, Sequence],
|
||||
input_shape: Sequence[int] = None,
|
||||
data_preprocessor: Optional[BaseDataPreprocessor] = None
|
||||
) -> Tuple[Dict, torch.Tensor]:
|
||||
"""Create input for segmentor.
|
||||
|
||||
Args:
|
||||
|
@ -118,43 +161,64 @@ class Segmentation(BaseTask):
|
|||
Returns:
|
||||
tuple: (data, img), meta information for the input image and input.
|
||||
"""
|
||||
from mmcv.parallel import collate, scatter
|
||||
from mmseg.datasets.pipelines import Compose
|
||||
if not isinstance(imgs, (list, tuple)):
|
||||
from mmengine.dataset import Compose
|
||||
|
||||
if not isinstance(imgs, (tuple, list)):
|
||||
imgs = [imgs]
|
||||
cfg = process_model_config(self.model_cfg, imgs, input_shape)
|
||||
test_pipeline = Compose(cfg.data.test.pipeline)
|
||||
data_list = []
|
||||
test_pipeline = Compose(cfg.test_pipeline)
|
||||
batch_data = defaultdict(list)
|
||||
for img in imgs:
|
||||
# prepare data
|
||||
data = dict(img=img)
|
||||
# build the data pipeline
|
||||
if isinstance(img, str):
|
||||
data = dict(img_path=img)
|
||||
else:
|
||||
data = dict(img=img)
|
||||
data = test_pipeline(data)
|
||||
data_list.append(data)
|
||||
batch_data['inputs'].append(data['inputs'])
|
||||
batch_data['data_samples'].append(data['data_samples'])
|
||||
|
||||
data = collate(data_list, samples_per_gpu=len(imgs))
|
||||
# batch_data = pseudo_collate([batch_data])
|
||||
if data_preprocessor is not None:
|
||||
batch_data = data_preprocessor(batch_data, False)
|
||||
input_tensor = batch_data['inputs']
|
||||
else:
|
||||
input_tensor = BaseTask.get_tensor_from_input(batch_data)
|
||||
return batch_data, input_tensor
|
||||
|
||||
data['img_metas'] = [
|
||||
img_metas.data[0] for img_metas in data['img_metas']
|
||||
]
|
||||
data['img'] = [img.data[0][None, :] for img in data['img']]
|
||||
if self.device != 'cpu':
|
||||
data = scatter(data, [self.device])[0]
|
||||
def get_visualizer(self, name: str, save_dir: str):
|
||||
"""
|
||||
|
||||
return data, data['img']
|
||||
Args:
|
||||
name:
|
||||
save_dir:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
# import to make SegLocalVisualizer could be built
|
||||
from mmseg.visualization import SegLocalVisualizer # noqa: F401,F403
|
||||
|
||||
visualizer = super().get_visualizer(name, save_dir)
|
||||
# force to change save_dir instead of
|
||||
# save_dir/vis_data/vis_image/xx.jpg
|
||||
visualizer._vis_backends['LocalVisBackend']._save_dir = save_dir
|
||||
visualizer._vis_backends['LocalVisBackend']._img_save_dir = '.'
|
||||
metainfo = _get_dataset_metainfo(self.model_cfg)
|
||||
if metainfo is not None:
|
||||
visualizer.dataset_meta = metainfo
|
||||
return visualizer
|
||||
|
||||
def visualize(self,
|
||||
model,
|
||||
image: Union[str, np.ndarray],
|
||||
result: list,
|
||||
output_file: str,
|
||||
window_name: str = '',
|
||||
show_result: bool = False,
|
||||
opacity: float = 0.5):
|
||||
"""Visualize predictions of a model.
|
||||
opacity: float = 0.5,
|
||||
**kwargs):
|
||||
"""Visualize segmentation predictions.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Input model.
|
||||
image (str | np.ndarray): Input image to draw predictions on.
|
||||
result (list): A list of predictions.
|
||||
output_file (str): Output file to save drawn image.
|
||||
|
@ -165,88 +229,18 @@ class Segmentation(BaseTask):
|
|||
opacity: (float): Opacity of painted segmentation map.
|
||||
Defaults to `0.5`.
|
||||
"""
|
||||
show_img = mmcv.imread(image) if isinstance(image, str) else image
|
||||
output_file = None if show_result else output_file
|
||||
# Need to wrapper the result with list for mmseg
|
||||
result = [result]
|
||||
model.show_result(
|
||||
show_img,
|
||||
result,
|
||||
out_file=output_file,
|
||||
win_name=window_name,
|
||||
show=show_result,
|
||||
opacity=opacity)
|
||||
|
||||
@staticmethod
|
||||
def run_inference(model, model_inputs: Dict[str, torch.Tensor]):
|
||||
"""Run inference once for a segmentation model of mmseg.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Input model.
|
||||
model_inputs (dict): A dict containing model inputs tensor and
|
||||
meta info.
|
||||
|
||||
Returns:
|
||||
list: The predictions of model inference.
|
||||
"""
|
||||
return model(**model_inputs, return_loss=False, rescale=True)
|
||||
save_dir, filename = osp.split(output_file)
|
||||
visualizer = self.get_visualizer(window_name, save_dir)
|
||||
name = osp.splitext(filename)[0]
|
||||
if isinstance(image, str):
|
||||
image = mmcv.imread(image, channel_order='rgb')
|
||||
visualizer.add_datasample(
|
||||
name, image, data_sample=result.cpu(), show=show_result)
|
||||
|
||||
@staticmethod
|
||||
def get_partition_cfg(partition_type: str) -> Dict:
|
||||
raise NotImplementedError('Not supported yet.')
|
||||
|
||||
@staticmethod
|
||||
def get_tensor_from_input(input_data: Dict[str, Any]) -> torch.Tensor:
|
||||
"""Get input tensor from input data.
|
||||
|
||||
Args:
|
||||
input_data (dict): Input data containing meta info and image
|
||||
tensor.
|
||||
Returns:
|
||||
torch.Tensor: An image in `Tensor`.
|
||||
"""
|
||||
return input_data['img'][0]
|
||||
|
||||
@staticmethod
|
||||
def evaluate_outputs(model_cfg,
|
||||
outputs: Sequence,
|
||||
dataset: Dataset,
|
||||
metrics: Optional[str] = None,
|
||||
out: Optional[str] = None,
|
||||
metric_options: Optional[dict] = None,
|
||||
format_only: bool = False,
|
||||
log_file: Optional[str] = None):
|
||||
"""Perform post-processing to predictions of model.
|
||||
|
||||
Args:
|
||||
outputs (list): A list of predictions of model inference.
|
||||
dataset (Dataset): Input dataset to run test.
|
||||
model_cfg (mmengine.Config): The model config.
|
||||
metrics (str): Evaluation metrics, which depends on
|
||||
the codebase and the dataset, e.g., e.g., "mIoU" for generic
|
||||
datasets, and "cityscapes" for Cityscapes in mmseg.
|
||||
out (str): Output result file in pickle format, defaults to `None`.
|
||||
metric_options (dict): Custom options for evaluation, will be
|
||||
kwargs for dataset.evaluate() function. Defaults to `None`.
|
||||
format_only (bool): Format the output results without perform
|
||||
evaluation. It is useful when you want to format the result
|
||||
to a specific format and submit it to the test server. Defaults
|
||||
to `False`.
|
||||
log_file (str | None): The file to write the evaluation results.
|
||||
Defaults to `None` and the results will only print on stdout.
|
||||
"""
|
||||
from mmcv.utils import get_logger
|
||||
logger = get_logger('test', log_file=log_file)
|
||||
|
||||
if out:
|
||||
logger.debug(f'writing results to {out}')
|
||||
mmcv.dump(outputs, out)
|
||||
kwargs = {} if metric_options is None else metric_options
|
||||
if format_only:
|
||||
dataset.format_results(outputs, **kwargs)
|
||||
if metrics:
|
||||
dataset.evaluate(outputs, metrics, logger=logger, **kwargs)
|
||||
|
||||
def get_preprocess(self) -> Dict:
|
||||
"""Get the preprocess information for SDK.
|
||||
|
||||
|
@ -254,10 +248,28 @@ class Segmentation(BaseTask):
|
|||
dict: Composed of the preprocess information.
|
||||
"""
|
||||
input_shape = get_input_shape(self.deploy_cfg)
|
||||
load_from_file = self.model_cfg.data.test.pipeline[0]
|
||||
load_from_file = self.model_cfg.test_pipeline[0]
|
||||
model_cfg = process_model_config(self.model_cfg, [''], input_shape)
|
||||
preprocess = model_cfg.data.test.pipeline
|
||||
preprocess = deepcopy(model_cfg.test_pipeline)
|
||||
preprocess[0] = load_from_file
|
||||
dp = self.model_cfg.data_preprocessor
|
||||
assert preprocess[1].type == 'Resize'
|
||||
preprocess[1]['size'] = list(reversed(preprocess[1].pop('scale')))
|
||||
if preprocess[-1].type == 'PackSegInputs':
|
||||
preprocess[-1] = dict(
|
||||
type='Normalize',
|
||||
mean=dp.mean,
|
||||
std=dp.std,
|
||||
to_rgb=dp.bgr_to_rgb)
|
||||
preprocess.append(dict(type='ImageToTensor', keys=['img']))
|
||||
preprocess.append(
|
||||
dict(
|
||||
type='Collect',
|
||||
keys=['img'],
|
||||
meta_keys=[
|
||||
'img_shape', 'pad_shape', 'ori_shape', 'img_norm_cfg',
|
||||
'scale_factor'
|
||||
]))
|
||||
return preprocess
|
||||
|
||||
def get_postprocess(self) -> Dict:
|
||||
|
|
|
@ -1,26 +1,23 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Optional, Sequence, Union
|
||||
|
||||
import mmcv
|
||||
import mmengine
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmcv.utils import Registry
|
||||
from mmseg.datasets import DATASETS
|
||||
from mmseg.models.segmentors.base import BaseSegmentor
|
||||
from mmseg.ops import resize
|
||||
from mmengine import Config
|
||||
from mmengine.model import BaseDataPreprocessor
|
||||
from mmengine.registry import Registry
|
||||
from mmengine.structures import BaseDataElement, PixelData
|
||||
from torch import nn
|
||||
|
||||
from mmdeploy.codebase.base import BaseBackendModel
|
||||
from mmdeploy.utils import (Backend, get_backend, get_codebase_config,
|
||||
load_config)
|
||||
get_root_logger, load_config)
|
||||
|
||||
|
||||
def __build_backend_model(cls_name: str, registry: Registry, *args, **kwargs):
|
||||
return registry.module_dict[cls_name](*args, **kwargs)
|
||||
|
||||
|
||||
__BACKEND_MODEL = mmcv.utils.Registry(
|
||||
'backend_segmentors', build_func=__build_backend_model)
|
||||
__BACKEND_MODEL = Registry('backend_segmentors')
|
||||
|
||||
|
||||
@__BACKEND_MODEL.register_module('end2end')
|
||||
|
@ -42,14 +39,13 @@ class End2EndModel(BaseBackendModel):
|
|||
backend: Backend,
|
||||
backend_files: Sequence[str],
|
||||
device: str,
|
||||
class_names: Sequence[str],
|
||||
palette: np.ndarray,
|
||||
deploy_cfg: Union[str, mmcv.Config] = None,
|
||||
deploy_cfg: Union[str, Config] = None,
|
||||
data_preprocessor: Optional[Union[dict, nn.Module]] = None,
|
||||
**kwargs):
|
||||
super(End2EndModel, self).__init__(deploy_cfg=deploy_cfg)
|
||||
self.CLASSES = class_names
|
||||
self.PALETTE = palette
|
||||
super(End2EndModel, self).__init__(
|
||||
deploy_cfg=deploy_cfg, data_preprocessor=data_preprocessor)
|
||||
self.deploy_cfg = deploy_cfg
|
||||
self.device = device
|
||||
self._init_wrapper(
|
||||
backend=backend,
|
||||
backend_files=backend_files,
|
||||
|
@ -67,8 +63,10 @@ class End2EndModel(BaseBackendModel):
|
|||
deploy_cfg=self.deploy_cfg,
|
||||
**kwargs)
|
||||
|
||||
def forward(self, img: Sequence[torch.Tensor],
|
||||
img_metas: Sequence[Sequence[dict]], *args, **kwargs):
|
||||
def forward(self,
|
||||
inputs: torch.Tensor,
|
||||
data_samples: Optional[List[BaseDataElement]] = None,
|
||||
mode: str = 'predict'):
|
||||
"""Run forward inference.
|
||||
|
||||
Args:
|
||||
|
@ -82,78 +80,47 @@ class End2EndModel(BaseBackendModel):
|
|||
Returns:
|
||||
list: A list contains predictions.
|
||||
"""
|
||||
input_img = img[0].contiguous()
|
||||
outputs = self.forward_test(input_img, img_metas, *args, **kwargs)
|
||||
seg_pred = outputs[0]
|
||||
# whole mode supports dynamic shape
|
||||
ori_shape = img_metas[0][0]['ori_shape']
|
||||
if not (ori_shape[0] == seg_pred.shape[-2]
|
||||
and ori_shape[1] == seg_pred.shape[-1]):
|
||||
seg_pred = torch.from_numpy(seg_pred).float()
|
||||
seg_pred = resize(
|
||||
seg_pred, size=tuple(ori_shape[:2]), mode='nearest')
|
||||
seg_pred = seg_pred.long().detach().cpu().numpy()
|
||||
# remove unnecessary dim
|
||||
seg_pred = seg_pred.squeeze(1)
|
||||
seg_pred = list(seg_pred)
|
||||
return seg_pred
|
||||
assert mode == 'predict', \
|
||||
'Backend model only support mode==predict,' f' but get {mode}'
|
||||
if inputs.device != torch.device(self.device):
|
||||
get_root_logger().warning(f'expect input device {self.device}'
|
||||
f' but get {inputs.device}.')
|
||||
inputs = inputs.to(self.device)
|
||||
batch_outputs = self.wrapper({self.input_name:
|
||||
inputs})[self.output_names[0]]
|
||||
return self.pack_result(batch_outputs, data_samples)
|
||||
|
||||
def forward_test(self, imgs: torch.Tensor, *args, **kwargs) -> \
|
||||
List[np.ndarray]:
|
||||
"""The interface for forward test.
|
||||
def pack_result(self, batch_outputs, data_samples):
|
||||
predictions = []
|
||||
for seg_pred, data_sample in zip(batch_outputs, data_samples):
|
||||
# resize seg_pred to original image shape
|
||||
metainfo = data_sample.metainfo
|
||||
if metainfo['ori_shape'] != metainfo['img_shape']:
|
||||
from mmseg.models.utils import resize
|
||||
ori_type = seg_pred.dtype
|
||||
seg_pred = resize(
|
||||
seg_pred.unsqueeze(0).to(torch.float32),
|
||||
size=metainfo['ori_shape'],
|
||||
mode='nearest').squeeze(0).to(ori_type)
|
||||
data_sample.set_data(
|
||||
dict(pred_sem_seg=PixelData(**dict(data=seg_pred))))
|
||||
predictions.append(data_sample)
|
||||
|
||||
Args:
|
||||
imgs (torch.Tensor): Input image(s) in [N x C x H x W] format.
|
||||
|
||||
Returns:
|
||||
List[np.ndarray]: A list of segmentation map.
|
||||
"""
|
||||
outputs = self.wrapper({self.input_name: imgs})
|
||||
outputs = self.wrapper.output_to_list(outputs)
|
||||
outputs = [out.detach().cpu().numpy() for out in outputs]
|
||||
return outputs
|
||||
|
||||
def show_result(self,
|
||||
img: np.ndarray,
|
||||
result: list,
|
||||
win_name: str = '',
|
||||
palette: Optional[np.ndarray] = None,
|
||||
show: bool = True,
|
||||
opacity: float = 0.5,
|
||||
out_file: str = None):
|
||||
"""Show predictions of segmentation.
|
||||
Args:
|
||||
img: (np.ndarray): Input image to draw predictions.
|
||||
result (list): A list of predictions.
|
||||
win_name (str): The name of visualization window. Default is ''.
|
||||
palette (np.ndarray): The palette of segmentation map.
|
||||
show (bool): Whether to show plotted image in windows. Defaults to
|
||||
`True`.
|
||||
opacity: (float): Opacity of painted segmentation map.
|
||||
Defaults to `0.5`.
|
||||
out_file (str): Output image file to save drawn predictions.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Drawn image, only if not `show` or `out_file`.
|
||||
"""
|
||||
palette = self.PALETTE if palette is None else palette
|
||||
return BaseSegmentor.show_result(
|
||||
self,
|
||||
img,
|
||||
result,
|
||||
palette=palette,
|
||||
opacity=opacity,
|
||||
show=show,
|
||||
win_name=win_name,
|
||||
out_file=out_file)
|
||||
return predictions
|
||||
|
||||
|
||||
@__BACKEND_MODEL.register_module('sdk')
|
||||
class SDKEnd2EndModel(End2EndModel):
|
||||
"""SDK inference class, converts SDK output to mmseg format."""
|
||||
|
||||
def forward(self, img: Sequence[torch.Tensor],
|
||||
img_metas: Sequence[Sequence[dict]], *args, **kwargs):
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs['data_preprocessor'] = None
|
||||
super(SDKEnd2EndModel, self).__init__(*args, **kwargs)
|
||||
|
||||
def forward(self,
|
||||
inputs: torch.Tensor,
|
||||
data_samples: Optional[List[BaseDataElement]] = None,
|
||||
mode: str = 'predict'):
|
||||
"""Run forward inference.
|
||||
|
||||
Args:
|
||||
|
@ -167,42 +134,26 @@ class SDKEnd2EndModel(End2EndModel):
|
|||
Returns:
|
||||
list: A list contains predictions.
|
||||
"""
|
||||
masks = self.wrapper.invoke(img[0].contiguous().detach().cpu().numpy())
|
||||
return masks
|
||||
if isinstance(inputs, list):
|
||||
inputs = inputs[0]
|
||||
# inputs are c,h,w, sdk requested h,w,c
|
||||
inputs = inputs.permute(1, 2, 0)
|
||||
outputs = self.wrapper.invoke(
|
||||
inputs.contiguous().detach().cpu().numpy())
|
||||
batch_outputs = torch.from_numpy(outputs).to(torch.int64).to(
|
||||
self.device)
|
||||
batch_outputs = batch_outputs.unsqueeze(0).unsqueeze(0)
|
||||
return self.pack_result(batch_outputs, data_samples)
|
||||
|
||||
|
||||
def get_classes_palette_from_config(model_cfg: Union[str, mmengine.Config]):
|
||||
"""Get class name and palette from config.
|
||||
|
||||
Args:
|
||||
model_cfg (str | mmengine.Config): Input model config file or
|
||||
Config object.
|
||||
Returns:
|
||||
tuple(Sequence[str], np.ndarray): A list of string specifying names of
|
||||
different class and the palette of segmentation map.
|
||||
"""
|
||||
# load cfg if necessary
|
||||
model_cfg = load_config(model_cfg)[0]
|
||||
|
||||
module_dict = DATASETS.module_dict
|
||||
data_cfg = model_cfg.data
|
||||
|
||||
if 'val' in data_cfg:
|
||||
module = module_dict[data_cfg.val.type]
|
||||
elif 'test' in data_cfg:
|
||||
module = module_dict[data_cfg.test.type]
|
||||
elif 'train' in data_cfg:
|
||||
module = module_dict[data_cfg.train.type]
|
||||
else:
|
||||
raise RuntimeError(f'No dataset config found in: {model_cfg}')
|
||||
|
||||
return module.CLASSES, module.PALETTE
|
||||
|
||||
|
||||
def build_segmentation_model(model_files: Sequence[str],
|
||||
model_cfg: Union[str, mmengine.Config],
|
||||
deploy_cfg: Union[str, mmengine.Config],
|
||||
device: str, **kwargs):
|
||||
def build_segmentation_model(
|
||||
model_files: Sequence[str],
|
||||
model_cfg: Union[str, Config],
|
||||
deploy_cfg: Union[str, Config],
|
||||
device: str,
|
||||
data_preprocessor: Optional[Union[Config,
|
||||
BaseDataPreprocessor]] = None,
|
||||
**kwargs):
|
||||
"""Build object segmentation model for different backends.
|
||||
|
||||
Args:
|
||||
|
@ -212,25 +163,25 @@ def build_segmentation_model(model_files: Sequence[str],
|
|||
deploy_cfg (str | mmengine.Config): Input deployment config file or
|
||||
Config object.
|
||||
device (str): Device to input model.
|
||||
data_preprocessor (BaseDataPreprocessor | Config): The data
|
||||
preprocessor of the model.
|
||||
|
||||
Returns:
|
||||
BaseBackendModel: Segmentor for a configured backend.
|
||||
"""
|
||||
# load cfg if necessary
|
||||
deploy_cfg, model_cfg = load_config(deploy_cfg, model_cfg)
|
||||
|
||||
backend = get_backend(deploy_cfg)
|
||||
model_type = get_codebase_config(deploy_cfg).get('model_type', 'end2end')
|
||||
class_names, palette = get_classes_palette_from_config(model_cfg)
|
||||
|
||||
backend_segmentor = __BACKEND_MODEL.build(
|
||||
model_type,
|
||||
backend=backend,
|
||||
backend_files=model_files,
|
||||
device=device,
|
||||
class_names=class_names,
|
||||
palette=palette,
|
||||
deploy_cfg=deploy_cfg,
|
||||
**kwargs)
|
||||
dict(
|
||||
type=model_type,
|
||||
backend=backend,
|
||||
backend_files=model_files,
|
||||
device=device,
|
||||
deploy_cfg=deploy_cfg,
|
||||
data_preprocessor=data_preprocessor,
|
||||
**kwargs))
|
||||
|
||||
return backend_segmentor
|
||||
|
|
|
@ -1,6 +1,4 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .aspp_head import aspp_head__forward
|
||||
from .ema_head import ema_module__forward
|
||||
from .psp_head import ppm__forward
|
||||
|
||||
__all__ = ['aspp_head__forward', 'ppm__forward', 'ema_module__forward']
|
||||
__all__ = ['ema_module__forward']
|
||||
|
|
|
@ -1,43 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
from mmseg.ops import resize
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
from mmdeploy.utils import is_dynamic_shape
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmseg.models.decode_heads.ASPPHead.forward')
|
||||
def aspp_head__forward(ctx, self, inputs):
|
||||
"""Rewrite `forward` for default backend.
|
||||
|
||||
Support configured dynamic/static shape in resize op.
|
||||
|
||||
Args:
|
||||
ctx (ContextCaller): The context with additional information.
|
||||
self: The instance of the original class.
|
||||
inputs (list[Tensor]): List of multi-level img features.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output segmentation map.
|
||||
"""
|
||||
x = self._transform_inputs(inputs)
|
||||
deploy_cfg = ctx.cfg
|
||||
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
|
||||
# get origin input shape as tensor to support onnx dynamic shape
|
||||
size = x.shape[2:]
|
||||
if not is_dynamic_flag:
|
||||
size = [int(val) for val in size]
|
||||
|
||||
aspp_outs = [
|
||||
resize(
|
||||
self.image_pool(x),
|
||||
size=size,
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
]
|
||||
aspp_outs.extend(self.aspp_modules(x))
|
||||
aspp_outs = torch.cat(aspp_outs, dim=1)
|
||||
output = self.bottleneck(aspp_outs)
|
||||
output = self.cls_seg(output)
|
||||
return output
|
|
@ -1,52 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
import torch.nn as nn
|
||||
from mmseg.ops import resize
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
from mmdeploy.utils import IR, get_root_logger, is_dynamic_shape
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmseg.models.decode_heads.psp_head.PPM.forward', ir=IR.ONNX)
|
||||
def ppm__forward(ctx, self, x):
|
||||
"""Rewrite `forward` for default backend.
|
||||
|
||||
Support configured dynamic/static shape in resize op.
|
||||
|
||||
Args:
|
||||
ctx (ContextCaller): The context with additional information.
|
||||
self: The instance of the original class.
|
||||
x (Tensor): The transformed input feature.
|
||||
|
||||
Returns:
|
||||
List[torch.Tensor]: Up-sampled segmentation maps of different
|
||||
scales.
|
||||
"""
|
||||
deploy_cfg = ctx.cfg
|
||||
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
|
||||
# get origin input shape as tensor to support onnx dynamic shape
|
||||
size = x.shape[2:]
|
||||
if not is_dynamic_flag:
|
||||
size = [int(val) for val in size]
|
||||
|
||||
ppm_outs = []
|
||||
for ppm in self:
|
||||
if isinstance(ppm[0], nn.AdaptiveAvgPool2d) and \
|
||||
ppm[0].output_size != 1:
|
||||
if is_dynamic_flag:
|
||||
logger = get_root_logger()
|
||||
logger.warning('`AdaptiveAvgPool2d` would be '
|
||||
'replaced to `AvgPool2d` explicitly')
|
||||
# replace AdaptiveAvgPool2d with AvgPool2d explicitly
|
||||
output_size = 2 * [ppm[0].output_size]
|
||||
k = [int(size[i] / output_size[i]) for i in range(0, len(size))]
|
||||
ppm[0] = nn.AvgPool2d(k, stride=k, padding=0, ceil_mode=False)
|
||||
ppm_out = ppm(x)
|
||||
upsampled_ppm_out = resize(
|
||||
ppm_out,
|
||||
size=size,
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
ppm_outs.append(upsampled_ppm_out)
|
||||
return ppm_outs
|
|
@ -1,5 +1,9 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .base import base_segmentor__forward
|
||||
from .encoder_decoder import encoder_decoder__simple_test
|
||||
from .cascade_encoder_decoder import cascade_encoder_decoder__predict
|
||||
from .encoder_decoder import encoder_decoder__predict
|
||||
|
||||
__all__ = ['base_segmentor__forward', 'encoder_decoder__simple_test']
|
||||
__all__ = [
|
||||
'base_segmentor__forward', 'encoder_decoder__predict',
|
||||
'cascade_encoder_decoder__predict'
|
||||
]
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
from mmseg.structures import SegDataSample
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
from mmdeploy.utils import is_dynamic_shape
|
||||
|
@ -7,7 +7,12 @@ from mmdeploy.utils import is_dynamic_shape
|
|||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmseg.models.segmentors.BaseSegmentor.forward')
|
||||
def base_segmentor__forward(ctx, self, img, img_metas=None, **kwargs):
|
||||
def base_segmentor__forward(ctx,
|
||||
self,
|
||||
inputs,
|
||||
data_samples=None,
|
||||
mode='predict',
|
||||
**kwargs):
|
||||
"""Rewrite `forward` for default backend.
|
||||
|
||||
Support configured dynamic/static shape for model input.
|
||||
|
@ -15,27 +20,23 @@ def base_segmentor__forward(ctx, self, img, img_metas=None, **kwargs):
|
|||
Args:
|
||||
ctx (ContextCaller): The context with additional information.
|
||||
self: The instance of the original class.
|
||||
img (Tensor | List[Tensor]): Input image tensor(s).
|
||||
img_metas (List[dict]): List of dicts containing image's meta
|
||||
inputs (Tensor | List[Tensor]): Input image tensor(s).
|
||||
data_samples (List[dict]): List of dicts containing image's meta
|
||||
information such as `img_shape`.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output segmentation map pf shape [N, 1, H, W].
|
||||
"""
|
||||
if img_metas is None:
|
||||
img_metas = [{}]
|
||||
else:
|
||||
assert len(img_metas) == 1, 'do not support aug_test'
|
||||
img_metas = img_metas[0]
|
||||
if isinstance(img, list):
|
||||
img = img[0]
|
||||
assert isinstance(img, torch.Tensor)
|
||||
if data_samples is None:
|
||||
data_samples = [SegDataSample()]
|
||||
|
||||
deploy_cfg = ctx.cfg
|
||||
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
|
||||
# get origin input shape as tensor to support onnx dynamic shape
|
||||
img_shape = img.shape[2:]
|
||||
img_shape = inputs.shape[2:]
|
||||
if not is_dynamic_flag:
|
||||
img_shape = [int(val) for val in img_shape]
|
||||
img_metas[0]['img_shape'] = img_shape
|
||||
return self.simple_test(img, img_metas, **kwargs)
|
||||
for data_sample in data_samples:
|
||||
data_sample.set_field(
|
||||
name='img_shape', value=img_shape, field_type='metainfo')
|
||||
return self.predict(inputs, data_samples)
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmseg.models.segmentors.CascadeEncoderDecoder.predict')
|
||||
def cascade_encoder_decoder__predict(ctx, self, inputs, data_samples,
|
||||
**kwargs):
|
||||
"""Rewrite `predict` for default backend.
|
||||
|
||||
1. only support mode=`whole` inference
|
||||
2. skip calling self.postprocess_result
|
||||
|
||||
Args:
|
||||
ctx (ContextCaller): The context with additional information.
|
||||
self: The instance of the original class.
|
||||
inputs (Tensor): Inputs with shape (N, C, H, W).
|
||||
data_samples (SampleList): The seg data samples.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output segmentation map pf shape [N, 1, H, W].
|
||||
"""
|
||||
batch_img_metas = []
|
||||
for data_sample in data_samples:
|
||||
batch_img_metas.append(data_sample.metainfo)
|
||||
x = self.extract_feat(inputs)
|
||||
out = self.decode_head[0].forward(x)
|
||||
for i in range(1, self.num_stages - 1):
|
||||
out = self.decode_head[i].forward(x, out)
|
||||
seg_logit = self.decode_head[-1].predict(x, out, batch_img_metas,
|
||||
self.test_cfg)
|
||||
seg_pred = seg_logit.argmax(dim=1, keepdim=True)
|
||||
return seg_pred
|
|
@ -1,28 +1,28 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch.nn.functional as F
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmseg.models.segmentors.EncoderDecoder.simple_test')
|
||||
def encoder_decoder__simple_test(ctx, self, img, img_meta, **kwargs):
|
||||
"""Rewrite `simple_test` for default backend.
|
||||
func_name='mmseg.models.segmentors.EncoderDecoder.predict')
|
||||
def encoder_decoder__predict(ctx, self, inputs, data_samples, **kwargs):
|
||||
"""Rewrite `predict` for default backend.
|
||||
|
||||
Support configured dynamic/static shape for model input and return
|
||||
segmentation map as Tensor instead of numpy array.
|
||||
1. only support mode=`whole` inference
|
||||
2. skip calling self.postprocess_result
|
||||
|
||||
Args:
|
||||
ctx (ContextCaller): The context with additional information.
|
||||
self: The instance of the original class.
|
||||
img (Tensor | List[Tensor]): Input image tensor(s).
|
||||
img_meta (dict): Dict containing image's meta information
|
||||
such as `img_shape`.
|
||||
inputs (Tensor): Inputs with shape (N, C, H, W).
|
||||
data_samples (SampleList): The seg data samples.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output segmentation map pf shape [N, 1, H, W].
|
||||
"""
|
||||
seg_logit = self.encode_decode(img, img_meta)
|
||||
seg_logit = F.softmax(seg_logit, dim=1)
|
||||
batch_img_metas = []
|
||||
for data_sample in data_samples:
|
||||
batch_img_metas.append(data_sample.metainfo)
|
||||
x = self.extract_feat(inputs)
|
||||
seg_logit = self.decode_head.predict(x, batch_img_metas, self.test_cfg)
|
||||
seg_pred = seg_logit.argmax(dim=1, keepdim=True)
|
||||
return seg_pred
|
||||
|
|
|
@ -28,7 +28,7 @@ def up_conv_block__forward(ctx, self, skip, x):
|
|||
# only valid when self.upsample is from build_upsample_layer
|
||||
if is_dynamic_shape(ctx.cfg) and not isinstance(self.upsample, ConvModule):
|
||||
# upsample with `size` instead of `scale_factor`
|
||||
from mmseg.ops import Upsample
|
||||
from mmseg.models.utils import Upsample
|
||||
for c in self.upsample.interp_upsample:
|
||||
if isinstance(c, Upsample):
|
||||
c.size = skip.shape[-2:]
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .adaptive_pool import (adaptive_avg_pool2d__default,
|
||||
adaptive_avg_pool2d__ncnn)
|
||||
from .atan2 import atan2__default
|
||||
from .chunk import chunk__ncnn, chunk__torchscript
|
||||
from .expand import expand__ncnn
|
||||
|
@ -18,7 +20,8 @@ __all__ = [
|
|||
'tensor__getattribute__ncnn', 'group_norm__ncnn', 'interpolate__ncnn',
|
||||
'interpolate__tensorrt', 'linear__ncnn', 'tensor__repeat__tensorrt',
|
||||
'tensor__size__ncnn', 'topk__dynamic', 'topk__tensorrt', 'chunk__ncnn',
|
||||
'triu__default', 'atan2__default', 'normalize__ncnn', 'expand__ncnn',
|
||||
'chunk__torchscript', 'masked_fill__onnxruntime',
|
||||
'tensor__setitem__default'
|
||||
'triu__default', 'atan2__default', 'adaptive_avg_pool2d__default',
|
||||
'normalize__ncnn', 'expand__ncnn', 'chunk__torchscript',
|
||||
'masked_fill__onnxruntime', 'tensor__setitem__default',
|
||||
'adaptive_avg_pool2d__ncnn'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.modules.utils import _pair
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
from mmdeploy.utils import Backend, get_root_logger, is_dynamic_shape
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='torch.nn.functional.adaptive_avg_pool2d')
|
||||
def adaptive_avg_pool2d__default(ctx, input, output_size):
|
||||
"""Rewrite `adaptive_avg_pool2d` for default backend."""
|
||||
output_size = _pair(output_size)
|
||||
if int(output_size[0]) == int(output_size[1]) == 1:
|
||||
out = ctx.origin_func(input, output_size)
|
||||
else:
|
||||
deploy_cfg = ctx.cfg
|
||||
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
|
||||
if is_dynamic_flag:
|
||||
logger = get_root_logger()
|
||||
logger.warning('`adaptive_avg_pool2d` would be '
|
||||
'replaced to `avg_pool2d` explicitly')
|
||||
size = input.shape[2:]
|
||||
k = [int(size[i] / output_size[i]) for i in range(0, len(size))]
|
||||
out = F.avg_pool2d(
|
||||
input,
|
||||
kernel_size=k,
|
||||
stride=k,
|
||||
padding=0,
|
||||
ceil_mode=False,
|
||||
count_include_pad=False)
|
||||
return out
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='torch.nn.functional.adaptive_avg_pool2d',
|
||||
backend=Backend.NCNN.value)
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='torch.nn.functional.adaptive_avg_pool2d',
|
||||
backend=Backend.TORCHSCRIPT.value)
|
||||
def adaptive_avg_pool2d__ncnn(ctx, input, output_size):
|
||||
"""Rewrite `adaptive_avg_pool2d` for ncnn and torchscript backend."""
|
||||
return ctx.origin_func(input, output_size)
|
|
@ -1,8 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .adaptive_avg_pool import (adaptive_avg_pool1d__default,
|
||||
adaptive_avg_pool2d__default,
|
||||
adaptive_avg_pool2d__ncnn,
|
||||
adaptive_avg_pool3d__default)
|
||||
from .adaptive_pool import adaptive_avg_pool2d__ncnn
|
||||
from .gelu import gelu__ncnn
|
||||
from .grid_sampler import grid_sampler__default
|
||||
from .hardsigmoid import hardsigmoid__default
|
||||
|
@ -15,10 +12,8 @@ from .roll import roll_default
|
|||
from .squeeze import squeeze__default
|
||||
|
||||
__all__ = [
|
||||
'adaptive_avg_pool1d__default', 'adaptive_avg_pool2d__default',
|
||||
'adaptive_avg_pool3d__default', 'grid_sampler__default',
|
||||
'hardsigmoid__default', 'instance_norm__tensorrt', 'generic_rnn__ncnn',
|
||||
'squeeze__default', 'adaptive_avg_pool2d__ncnn', 'gelu__ncnn',
|
||||
'layer_norm__ncnn', 'linear__ncnn', '_prepare_onnx_paddings__tensorrt',
|
||||
'roll_default'
|
||||
'grid_sampler__default', 'hardsigmoid__default', 'instance_norm__tensorrt',
|
||||
'generic_rnn__ncnn', 'squeeze__default', 'adaptive_avg_pool2d__ncnn',
|
||||
'gelu__ncnn', 'layer_norm__ncnn', 'linear__ncnn',
|
||||
'_prepare_onnx_paddings__tensorrt', 'roll_default'
|
||||
]
|
||||
|
|
|
@ -1,90 +0,0 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
# Modified from:
|
||||
# https://github.com/pytorch/pytorch/blob/9ade03959392e5a90b74261012de1d806cab2253/torch/onnx/symbolic_opset9.py
|
||||
|
||||
from torch.nn.modules.utils import _pair, _single, _triple
|
||||
from torch.onnx.symbolic_helper import parse_args
|
||||
|
||||
from mmdeploy.core import SYMBOLIC_REWRITER
|
||||
|
||||
|
||||
def _adaptive_pool(name, type, tuple_fn, fn=None):
|
||||
"""Generic adaptive pooling."""
|
||||
|
||||
@parse_args('v', 'is')
|
||||
def symbolic_fn(g, input, output_size):
|
||||
if output_size == [1] * len(output_size) and type == 'AveragePool':
|
||||
return g.op('GlobalAveragePool', input)
|
||||
if not input.isCompleteTensor():
|
||||
if output_size == [1] * len(output_size):
|
||||
return g.op('GlobalMaxPool', input), None
|
||||
raise NotImplementedError(
|
||||
'[Adaptive pool]:input size not accessible')
|
||||
dim = input.type().sizes()[2:]
|
||||
if output_size == [1] * len(output_size) and type == 'MaxPool':
|
||||
return g.op('GlobalMaxPool', input), None
|
||||
|
||||
# compute stride = floor(input_size / output_size)
|
||||
s = [int(dim[i] / output_size[i]) for i in range(0, len(dim))]
|
||||
|
||||
# compute kernel_size = input_size - (output_size - 1) * stride
|
||||
k = [dim[i] - (output_size[i] - 1) * s[i] for i in range(0, len(dim))]
|
||||
|
||||
# call max_poolxd_with_indices to get indices in the output
|
||||
if type == 'MaxPool':
|
||||
return fn(g, input, k, k, (0, ) * len(dim), (1, ) * len(dim),
|
||||
False)
|
||||
output = g.op(
|
||||
type,
|
||||
input,
|
||||
kernel_shape_i=tuple_fn(k),
|
||||
strides_i=tuple_fn(s),
|
||||
ceil_mode_i=False)
|
||||
return output
|
||||
|
||||
return symbolic_fn
|
||||
|
||||
|
||||
adaptive_avg_pool1d = _adaptive_pool('adaptive_avg_pool1d', 'AveragePool',
|
||||
_single)
|
||||
adaptive_avg_pool2d = _adaptive_pool('adaptive_avg_pool2d', 'AveragePool',
|
||||
_pair)
|
||||
adaptive_avg_pool3d = _adaptive_pool('adaptive_avg_pool3d', 'AveragePool',
|
||||
_triple)
|
||||
|
||||
|
||||
@SYMBOLIC_REWRITER.register_symbolic('adaptive_avg_pool1d', is_pytorch=True)
|
||||
def adaptive_avg_pool1d__default(ctx, *args):
|
||||
"""Register default symbolic function for `adaptive_avg_pool1d`.
|
||||
|
||||
Align symbolic of adaptive_pool between different torch version.
|
||||
"""
|
||||
return adaptive_avg_pool1d(*args)
|
||||
|
||||
|
||||
@SYMBOLIC_REWRITER.register_symbolic('adaptive_avg_pool2d', is_pytorch=True)
|
||||
def adaptive_avg_pool2d__default(ctx, *args):
|
||||
"""Register default symbolic function for `adaptive_avg_pool2d`.
|
||||
|
||||
Align symbolic of adaptive_pool between different torch version.
|
||||
"""
|
||||
return adaptive_avg_pool2d(*args)
|
||||
|
||||
|
||||
@SYMBOLIC_REWRITER.register_symbolic('adaptive_avg_pool3d', is_pytorch=True)
|
||||
def adaptive_avg_pool3d__default(ctx, *args):
|
||||
"""Register default symbolic function for `adaptive_avg_pool3d`.
|
||||
|
||||
Align symbolic of adaptive_pool between different torch version.
|
||||
"""
|
||||
return adaptive_avg_pool3d(*args)
|
||||
|
||||
|
||||
@SYMBOLIC_REWRITER.register_symbolic(
|
||||
'adaptive_avg_pool2d', is_pytorch=True, backend='ncnn')
|
||||
def adaptive_avg_pool2d__ncnn(ctx, g, x, output_size):
|
||||
"""Register ncnn symbolic function for `adaptive_avg_pool2d`.
|
||||
|
||||
Align symbolic of adaptive_avg_pool2d in ncnn.
|
||||
"""
|
||||
return g.op('mmdeploy::AdaptiveAvgPool2d', x, output_size)
|
|
@ -0,0 +1,13 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
from mmdeploy.core import SYMBOLIC_REWRITER
|
||||
|
||||
|
||||
@SYMBOLIC_REWRITER.register_symbolic(
|
||||
'adaptive_avg_pool2d', is_pytorch=True, backend='ncnn')
|
||||
def adaptive_avg_pool2d__ncnn(ctx, g, x, output_size):
|
||||
"""Register ncnn symbolic function for `adaptive_avg_pool2d`.
|
||||
|
||||
Align symbolic of adaptive_avg_pool2d in ncnn.
|
||||
"""
|
||||
return g.op('mmdeploy::AdaptiveAvgPool2d', x, output_size)
|
|
@ -92,7 +92,8 @@ class TimeCounter:
|
|||
warmup: int = 1,
|
||||
log_interval: int = 1,
|
||||
with_sync: bool = False,
|
||||
file: Optional[str] = None):
|
||||
file: Optional[str] = None,
|
||||
logger=None):
|
||||
"""Activate the time counter.
|
||||
|
||||
Args:
|
||||
|
@ -106,7 +107,8 @@ class TimeCounter:
|
|||
is `None`.
|
||||
"""
|
||||
assert warmup >= 1
|
||||
logger = get_logger('test', log_file=file)
|
||||
if logger is None:
|
||||
logger = get_logger('test', log_file=file)
|
||||
cls.logger = logger
|
||||
if func_name is not None:
|
||||
warnings.warn('func_name must be globally unique if you call '
|
||||
|
|
|
@ -2,12 +2,9 @@ globals:
|
|||
codebase_dir: ../mmsegmentation
|
||||
checkpoint_force_download: False
|
||||
images:
|
||||
img_leftImg8bit: &img_leftImg8bit ../mmsegmentation/tests/data/pseudo_cityscapes_dataset/leftImg8bit/frankfurt_000000_000294_leftImg8bit.png
|
||||
img_leftImg8bit: &img_leftImg8bit ../mmsegmentation/tests/data/pseudo_cityscapes_dataset/leftImg8bit/val/frankfurt/frankfurt_000000_000294_leftImg8bit.png
|
||||
img_loveda_0: &img_loveda_0 ../mmsegmentation/tests/data/pseudo_loveda_dataset/img_dir/0.png
|
||||
img_loveda_1: &img_loveda_1 ../mmsegmentation/tests/data/pseudo_loveda_dataset/img_dir/1.png
|
||||
img_loveda_2: &img_loveda_2 ../mmsegmentation/tests/data/pseudo_loveda_dataset/img_dir/2.png
|
||||
img_potsdam: &img_potsdam ../mmsegmentation/tests/data/pseudo_potsdam_dataset/img_dir/2_10_0_0_512_512.png
|
||||
img_vaihingen: &img_vaihingen ../mmsegmentation/tests/data/pseudo_vaihingen_dataset/img_dir/area1_0_0_512_512.png
|
||||
|
||||
metric_info: &metric_info
|
||||
mIoU: # named after metafile.Results.Metrics
|
||||
eval_name: mIoU # test.py --metrics args
|
||||
|
@ -25,13 +22,9 @@ globals:
|
|||
onnxruntime:
|
||||
pipeline_ort_static_fp32: &pipeline_ort_static_fp32
|
||||
convert_image: *convert_image
|
||||
backend_test: *default_backend_test
|
||||
sdk_config: *sdk_dynamic
|
||||
deploy_config: configs/mmseg/segmentation_onnxruntime_static-1024x2048.py
|
||||
|
||||
pipeline_ort_static_fp32_512x512: &pipeline_ort_static_fp32_512x512
|
||||
convert_image: *convert_image
|
||||
backend_test: False
|
||||
deploy_config: configs/mmseg/segmentation_onnxruntime_static-512x512.py
|
||||
|
||||
pipeline_ort_dynamic_fp32: &pipeline_ort_dynamic_fp32
|
||||
|
@ -129,13 +122,11 @@ models:
|
|||
- name: FCN
|
||||
metafile: configs/fcn/fcn.yml
|
||||
model_configs:
|
||||
- configs/fcn/fcn_r50-d8_512x1024_40k_cityscapes.py
|
||||
- configs/fcn/fcn_r50-d8_4xb2-40k_cityscapes-512x1024.py
|
||||
pipelines:
|
||||
- *pipeline_ts_fp32
|
||||
- *pipeline_ort_dynamic_fp32
|
||||
# - *pipeline_trt_dynamic_fp32
|
||||
- *pipeline_trt_dynamic_fp16
|
||||
# - *pipeline_trt_dynamic_int8
|
||||
- *pipeline_ncnn_static_fp32
|
||||
- *pipeline_pplnn_dynamic_fp32
|
||||
- *pipeline_openvino_dynamic_fp32
|
||||
|
@ -143,7 +134,7 @@ models:
|
|||
- name: PSPNet
|
||||
metafile: configs/pspnet/pspnet.yml
|
||||
model_configs:
|
||||
- configs/pspnet/pspnet_r50-d8_512x1024_80k_cityscapes.py
|
||||
- configs/pspnet/pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py
|
||||
pipelines:
|
||||
- *pipeline_ts_fp32
|
||||
- *pipeline_ort_static_fp32
|
||||
|
@ -155,7 +146,7 @@ models:
|
|||
- name: deeplabv3
|
||||
metafile: configs/deeplabv3/deeplabv3.yml
|
||||
model_configs:
|
||||
- configs/deeplabv3/deeplabv3_r50-d8_512x1024_40k_cityscapes.py
|
||||
- configs/deeplabv3/deeplabv3_r50-d8_4xb2-40k_cityscapes-512x1024.py
|
||||
pipelines:
|
||||
- *pipeline_ts_fp32
|
||||
- *pipeline_ort_dynamic_fp32
|
||||
|
@ -167,7 +158,7 @@ models:
|
|||
- name: deeplabv3+
|
||||
metafile: configs/deeplabv3plus/deeplabv3plus.yml
|
||||
model_configs:
|
||||
- configs/deeplabv3plus/deeplabv3plus_r50-d8_512x1024_40k_cityscapes.py
|
||||
- configs/deeplabv3plus/deeplabv3plus_r50-d8_4xb2-40k_cityscapes-512x1024.py
|
||||
pipelines:
|
||||
- *pipeline_ts_fp32
|
||||
- *pipeline_ort_dynamic_fp32
|
||||
|
@ -179,7 +170,7 @@ models:
|
|||
- name: Fast-SCNN
|
||||
metafile: configs/fastscnn/fastscnn.yml
|
||||
model_configs:
|
||||
- configs/fastscnn/fast_scnn_lr0.12_8x4_160k_cityscapes.py
|
||||
- configs/fastscnn/fast_scnn_8xb4-160k_cityscapes-512x1024.py
|
||||
pipelines:
|
||||
- *pipeline_ts_fp32
|
||||
- *pipeline_ort_static_fp32
|
||||
|
@ -190,7 +181,7 @@ models:
|
|||
- name: UNet
|
||||
metafile: configs/unet/unet.yml
|
||||
model_configs:
|
||||
- configs/unet/fcn_unet_s5-d16_4x4_512x1024_160k_cityscapes.py
|
||||
- configs/unet/unet-s5-d16_fcn_4xb4-160k_cityscapes-512x1024.py
|
||||
pipelines:
|
||||
- *pipeline_ts_fp32
|
||||
- *pipeline_ort_dynamic_fp32
|
||||
|
@ -202,7 +193,7 @@ models:
|
|||
- name: ANN
|
||||
metafile: configs/ann/ann.yml
|
||||
model_configs:
|
||||
- configs/ann/ann_r50-d8_512x1024_40k_cityscapes.py
|
||||
- configs/ann/ann_r50-d8_4xb2-40k_cityscapes-512x1024.py
|
||||
pipelines:
|
||||
- *pipeline_ort_static_fp32
|
||||
- *pipeline_trt_static_fp16
|
||||
|
@ -210,7 +201,7 @@ models:
|
|||
- name: APCNet
|
||||
metafile: configs/apcnet/apcnet.yml
|
||||
model_configs:
|
||||
- configs/apcnet/apcnet_r50-d8_512x1024_40k_cityscapes.py
|
||||
- configs/apcnet/apcnet_r50-d8_4xb2-40k_cityscapes-512x1024.py
|
||||
pipelines:
|
||||
- *pipeline_ort_dynamic_fp32
|
||||
- *pipeline_trt_dynamic_fp16
|
||||
|
@ -219,7 +210,7 @@ models:
|
|||
- name: BiSeNetV1
|
||||
metafile: configs/bisenetv1/bisenetv1.yml
|
||||
model_configs:
|
||||
- configs/bisenetv1/bisenetv1_r18-d32_4x4_1024x1024_160k_cityscapes.py
|
||||
- configs/bisenetv1/bisenetv1_r18-d32_4xb4-160k_cityscapes-1024x1024.py
|
||||
pipelines:
|
||||
- *pipeline_ort_dynamic_fp32
|
||||
- *pipeline_trt_dynamic_fp16
|
||||
|
@ -229,7 +220,7 @@ models:
|
|||
- name: BiSeNetV2
|
||||
metafile: configs/bisenetv2/bisenetv2.yml
|
||||
model_configs:
|
||||
- configs/bisenetv2/bisenetv2_fcn_4x4_1024x1024_160k_cityscapes.py
|
||||
- configs/bisenetv2/bisenetv2_fcn_4xb4-160k_cityscapes-1024x1024.py
|
||||
pipelines:
|
||||
- *pipeline_ort_dynamic_fp32
|
||||
- *pipeline_trt_dynamic_fp16
|
||||
|
@ -239,7 +230,7 @@ models:
|
|||
- name: CGNet
|
||||
metafile: configs/cgnet/cgnet.yml
|
||||
model_configs:
|
||||
- configs/cgnet/cgnet_512x1024_60k_cityscapes.py
|
||||
- configs/cgnet/cgnet_fcn_4xb8-60k_cityscapes-512x1024.py
|
||||
pipelines:
|
||||
- *pipeline_ort_dynamic_fp32
|
||||
- *pipeline_trt_dynamic_fp16
|
||||
|
@ -249,7 +240,7 @@ models:
|
|||
- name: EMANet
|
||||
metafile: configs/emanet/emanet.yml
|
||||
model_configs:
|
||||
- configs/emanet/emanet_r50-d8_512x1024_80k_cityscapes.py
|
||||
- configs/emanet/emanet_r50-d8_4xb2-80k_cityscapes-512x1024.py
|
||||
pipelines:
|
||||
- *pipeline_ort_dynamic_fp32
|
||||
- *pipeline_trt_dynamic_fp16
|
||||
|
@ -258,7 +249,7 @@ models:
|
|||
- name: EncNet
|
||||
metafile: configs/encnet/encnet.yml
|
||||
model_configs:
|
||||
- configs/encnet/encnet_r50-d8_512x1024_40k_cityscapes.py
|
||||
- configs/encnet/encnet_r50-d8_4xb2-40k_cityscapes-512x1024.py
|
||||
pipelines:
|
||||
- *pipeline_ort_dynamic_fp32
|
||||
- *pipeline_trt_dynamic_fp16
|
||||
|
@ -267,7 +258,7 @@ models:
|
|||
- name: ERFNet
|
||||
metafile: configs/erfnet/erfnet.yml
|
||||
model_configs:
|
||||
- configs/erfnet/erfnet_fcn_4x4_512x1024_160k_cityscapes.py
|
||||
- configs/erfnet/erfnet_fcn_4xb4-160k_cityscapes-512x1024.py
|
||||
pipelines:
|
||||
- *pipeline_ort_dynamic_fp32
|
||||
- *pipeline_trt_dynamic_fp16
|
||||
|
@ -277,7 +268,7 @@ models:
|
|||
- name: FastFCN
|
||||
metafile: configs/fastfcn/fastfcn.yml
|
||||
model_configs:
|
||||
- configs/fastfcn/fastfcn_r50-d32_jpu_aspp_512x1024_80k_cityscapes.py
|
||||
- configs/fastfcn/fastfcn_r50-d32_jpu_aspp_4xb2-80k_cityscapes-512x1024.py
|
||||
pipelines:
|
||||
- *pipeline_ort_dynamic_fp32
|
||||
- *pipeline_trt_dynamic_fp16
|
||||
|
@ -287,7 +278,7 @@ models:
|
|||
- name: GCNet
|
||||
metafile: configs/gcnet/gcnet.yml
|
||||
model_configs:
|
||||
- configs/gcnet/gcnet_r50-d8_512x1024_40k_cityscapes.py
|
||||
- configs/gcnet/gcnet_r50-d8_4xb2-40k_cityscapes-512x1024.py
|
||||
pipelines:
|
||||
- *pipeline_ort_dynamic_fp32
|
||||
- *pipeline_trt_dynamic_fp16
|
||||
|
@ -295,7 +286,7 @@ models:
|
|||
- name: ICNet
|
||||
metafile: configs/icnet/icnet.yml
|
||||
model_configs:
|
||||
- configs/icnet/icnet_r18-d8_832x832_80k_cityscapes.py
|
||||
- configs/icnet/icnet_r18-d8_4xb2-80k_cityscapes-832x832.py
|
||||
pipelines:
|
||||
- *pipeline_ort_static_fp32
|
||||
- *pipeline_trt_static_fp16
|
||||
|
@ -304,26 +295,26 @@ models:
|
|||
- name: ISANet
|
||||
metafile: configs/isanet/isanet.yml
|
||||
model_configs:
|
||||
- configs/isanet/isanet_r50-d8_512x1024_40k_cityscapes.py
|
||||
- configs/isanet/isanet_r50-d8_4xb2-40k_cityscapes-512x1024.py
|
||||
pipelines:
|
||||
- *pipeline_ort_dynamic_fp32
|
||||
- *pipeline_trt_dynamic_fp16
|
||||
- *pipeline_openvino_dynamic_fp32
|
||||
- *pipeline_ort_static_fp32
|
||||
- *pipeline_trt_static_fp16
|
||||
- *pipeline_openvino_static_fp32
|
||||
|
||||
- name: OCRNet
|
||||
metafile: configs/ocrnet/ocrnet.yml
|
||||
model_configs:
|
||||
- configs/ocrnet/ocrnet_hr18s_512x1024_40k_cityscapes.py
|
||||
- configs/ocrnet/ocrnet_hr18s_4xb2-40k_cityscapes-512x1024.py
|
||||
pipelines:
|
||||
- *pipeline_ort_dynamic_fp32
|
||||
- *pipeline_trt_dynamic_fp16
|
||||
- *pipeline_trt_dynamic_fp32
|
||||
- *pipeline_ncnn_static_fp32
|
||||
- *pipeline_openvino_dynamic_fp32
|
||||
|
||||
- name: PointRend
|
||||
metafile: configs/point_rend/point_rend.yml
|
||||
model_configs:
|
||||
- configs/point_rend/pointrend_r50_512x1024_80k_cityscapes.py
|
||||
- configs/point_rend/pointrend_r50_4xb2-80k_cityscapes-512x1024.py
|
||||
pipelines:
|
||||
- *pipeline_ort_dynamic_fp32
|
||||
- *pipeline_trt_dynamic_fp16
|
||||
|
@ -332,18 +323,18 @@ models:
|
|||
- name: Semantic FPN
|
||||
metafile: configs/sem_fpn/sem_fpn.yml
|
||||
model_configs:
|
||||
- configs/sem_fpn/fpn_r50_512x1024_80k_cityscapes.py
|
||||
- configs/sem_fpn/fpn_r50_4xb2-80k_cityscapes-512x1024.py
|
||||
pipelines:
|
||||
- *pipeline_ort_dynamic_fp32
|
||||
- *pipeline_trt_dynamic_fp16
|
||||
- *pipeline_trt_dynamic_fp32
|
||||
- *pipeline_ncnn_static_fp32
|
||||
- *pipeline_openvino_dynamic_fp32
|
||||
|
||||
- name: STDC
|
||||
metafile: configs/stdc/stdc.yml
|
||||
model_configs:
|
||||
- configs/stdc/stdc1_in1k-pre_512x1024_80k_cityscapes.py
|
||||
- configs/stdc/stdc2_in1k-pre_512x1024_80k_cityscapes.py
|
||||
- configs/stdc/stdc1_in1k-pre_4xb12-80k_cityscapes-512x1024.py
|
||||
- configs/stdc/stdc2_in1k-pre_4xb12-80k_cityscapes-512x1024.py
|
||||
pipelines:
|
||||
- *pipeline_ort_dynamic_fp32
|
||||
- *pipeline_trt_dynamic_fp16
|
||||
|
@ -353,15 +344,15 @@ models:
|
|||
- name: UPerNet
|
||||
metafile: configs/upernet/upernet.yml
|
||||
model_configs:
|
||||
- configs/upernet/upernet_r50_512x1024_40k_cityscapes.py
|
||||
- configs/upernet/upernet_r50_4xb2-40k_cityscapes-512x1024.py
|
||||
pipelines:
|
||||
- *pipeline_ort_static_fp32
|
||||
- *pipeline_trt_static_fp16
|
||||
- name: Segmenter
|
||||
metafile: configs/segmenter/segmenter.yml
|
||||
model_configs:
|
||||
- configs/segmenter/segmenter_vit-s_mask_8x1_512x512_160k_ade20k.py
|
||||
- configs/segmenter/segmenter_vit-s_linear_8x1_512x512_160k_ade20k.py
|
||||
- configs/segmenter/segmenter_vit-s_fcn_8xb1-160k_ade20k-512x512.py
|
||||
- configs/segmenter/segmenter_vit-s_mask_8xb1-160k_ade20k-512x512.py
|
||||
pipelines:
|
||||
- *pipeline_ort_static_fp32_512x512
|
||||
- *pipeline_trt_static_fp32_512x512
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .utils import (generate_datasample, generate_mmseg_deploy_config,
|
||||
generate_mmseg_task_processor)
|
||||
|
||||
__all__ = [
|
||||
'generate_datasample', 'generate_mmseg_deploy_config',
|
||||
'generate_mmseg_task_processor'
|
||||
]
|
|
@ -2,44 +2,44 @@
|
|||
# dataset settings
|
||||
dataset_type = 'CityscapesDataset'
|
||||
data_root = '.'
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||
crop_size = (128, 128)
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='MultiScaleFlipAug',
|
||||
img_scale=(128, 128),
|
||||
flip=False,
|
||||
transforms=[
|
||||
dict(type='Resize', keep_ratio=True),
|
||||
dict(type='RandomFlip'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='Collect', keys=['img']),
|
||||
])
|
||||
dict(type='Resize', scale=crop_size, keep_ratio=False),
|
||||
# add loading annotation after ``Resize`` because ground truth
|
||||
# does not need to do resize data transform
|
||||
dict(type='LoadAnnotations', reduce_zero_label=True),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
data = dict(
|
||||
samples_per_gpu=1,
|
||||
workers_per_gpu=1,
|
||||
val=dict(
|
||||
val_dataloader = dict(
|
||||
batch_size=1,
|
||||
num_workers=1,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
img_dir='',
|
||||
ann_dir='',
|
||||
pipeline=test_pipeline),
|
||||
test=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
img_dir='',
|
||||
ann_dir='',
|
||||
data_prefix=dict(img_path='', seg_map_path=''),
|
||||
pipeline=test_pipeline))
|
||||
test_dataloader = val_dataloader
|
||||
|
||||
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
|
||||
test_evaluator = val_evaluator
|
||||
|
||||
# model settings
|
||||
norm_cfg = dict(type='SyncBN', requires_grad=True, momentum=0.01)
|
||||
data_preprocessor = dict(
|
||||
type='SegDataPreProcessor',
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
bgr_to_rgb=True,
|
||||
pad_val=0,
|
||||
seg_pad_val=255)
|
||||
|
||||
model = dict(
|
||||
type='EncoderDecoder',
|
||||
data_preprocessor=data_preprocessor,
|
||||
backbone=dict(
|
||||
type='FastSCNN',
|
||||
downsample_dw_channels=(32, 48),
|
||||
|
@ -64,6 +64,33 @@ model = dict(
|
|||
align_corners=False,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1)),
|
||||
|
||||
# model training and testing settings
|
||||
train_cfg=dict(),
|
||||
test_cfg=dict(mode='whole'))
|
||||
|
||||
# from default_runtime
|
||||
default_scope = 'mmseg'
|
||||
env_cfg = dict(
|
||||
cudnn_benchmark=True,
|
||||
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
|
||||
dist_cfg=dict(backend='nccl'),
|
||||
)
|
||||
log_level = 'INFO'
|
||||
load_from = None
|
||||
resume = False
|
||||
|
||||
vis_backends = [dict(type='LocalVisBackend')]
|
||||
visualizer = dict(
|
||||
type='SegLocalVisualizer', vis_backends=vis_backends, name='visualizer')
|
||||
|
||||
# from schedules
|
||||
val_cfg = dict(type='ValLoop')
|
||||
test_cfg = dict(type='TestLoop')
|
||||
default_hooks = dict(
|
||||
timer=dict(type='IterTimerHook'),
|
||||
logger=dict(type='LoggerHook', interval=50),
|
||||
param_scheduler=dict(type='ParamSchedulerHook'),
|
||||
checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=2000),
|
||||
sampler_seed=dict(type='DistSamplerSeedHook'),
|
||||
)
|
||||
|
|
|
@ -1,12 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import mmengine
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv import ConfigDict
|
||||
from mmseg.models import BACKBONES, HEADS
|
||||
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
|
||||
|
||||
from mmdeploy.codebase import import_codebase
|
||||
from mmdeploy.utils import Backend, Codebase, Task
|
||||
|
@ -15,225 +10,52 @@ from mmdeploy.utils.test import (WrapModel, check_backend, get_model_outputs,
|
|||
|
||||
import_codebase(Codebase.MMSEG)
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
class ExampleBackbone(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(ExampleBackbone, self).__init__()
|
||||
self.conv = nn.Conv2d(3, 3, 3)
|
||||
|
||||
def init_weights(self, pretrained=None):
|
||||
pass
|
||||
|
||||
def forward(self, x):
|
||||
return [self.conv(x)]
|
||||
from .utils import generate_datasample # noqa: E402
|
||||
from .utils import generate_mmseg_deploy_config # noqa: E402
|
||||
from .utils import generate_mmseg_task_processor # noqa: E402
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
class ExampleDecodeHead(BaseDecodeHead):
|
||||
|
||||
def __init__(self):
|
||||
super(ExampleDecodeHead, self).__init__(3, 3, num_classes=19)
|
||||
|
||||
def forward(self, inputs):
|
||||
return self.cls_seg(inputs[0])
|
||||
|
||||
|
||||
def get_model(type='EncoderDecoder',
|
||||
backbone='ExampleBackbone',
|
||||
decode_head='ExampleDecodeHead'):
|
||||
|
||||
from mmseg.models import build_segmentor
|
||||
|
||||
cfg = ConfigDict(
|
||||
type=type,
|
||||
backbone=dict(type=backbone),
|
||||
decode_head=dict(type=decode_head),
|
||||
train_cfg=None,
|
||||
test_cfg=dict(mode='whole'))
|
||||
segmentor = build_segmentor(cfg)
|
||||
|
||||
return segmentor
|
||||
|
||||
|
||||
def _demo_mm_inputs(input_shape=(1, 3, 8, 16), num_classes=10):
|
||||
"""Create a superset of inputs needed to run test or train batches.
|
||||
|
||||
Args:
|
||||
input_shape (tuple):
|
||||
input batch dimensions
|
||||
|
||||
num_classes (int):
|
||||
number of semantic classes
|
||||
"""
|
||||
(N, C, H, W) = input_shape
|
||||
|
||||
rng = np.random.RandomState(0)
|
||||
|
||||
imgs = rng.rand(*input_shape)
|
||||
segs = rng.randint(
|
||||
low=0, high=num_classes - 1, size=(N, 1, H, W)).astype(np.uint8)
|
||||
|
||||
img_metas = [{
|
||||
'img_shape': (H, W, C),
|
||||
'ori_shape': (H, W, C),
|
||||
'pad_shape': (H, W, C),
|
||||
'filename': '<demo>.png',
|
||||
'scale_factor': 1.0,
|
||||
'flip': False,
|
||||
'flip_direction': 'horizontal'
|
||||
} for _ in range(N)]
|
||||
|
||||
mm_inputs = {
|
||||
'imgs': torch.FloatTensor(imgs),
|
||||
'img_metas': img_metas,
|
||||
'gt_semantic_seg': torch.LongTensor(segs)
|
||||
}
|
||||
return mm_inputs
|
||||
|
||||
|
||||
@pytest.mark.parametrize('backend',
|
||||
[Backend.ONNXRUNTIME, Backend.OPENVINO, Backend.NCNN])
|
||||
def test_encoderdecoder_simple_test(backend):
|
||||
@pytest.mark.parametrize('backend', [Backend.ONNXRUNTIME])
|
||||
def test_encoderdecoder_predict(backend):
|
||||
check_backend(backend)
|
||||
segmentor = get_model()
|
||||
segmentor.cpu().eval()
|
||||
|
||||
deploy_cfg = mmcv.Config(
|
||||
dict(
|
||||
backend_config=dict(type=backend.value),
|
||||
onnx_config=dict(output_names=['result'], input_shape=None),
|
||||
codebase_config=dict(type='mmseg', task='Segmentation')))
|
||||
|
||||
if isinstance(segmentor.decode_head, nn.ModuleList):
|
||||
num_classes = segmentor.decode_head[-1].num_classes
|
||||
else:
|
||||
num_classes = segmentor.decode_head.num_classes
|
||||
mm_inputs = _demo_mm_inputs(
|
||||
input_shape=(1, 3, 32, 32), num_classes=num_classes)
|
||||
imgs = mm_inputs.pop('imgs')
|
||||
img_metas = mm_inputs.pop('img_metas')
|
||||
model_inputs = {'img': imgs, 'img_meta': img_metas}
|
||||
model_outputs = get_model_outputs(segmentor, 'simple_test', model_inputs)
|
||||
img_meta = {
|
||||
'img_shape':
|
||||
(img_metas[0]['img_shape'][0], img_metas[0]['img_shape'][1])
|
||||
}
|
||||
wrapped_model = WrapModel(segmentor, 'simple_test', img_meta=img_meta)
|
||||
|
||||
deploy_cfg = generate_mmseg_deploy_config(backend.value)
|
||||
task_processor = generate_mmseg_task_processor(deploy_cfg=deploy_cfg)
|
||||
segmentor = task_processor.build_pytorch_model()
|
||||
size = 256
|
||||
inputs = torch.randn(1, 3, size, size)
|
||||
data_samples = [generate_datasample(size, size)]
|
||||
wrapped_model = WrapModel(segmentor, 'predict', data_samples=data_samples)
|
||||
model_outputs = wrapped_model(inputs)[0].pred_sem_seg.data
|
||||
rewrite_inputs = {
|
||||
'img': imgs,
|
||||
'inputs': inputs,
|
||||
}
|
||||
rewrite_outputs, is_backend_output = get_rewrite_outputs(
|
||||
rewrite_outputs, _ = get_rewrite_outputs(
|
||||
wrapped_model=wrapped_model,
|
||||
model_inputs=rewrite_inputs,
|
||||
deploy_cfg=deploy_cfg)
|
||||
|
||||
model_outputs = torch.tensor(model_outputs[0])
|
||||
rewrite_outputs = rewrite_outputs[0].to(model_outputs).reshape(
|
||||
model_outputs.shape)
|
||||
assert torch.allclose(rewrite_outputs, model_outputs)
|
||||
assert torch.allclose(model_outputs, rewrite_outputs[0].squeeze(0))
|
||||
|
||||
|
||||
@pytest.mark.parametrize('backend', [Backend.ONNXRUNTIME, Backend.OPENVINO])
|
||||
@pytest.mark.parametrize('backend', [Backend.ONNXRUNTIME])
|
||||
def test_basesegmentor_forward(backend):
|
||||
check_backend(backend)
|
||||
segmentor = get_model()
|
||||
segmentor.cpu().eval()
|
||||
|
||||
deploy_cfg = mmcv.Config(
|
||||
dict(
|
||||
backend_config=dict(type=backend.value),
|
||||
onnx_config=dict(output_names=['result'], input_shape=None),
|
||||
codebase_config=dict(type='mmseg', task='Segmentation')))
|
||||
|
||||
if isinstance(segmentor.decode_head, nn.ModuleList):
|
||||
num_classes = segmentor.decode_head[-1].num_classes
|
||||
else:
|
||||
num_classes = segmentor.decode_head.num_classes
|
||||
mm_inputs = _demo_mm_inputs(num_classes=num_classes)
|
||||
imgs = mm_inputs.pop('imgs')
|
||||
img_metas = mm_inputs.pop('img_metas')
|
||||
model_inputs = {
|
||||
'img': [imgs],
|
||||
'img_metas': [img_metas],
|
||||
'return_loss': False
|
||||
deploy_cfg = generate_mmseg_deploy_config(backend.value)
|
||||
task_processor = generate_mmseg_task_processor(deploy_cfg=deploy_cfg)
|
||||
segmentor = task_processor.build_pytorch_model()
|
||||
size = 256
|
||||
inputs = torch.randn(1, 3, size, size)
|
||||
data_samples = [generate_datasample(size, size)]
|
||||
wrapped_model = WrapModel(
|
||||
segmentor, 'forward', data_samples=data_samples, mode='predict')
|
||||
model_outputs = wrapped_model(inputs)[0].pred_sem_seg.data
|
||||
rewrite_inputs = {
|
||||
'inputs': inputs,
|
||||
}
|
||||
model_outputs = get_model_outputs(segmentor, 'forward', model_inputs)
|
||||
|
||||
wrapped_model = WrapModel(segmentor, 'forward')
|
||||
rewrite_inputs = {'img': imgs}
|
||||
rewrite_outputs, is_backend_output = get_rewrite_outputs(
|
||||
rewrite_outputs, _ = get_rewrite_outputs(
|
||||
wrapped_model=wrapped_model,
|
||||
model_inputs=rewrite_inputs,
|
||||
deploy_cfg=deploy_cfg)
|
||||
|
||||
model_outputs = torch.tensor(model_outputs[0])
|
||||
rewrite_outputs = rewrite_outputs[0].to(model_outputs).reshape(
|
||||
model_outputs.shape)
|
||||
assert torch.allclose(rewrite_outputs, model_outputs)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('backend', [Backend.ONNXRUNTIME, Backend.OPENVINO])
|
||||
def test_aspphead_forward(backend):
|
||||
check_backend(backend)
|
||||
from mmseg.models.decode_heads import ASPPHead
|
||||
head = ASPPHead(
|
||||
in_channels=32, channels=16, num_classes=19,
|
||||
dilations=(1, 12, 24)).eval()
|
||||
|
||||
deploy_cfg = mmcv.Config(
|
||||
dict(
|
||||
backend_config=dict(type=backend.value),
|
||||
onnx_config=dict(
|
||||
output_names=['result'], input_shape=(1, 32, 45, 45)),
|
||||
codebase_config=dict(type='mmseg', task='Segmentation')))
|
||||
inputs = [torch.randn(1, 32, 45, 45)]
|
||||
model_inputs = {'inputs': inputs}
|
||||
with torch.no_grad():
|
||||
model_outputs = get_model_outputs(head, 'forward', model_inputs)
|
||||
wrapped_model = WrapModel(head, 'forward')
|
||||
rewrite_inputs = {'inputs': inputs}
|
||||
rewrite_outputs, is_backend_output = get_rewrite_outputs(
|
||||
wrapped_model=wrapped_model,
|
||||
model_inputs=rewrite_inputs,
|
||||
deploy_cfg=deploy_cfg)
|
||||
if is_backend_output:
|
||||
rewrite_outputs = rewrite_outputs[0]
|
||||
rewrite_outputs = rewrite_outputs.to(model_outputs).reshape(
|
||||
model_outputs.shape)
|
||||
assert torch.allclose(
|
||||
rewrite_outputs, model_outputs, rtol=1e-03, atol=1e-05)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('backend',
|
||||
[Backend.ONNXRUNTIME, Backend.OPENVINO, Backend.NCNN])
|
||||
def test_psphead_forward(backend):
|
||||
check_backend(backend)
|
||||
from mmseg.models.decode_heads import PSPHead
|
||||
head = PSPHead(in_channels=32, channels=16, num_classes=19).eval()
|
||||
|
||||
deploy_cfg = mmcv.Config(
|
||||
dict(
|
||||
backend_config=dict(type=backend.value),
|
||||
onnx_config=dict(output_names=['result'], input_shape=None),
|
||||
codebase_config=dict(type='mmseg', task='Segmentation')))
|
||||
inputs = [torch.randn(1, 32, 45, 45)]
|
||||
model_inputs = {'inputs': inputs}
|
||||
with torch.no_grad():
|
||||
model_outputs = get_model_outputs(head, 'forward', model_inputs)
|
||||
wrapped_model = WrapModel(head, 'forward')
|
||||
rewrite_inputs = {'inputs': inputs}
|
||||
rewrite_outputs, is_backend_output = get_rewrite_outputs(
|
||||
wrapped_model=wrapped_model,
|
||||
model_inputs=rewrite_inputs,
|
||||
deploy_cfg=deploy_cfg)
|
||||
if is_backend_output:
|
||||
rewrite_outputs = rewrite_outputs[0]
|
||||
rewrite_outputs = rewrite_outputs.to(model_outputs).reshape(
|
||||
model_outputs.shape)
|
||||
assert torch.allclose(rewrite_outputs, model_outputs, rtol=1, atol=1)
|
||||
assert torch.allclose(model_outputs, rewrite_outputs[0].squeeze(0))
|
||||
|
||||
|
||||
@pytest.mark.parametrize('backend', [Backend.ONNXRUNTIME])
|
||||
|
@ -242,7 +64,7 @@ def test_emamodule_forward(backend):
|
|||
from mmseg.models.decode_heads.ema_head import EMAModule
|
||||
head = EMAModule(8, 2, 2, 1.0).eval()
|
||||
|
||||
deploy_cfg = mmcv.Config(
|
||||
deploy_cfg = mmengine.Config(
|
||||
dict(
|
||||
backend_config=dict(type=backend.value),
|
||||
onnx_config=dict(
|
||||
|
@ -290,7 +112,7 @@ def test_upconvblock_forward(backend, is_dynamic_shape):
|
|||
3: 'w'
|
||||
},
|
||||
} if is_dynamic_shape else None
|
||||
deploy_cfg = mmcv.Config(
|
||||
deploy_cfg = mmengine.Config(
|
||||
dict(
|
||||
backend_config=dict(type=backend.value),
|
||||
onnx_config=dict(
|
||||
|
|
|
@ -1,42 +1,32 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import os
|
||||
from tempfile import NamedTemporaryFile, TemporaryDirectory
|
||||
from typing import Any
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
import mmdeploy.backend.onnxruntime as ort_apis
|
||||
from mmdeploy.apis import build_task_processor
|
||||
from mmdeploy.codebase import import_codebase
|
||||
from mmdeploy.utils import Codebase, load_config
|
||||
from mmdeploy.utils.test import DummyModel, SwitchBackendWrapper
|
||||
from mmdeploy.utils.test import SwitchBackendWrapper
|
||||
|
||||
import_codebase(Codebase.MMSEG)
|
||||
|
||||
from .utils import generate_datasample # noqa: E402
|
||||
from .utils import generate_mmseg_deploy_config # noqa: E402
|
||||
|
||||
model_cfg_path = 'tests/test_codebase/test_mmseg/data/model.py'
|
||||
model_cfg = load_config(model_cfg_path)[0]
|
||||
deploy_cfg = mmcv.Config(
|
||||
dict(
|
||||
backend_config=dict(type='onnxruntime'),
|
||||
codebase_config=dict(type='mmseg', task='Segmentation'),
|
||||
onnx_config=dict(
|
||||
type='onnx',
|
||||
export_params=True,
|
||||
keep_initializers_as_inputs=False,
|
||||
opset_version=11,
|
||||
input_shape=None,
|
||||
input_names=['input'],
|
||||
output_names=['output'])))
|
||||
deploy_cfg = generate_mmseg_deploy_config()
|
||||
|
||||
onnx_file = NamedTemporaryFile(suffix='.onnx').name
|
||||
task_processor = build_task_processor(model_cfg, deploy_cfg, 'cpu')
|
||||
img_shape = (32, 32)
|
||||
img = np.random.rand(*img_shape, 3)
|
||||
tiger_img_path = 'tests/data/tiger.jpeg'
|
||||
img = mmcv.imread(tiger_img_path)
|
||||
img = mmcv.imresize(img, img_shape)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('from_mmrazor', [True, False, '123', 0])
|
||||
|
@ -88,29 +78,31 @@ def test_build_backend_model(backend_model):
|
|||
|
||||
|
||||
def test_create_input():
|
||||
inputs = task_processor.create_input(img, input_shape=img_shape)
|
||||
img_path = 'tests/data/tiger.jpeg'
|
||||
data_preprocessor = task_processor.build_data_preprocessor()
|
||||
inputs = task_processor.create_input(
|
||||
img_path, input_shape=img_shape, data_preprocessor=data_preprocessor)
|
||||
assert isinstance(inputs, tuple) and len(inputs) == 2
|
||||
|
||||
|
||||
def test_run_inference(backend_model):
|
||||
input_dict, _ = task_processor.create_input(img, input_shape=img_shape)
|
||||
results = task_processor.run_inference(backend_model, input_dict)
|
||||
assert results is not None
|
||||
def test_build_data_preprocessor():
|
||||
from mmseg.models import SegDataPreProcessor
|
||||
data_preprocessor = task_processor.build_data_preprocessor()
|
||||
assert isinstance(data_preprocessor, SegDataPreProcessor)
|
||||
|
||||
|
||||
def test_visualize(backend_model):
|
||||
input_dict, _ = task_processor.create_input(img, input_shape=img_shape)
|
||||
results = task_processor.run_inference(backend_model, input_dict)
|
||||
with TemporaryDirectory() as dir:
|
||||
filename = dir + 'tmp.jpg'
|
||||
task_processor.visualize(backend_model, img, results[0], filename, '')
|
||||
assert os.path.exists(filename)
|
||||
def test_get_visualizer():
|
||||
from mmseg.visualization import SegLocalVisualizer
|
||||
tmp_dir = TemporaryDirectory().name
|
||||
visualizer = task_processor.get_visualizer('ort', tmp_dir)
|
||||
assert isinstance(visualizer, SegLocalVisualizer)
|
||||
|
||||
|
||||
def test_get_tensort_from_input():
|
||||
input_data = {'img': [torch.ones(3, 4, 5)]}
|
||||
data = torch.rand(3, 4, 5)
|
||||
input_data = {'inputs': data}
|
||||
inputs = task_processor.get_tensor_from_input(input_data)
|
||||
assert torch.equal(inputs, torch.ones(3, 4, 5))
|
||||
assert torch.equal(inputs, data)
|
||||
|
||||
|
||||
def test_get_partition_cfg():
|
||||
|
@ -122,24 +114,39 @@ def test_get_partition_cfg():
|
|||
|
||||
def test_build_dataset_and_dataloader():
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
val_dataloader = model_cfg['val_dataloader']
|
||||
dataset = task_processor.build_dataset(
|
||||
dataset_cfg=model_cfg, dataset_type='test')
|
||||
dataset_cfg=val_dataloader['dataset'])
|
||||
assert isinstance(dataset, Dataset), 'Failed to build dataset'
|
||||
dataloader = task_processor.build_dataloader(dataset, 1, 1)
|
||||
dataloader = task_processor.build_dataloader(val_dataloader)
|
||||
assert isinstance(dataloader, DataLoader), 'Failed to build dataloader'
|
||||
|
||||
|
||||
def test_single_gpu_test_and_evaluate():
|
||||
from mmcv.parallel import MMDataParallel
|
||||
def test_build_test_runner(backend_model):
|
||||
from mmdeploy.codebase.base.runner import DeployTestRunner
|
||||
temp_dir = TemporaryDirectory().name
|
||||
runner = task_processor.build_test_runner(backend_model, temp_dir)
|
||||
assert isinstance(runner, DeployTestRunner)
|
||||
|
||||
# Prepare dataloader
|
||||
dataloader = DataLoader([])
|
||||
|
||||
# Prepare dummy model
|
||||
model = DummyModel(outputs=[torch.rand([1, 1, *img_shape])])
|
||||
model = MMDataParallel(model, device_ids=[0])
|
||||
assert model is not None
|
||||
# Run test
|
||||
outputs = task_processor.single_gpu_test(model, dataloader)
|
||||
assert outputs is not None
|
||||
task_processor.evaluate_outputs(model_cfg, outputs, [])
|
||||
def test_visualize():
|
||||
h, w = img.shape[:2]
|
||||
datasample = generate_datasample(h, w)
|
||||
output_file = NamedTemporaryFile(suffix='.jpg').name
|
||||
task_processor.visualize(
|
||||
img, datasample, output_file, show_result=False, window_name='test')
|
||||
|
||||
|
||||
def test_get_preprocess():
|
||||
process = task_processor.get_preprocess()
|
||||
assert process is not None
|
||||
|
||||
|
||||
def test_get_postprocess():
|
||||
process = task_processor.get_postprocess()
|
||||
assert isinstance(process, dict)
|
||||
|
||||
|
||||
def test_get_model_name():
|
||||
name = task_processor.get_model_name()
|
||||
assert isinstance(name, str)
|
||||
|
|
|
@ -1,10 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
from tempfile import NamedTemporaryFile
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mmengine
|
||||
import torch
|
||||
|
||||
import mmdeploy.backend.onnxruntime as ort_apis
|
||||
|
@ -14,6 +9,9 @@ from mmdeploy.utils.test import SwitchBackendWrapper, backend_checker
|
|||
|
||||
import_codebase(Codebase.MMSEG)
|
||||
|
||||
from .utils import generate_datasample # noqa: E402
|
||||
from .utils import generate_mmseg_deploy_config # noqa: E402
|
||||
|
||||
NUM_CLASS = 19
|
||||
IMAGE_SIZE = 32
|
||||
|
||||
|
@ -30,101 +28,34 @@ class TestEnd2EndModel:
|
|||
# simplify backend inference
|
||||
cls.wrapper = SwitchBackendWrapper(ORTWrapper)
|
||||
cls.outputs = {
|
||||
'outputs': torch.rand(1, 1, IMAGE_SIZE, IMAGE_SIZE),
|
||||
'output': torch.rand(1, 1, IMAGE_SIZE, IMAGE_SIZE),
|
||||
}
|
||||
cls.wrapper.set(outputs=cls.outputs)
|
||||
deploy_cfg = mmcv.Config(
|
||||
{'onnx_config': {
|
||||
'output_names': ['outputs']
|
||||
}})
|
||||
deploy_cfg = generate_mmseg_deploy_config()
|
||||
|
||||
from mmdeploy.codebase.mmseg.deploy.segmentation_model import \
|
||||
End2EndModel
|
||||
class_names = ['' for i in range(NUM_CLASS)]
|
||||
palette = np.random.randint(0, 255, size=(NUM_CLASS, 3))
|
||||
cls.end2end_model = End2EndModel(
|
||||
Backend.ONNXRUNTIME, [''],
|
||||
device='cpu',
|
||||
class_names=class_names,
|
||||
palette=palette,
|
||||
deploy_cfg=deploy_cfg)
|
||||
Backend.ONNXRUNTIME, [''], device='cpu', deploy_cfg=deploy_cfg)
|
||||
|
||||
@classmethod
|
||||
def teardown_class(cls):
|
||||
cls.wrapper.recover()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'ori_shape',
|
||||
[[IMAGE_SIZE, IMAGE_SIZE, 3], [2 * IMAGE_SIZE, 2 * IMAGE_SIZE, 3]])
|
||||
def test_forward(self, ori_shape):
|
||||
imgs = [torch.rand(1, 3, IMAGE_SIZE, IMAGE_SIZE)]
|
||||
img_metas = [[{
|
||||
'ori_shape': ori_shape,
|
||||
'img_shape': [IMAGE_SIZE, IMAGE_SIZE, 3],
|
||||
'scale_factor': [1., 1., 1., 1.],
|
||||
}]]
|
||||
results = self.end2end_model.forward(imgs, img_metas)
|
||||
assert results is not None, 'failed to get output using '\
|
||||
'End2EndModel'
|
||||
|
||||
def test_forward_test(self):
|
||||
imgs = torch.rand(2, 3, IMAGE_SIZE, IMAGE_SIZE)
|
||||
results = self.end2end_model.forward_test(imgs)
|
||||
assert isinstance(results[0], np.ndarray)
|
||||
|
||||
def test_show_result(self):
|
||||
input_img = np.zeros([IMAGE_SIZE, IMAGE_SIZE, 3])
|
||||
img_path = NamedTemporaryFile(suffix='.jpg').name
|
||||
|
||||
result = [torch.rand(IMAGE_SIZE, IMAGE_SIZE)]
|
||||
self.end2end_model.show_result(
|
||||
input_img, result, '', show=False, out_file=img_path)
|
||||
assert osp.exists(img_path), 'Fails to create drawn image.'
|
||||
|
||||
|
||||
@pytest.mark.parametrize('from_file', [True, False])
|
||||
@pytest.mark.parametrize('data_type', ['train', 'val', 'test'])
|
||||
def test_get_classes_palette_from_config(from_file, data_type):
|
||||
from mmseg.datasets import DATASETS
|
||||
|
||||
from mmdeploy.codebase.mmseg.deploy.segmentation_model import \
|
||||
get_classes_palette_from_config
|
||||
dataset_type = 'CityscapesDataset'
|
||||
data_cfg = mmcv.Config({
|
||||
'data': {
|
||||
data_type:
|
||||
dict(
|
||||
type=dataset_type,
|
||||
data_root='',
|
||||
img_dir='',
|
||||
ann_dir='',
|
||||
pipeline=None)
|
||||
}
|
||||
})
|
||||
|
||||
if from_file:
|
||||
config_path = NamedTemporaryFile(suffix='.py').name
|
||||
with open(config_path, 'w') as file:
|
||||
file.write(data_cfg.pretty_text)
|
||||
data_cfg = config_path
|
||||
|
||||
classes, palette = get_classes_palette_from_config(data_cfg)
|
||||
module = DATASETS.module_dict[dataset_type]
|
||||
assert classes == module.CLASSES, \
|
||||
f'fail to get CLASSES of dataset: {dataset_type}'
|
||||
assert palette == module.PALETTE, \
|
||||
f'fail to get PALETTE of dataset: {dataset_type}'
|
||||
def test_forward(self):
|
||||
from mmseg.structures import SegDataSample
|
||||
imgs = torch.rand(1, 3, IMAGE_SIZE, IMAGE_SIZE)
|
||||
data_samples = [generate_datasample(IMAGE_SIZE, IMAGE_SIZE)]
|
||||
results = self.end2end_model.forward(imgs, data_samples)
|
||||
assert len(results) == 1
|
||||
assert isinstance(results[0], SegDataSample)
|
||||
|
||||
|
||||
@backend_checker(Backend.ONNXRUNTIME)
|
||||
def test_build_segmentation_model():
|
||||
model_cfg = mmcv.Config(
|
||||
model_cfg = mmengine.Config(
|
||||
dict(data=dict(test={'type': 'CityscapesDataset'})))
|
||||
deploy_cfg = mmcv.Config(
|
||||
dict(
|
||||
backend_config=dict(type='onnxruntime'),
|
||||
onnx_config=dict(output_names=['outputs']),
|
||||
codebase_config=dict(type='mmseg')))
|
||||
deploy_cfg = generate_mmseg_deploy_config()
|
||||
|
||||
from mmdeploy.backend.onnxruntime import ORTWrapper
|
||||
ort_apis.__dict__.update({'ORTWrapper': ORTWrapper})
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import mmengine
|
||||
import torch
|
||||
from mmengine.structures import PixelData
|
||||
|
||||
from mmdeploy.apis import build_task_processor
|
||||
from mmdeploy.utils import load_config
|
||||
|
||||
|
||||
def generate_datasample(h, w):
|
||||
from mmseg.structures import SegDataSample
|
||||
metainfo = dict(img_shape=(h, w), ori_shape=(h, w), pad_shape=(h, w))
|
||||
data_sample = SegDataSample()
|
||||
data_sample.set_metainfo(metainfo)
|
||||
seg_pred = torch.randint(0, 2, (1, h, w))
|
||||
seg_gt = torch.randint(0, 2, (1, h, w))
|
||||
data_sample.set_data(dict(pred_sem_seg=PixelData(**dict(data=seg_pred))))
|
||||
data_sample.set_data(
|
||||
dict(gt_sem_seg=PixelData(**dict(data=seg_gt, metainfo=metainfo))))
|
||||
return data_sample
|
||||
|
||||
|
||||
def generate_mmseg_deploy_config(backend='onnxruntime'):
|
||||
deploy_cfg = mmengine.Config(
|
||||
dict(
|
||||
backend_config=dict(type=backend),
|
||||
codebase_config=dict(type='mmseg', task='Segmentation'),
|
||||
onnx_config=dict(
|
||||
type='onnx',
|
||||
export_params=True,
|
||||
keep_initializers_as_inputs=False,
|
||||
opset_version=11,
|
||||
input_shape=None,
|
||||
input_names=['input'],
|
||||
output_names=['output'])))
|
||||
return deploy_cfg
|
||||
|
||||
|
||||
def generate_mmseg_task_processor(model_cfg=None, deploy_cfg=None):
|
||||
if model_cfg is None:
|
||||
model_cfg = 'tests/test_codebase/test_mmseg/data/model.py'
|
||||
if deploy_cfg is None:
|
||||
deploy_cfg = generate_mmseg_deploy_config()
|
||||
model_cfg, deploy_cfg = load_config(model_cfg, deploy_cfg)
|
||||
task_processor = build_task_processor(model_cfg, deploy_cfg, 'cpu')
|
||||
return task_processor
|
|
@ -447,10 +447,14 @@ def get_info_from_log_file(info_type: str, log_path: Path,
|
|||
line_index = -1
|
||||
else:
|
||||
line_index = -2
|
||||
|
||||
if yaml_metric_key in ['accuracy_top-1', 'mIoU', 'Eval-PSNR']:
|
||||
if yaml_metric_key == 'mIoU':
|
||||
metric_line = lines[-1]
|
||||
info_value = metric_line.split('mIoU: ')[1].split(' ')[0]
|
||||
info_value = float(info_value)
|
||||
return info_value
|
||||
elif yaml_metric_key in ['accuracy_top-1', 'Eval-PSNR']:
|
||||
# info in last second line
|
||||
# mmcls, mmseg, mmedit
|
||||
# mmcls, mmeg, mmedit
|
||||
metric_line = lines[line_index - 1]
|
||||
elif yaml_metric_key == 'AP':
|
||||
# info in last tenth line
|
||||
|
@ -655,10 +659,6 @@ def get_backend_fps_metric(deploy_cfg_path: str, model_cfg_path: Path,
|
|||
f'--device {device_type} '
|
||||
|
||||
codebase_name = get_codebase(str(deploy_cfg_path)).value
|
||||
if codebase_name != 'mmedit' and codebase_name != 'mmdet':
|
||||
# mmedit and mmdet dont --metric
|
||||
cmd_str += f'--metrics {eval_name} '
|
||||
|
||||
logger.info(f'Process cmd = {cmd_str}')
|
||||
# Test backend
|
||||
shell_res = subprocess.run(
|
||||
|
@ -937,7 +937,10 @@ def get_backend_result(pipeline_info: dict, model_cfg_path: Path,
|
|||
# Test the model
|
||||
if convert_result and test_type == 'precision':
|
||||
# Get evaluation metric from model config
|
||||
metrics_eval_list = model_cfg.test_evaluator.get('metric', [])
|
||||
if codebase_name == 'mmseg':
|
||||
metrics_eval_list = model_cfg.val_evaluator.iou_metrics
|
||||
else:
|
||||
metrics_eval_list = model_cfg.test_evaluator.get('metric', [])
|
||||
if isinstance(metrics_eval_list, str):
|
||||
# some config is using str only
|
||||
metrics_eval_list = [metrics_eval_list]
|
||||
|
|
Loading…
Reference in New Issue