mmdeploy/mmdeploy/codebase/mmdet/deploy/object_detection.py

291 lines
10 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
from typing import Callable, Dict, Optional, Sequence, Tuple, Union
import numpy as np
import torch
from mmengine import Config
from mmengine.dataset import pseudo_collate
from mmengine.model import BaseDataPreprocessor
from mmengine.registry import Registry
from mmdeploy.codebase.base import CODEBASE, BaseTask, MMCodebase
from mmdeploy.utils import Codebase, Task
from mmdeploy.utils.config_utils import get_input_shape, is_dynamic_shape
MMDET_TASK = Registry('mmdet_tasks')
@CODEBASE.register_module(Codebase.MMDET.value)
class MMDetection(MMCodebase):
"""MMDetection codebase class."""
task_registry = MMDET_TASK
@classmethod
def register_all_modules(cls):
from mmdet.utils.setup_env import register_all_modules
register_all_modules(True)
def process_model_config(model_cfg: Config,
imgs: Union[Sequence[str], Sequence[np.ndarray]],
input_shape: Optional[Sequence[int]] = None):
"""Process the model config.
Args:
model_cfg (Config): The model config.
imgs (Sequence[str] | Sequence[np.ndarray]): Input image(s), accepted
data type are List[str], List[np.ndarray].
input_shape (list[int]): A list of two integer in (width, height)
format specifying input shape. Default: None.
Returns:
Config: the model config after processing.
"""
cfg = model_cfg.copy()
if isinstance(imgs[0], np.ndarray):
cfg = cfg.copy()
# set loading pipeline type
cfg.test_pipeline[0].type = 'LoadImageFromNDArray'
pipeline = cfg.test_pipeline
for i, transform in enumerate(pipeline):
# for static exporting
if input_shape is not None and transform.type == 'Resize':
pipeline[i].keep_ratio = False
pipeline[i].scale = tuple(input_shape)
pipeline = [
transform for transform in pipeline
if transform.type != 'LoadAnnotations'
]
cfg.test_pipeline = pipeline
return cfg
def _get_dataset_metainfo(model_cfg: Config):
"""Get metainfo of dataset.
Args:
model_cfg Config: Input model Config object.
Returns:
list[str]: A list of string specifying names of different class.
"""
from mmdet import datasets # noqa
from mmdet.registry import DATASETS
module_dict = DATASETS.module_dict
for dataloader_name in [
'test_dataloader', 'val_dataloader', 'train_dataloader'
]:
if dataloader_name not in model_cfg:
continue
dataloader_cfg = model_cfg[dataloader_name]
dataset_cfg = dataloader_cfg.dataset
dataset_cls = module_dict.get(dataset_cfg.type, None)
if dataset_cls is None:
continue
if hasattr(dataset_cls, '_load_metainfo') and isinstance(
dataset_cls._load_metainfo, Callable):
meta = dataset_cls._load_metainfo(
dataset_cfg.get('metainfo', None))
if meta is not None:
return meta
if hasattr(dataset_cls, 'METAINFO'):
return dataset_cls.METAINFO
return None
@MMDET_TASK.register_module(Task.OBJECT_DETECTION.value)
class ObjectDetection(BaseTask):
def __init__(self,
model_cfg: Config,
deploy_cfg: Config,
device: str,
experiment_name: str = 'ObjectDetection') -> None:
super().__init__(model_cfg, deploy_cfg, device, experiment_name)
def build_backend_model(self,
model_files: Optional[str] = None,
**kwargs) -> torch.nn.Module:
"""Initialize backend model.
Args:
model_files (Sequence[str]): Input model files.
Returns:
nn.Module: An initialized backend model.
"""
from .object_detection_model import build_object_detection_model
data_preprocessor = deepcopy(
self.model_cfg.model.get('data_preprocessor', {}))
data_preprocessor.setdefault('type', 'mmdet.DetDataPreprocessor')
model = build_object_detection_model(
model_files,
self.model_cfg,
self.deploy_cfg,
device=self.device,
data_preprocessor=data_preprocessor)
model = model.to(self.device)
return model.eval()
def create_input(
self,
imgs: Union[str, np.ndarray],
input_shape: Sequence[int] = None,
data_preprocessor: Optional[BaseDataPreprocessor] = None
) -> Tuple[Dict, torch.Tensor]:
"""Create input for detector.
Args:
imgs (str|np.ndarray): Input image(s), accpeted data type are
`str`, `np.ndarray`.
input_shape (list[int]): A list of two integer in (width, height)
format specifying input shape. Defaults to `None`.
Returns:
tuple: (data, img), meta information for the input image and input.
"""
from mmcv.transforms import Compose
if not isinstance(imgs, (list, tuple)):
imgs = [imgs]
dynamic_flag = is_dynamic_shape(self.deploy_cfg)
cfg = process_model_config(self.model_cfg, imgs, input_shape)
# Drop pad_to_square when static shape. Because static shape should
# ensure the shape before input image.
pipeline = cfg.test_pipeline
if not dynamic_flag:
transform = pipeline[1]
if 'transforms' in transform:
transform_list = transform['transforms']
for i, step in enumerate(transform_list):
if step['type'] == 'Pad' and 'pad_to_square' in step \
and step['pad_to_square']:
transform_list.pop(i)
break
test_pipeline = Compose(pipeline)
data = []
for img in imgs:
# prepare data
if isinstance(img, np.ndarray):
# TODO: remove img_id.
data_ = dict(img=img, img_id=0)
else:
# TODO: remove img_id.
data_ = dict(img_path=img, img_id=0)
# build the data pipeline
data_ = test_pipeline(data_)
data.append(data_)
data = pseudo_collate(data)
if data_preprocessor is not None:
data = data_preprocessor(data, False)
return data, data['inputs']
else:
return data, BaseTask.get_tensor_from_input(data)
@staticmethod
def get_partition_cfg(partition_type: str) -> Dict:
"""Get a certain partition config for mmdet.
Args:
partition_type (str): A string specifying partition type.
Returns:
dict: A dictionary of partition config.
"""
from .model_partition_cfg import MMDET_PARTITION_CFG
assert (partition_type in MMDET_PARTITION_CFG), \
f'Unknown partition_type {partition_type}'
return MMDET_PARTITION_CFG[partition_type]
def get_preprocess(self, *args, **kwargs) -> Dict:
"""Get the preprocess information for SDK.
Return:
dict: Composed of the preprocess information.
"""
input_shape = get_input_shape(self.deploy_cfg)
model_cfg = process_model_config(self.model_cfg, [''], input_shape)
pipeline = model_cfg.test_pipeline
meta_keys = [
'filename', 'ori_filename', 'ori_shape', 'img_shape', 'pad_shape',
'scale_factor', 'flip', 'flip_direction', 'img_norm_cfg',
'valid_ratio'
]
transforms = [
item for item in pipeline if 'Random' not in item['type']
and 'Annotation' not in item['type']
]
for i, transform in enumerate(transforms):
if transform['type'] == 'PackDetInputs':
meta_keys += transform[
'meta_keys'] if 'meta_keys' in transform else []
transform['meta_keys'] = list(set(meta_keys))
transform['keys'] = ['img']
transforms[i]['type'] = 'Collect'
if transform['type'] == 'Resize':
transforms[i]['size'] = transforms[i].pop('scale')
data_preprocessor = model_cfg.model.data_preprocessor
transforms.insert(-1, dict(type='DefaultFormatBundle'))
transforms.insert(
-2,
dict(
type='Pad',
size_divisor=data_preprocessor.get('pad_size_divisor', 1)))
transforms.insert(
-3,
dict(
type='Normalize',
to_rgb=data_preprocessor.get('bgr_to_rgb', False),
mean=data_preprocessor.get('mean', [0, 0, 0]),
std=data_preprocessor.get('std', [1, 1, 1])))
return transforms
def get_postprocess(self, *args, **kwargs) -> Dict:
"""Get the postprocess information for SDK.
Return:
dict: Composed of the postprocess information.
"""
params = self.model_cfg.model.test_cfg
type = 'ResizeBBox' # default for object detection
if 'rpn' in params:
params['min_bbox_size'] = params['rpn']['min_bbox_size']
if 'rcnn' in params:
params['score_thr'] = params['rcnn']['score_thr']
if 'mask_thr_binary' in params['rcnn']:
params['mask_thr_binary'] = params['rcnn']['mask_thr_binary']
type = 'ResizeInstanceMask' # for instance-seg
return dict(type=type, params=params)
def get_model_name(self, *args, **kwargs) -> str:
"""Get the model name.
Return:
str: the name of the model.
"""
assert 'type' in self.model_cfg.model, 'model config contains no type'
name = self.model_cfg.model.type.lower()
return name
def get_visualizer(self, name: str, save_dir: str):
visualizer = super().get_visualizer(name, save_dir)
metainfo = _get_dataset_metainfo(self.model_cfg)
if metainfo is not None:
visualizer.dataset_meta = metainfo
return visualizer