diff --git a/docs/en/04-supported-codebases/mmdet.md b/docs/en/04-supported-codebases/mmdet.md index 20956d057..721a2b27d 100644 --- a/docs/en/04-supported-codebases/mmdet.md +++ b/docs/en/04-supported-codebases/mmdet.md @@ -1,14 +1,15 @@ # MMDetection Deployment -- [Installation](#installation) - - [Install mmdet](#install-mmdet) - - [Install mmdeploy](#install-mmdeploy) -- [Convert model](#convert-model) -- [Model specification](#model-specification) -- [Model inference](#model-inference) - - [Backend model inference](#backend-model-inference) - - [SDK model inference](#sdk-model-inference) -- [Supported models](#supported-models) +- [MMDetection Deployment](#mmdetection-deployment) + - [Installation](#installation) + - [Install mmdet](#install-mmdet) + - [Install mmdeploy](#install-mmdeploy) + - [Convert model](#convert-model) + - [Model specification](#model-specification) + - [Model inference](#model-inference) + - [Backend model inference](#backend-model-inference) + - [SDK model inference](#sdk-model-inference) + - [Supported models](#supported-models) ______________________________________________________________________ @@ -26,7 +27,7 @@ There are several methods to install mmdeploy, among which you can choose an app **Method I:** Install precompiled package -> **TODO**. MMDeploy hasn't released based on dev-1.x branch. +> **TODO**. MMDeploy hasn't released based on 1.x branch. **Method II:** Build using scripts @@ -34,7 +35,7 @@ If your target platform is **Ubuntu 18.04 or later version**, we encourage you t [scripts](../01-how-to-build/build_from_script.md). For example, the following commands install mmdeploy as well as inference engine - `ONNX Runtime`. ```shell -git clone --recursive -b dev-1.x https://github.com/open-mmlab/mmdeploy.git +git clone --recursive -b 1.x https://github.com/open-mmlab/mmdeploy.git cd mmdeploy python3 tools/scripts/build_ubuntu_x64_ort.py $(nproc) export PYTHONPATH=$(pwd)/build/lib:$PYTHONPATH @@ -47,7 +48,7 @@ If neither **I** nor **II** meets your requirements, [building mmdeploy from sou ## Convert model -You can use [tools/deploy.py](https://github.com/open-mmlab/mmdeploy/blob/dev-1.x/tools/deploy.py) to convert mmdet models to the specified backend models. Its detailed usage can be learned from [here](https://github.com/open-mmlab/mmdeploy/blob/master/docs/en/02-how-to-run/convert_model.md#usage). +You can use [tools/deploy.py](https://github.com/open-mmlab/mmdeploy/blob/1.x/tools/deploy.py) to convert mmdet models to the specified backend models. Its detailed usage can be learned from [here](../02-how-to-run/convert_model.md). The command below shows an example about converting `Faster R-CNN` model to onnx model that can be inferred by ONNX Runtime. @@ -67,7 +68,7 @@ python tools/deploy.py \ --dump-info ``` -It is crucial to specify the correct deployment config during model conversion. We've already provided builtin deployment config [files](https://github.com/open-mmlab/mmdeploy/tree/dev-1.x/configs/mmdet) of all supported backends for mmdetection, under which the config file path follows the pattern: +It is crucial to specify the correct deployment config during model conversion. We've already provided builtin deployment config [files](https://github.com/open-mmlab/mmdeploy/tree/1.x/configs/mmdet) of all supported backends for mmdetection, under which the config file path follows the pattern: ``` {task}/{task}_{backend}-{precision}_{static | dynamic}_{shape}.py @@ -89,7 +90,7 @@ It is crucial to specify the correct deployment config during model conversion. - **{shape}:** input shape or shape range of a model -Therefore, in the above example, you can also convert `faster r-cnn` to other backend models by changing the deployment config file `detection_onnxruntime_dynamic.py` to [others](https://github.com/open-mmlab/mmdeploy/tree/dev-1.x/configs/mmdet/detection), e.g., converting to tensorrt-fp16 model by `detection_tensorrt-fp16_dynamic-320x320-1344x1344.py`. +Therefore, in the above example, you can also convert `faster r-cnn` to other backend models by changing the deployment config file `detection_onnxruntime_dynamic.py` to [others](https://github.com/open-mmlab/mmdeploy/tree/1.x/configs/mmdet/detection), e.g., converting to tensorrt-fp16 model by `detection_tensorrt-fp16_dynamic-320x320-1344x1344.py`. ```{tip} When converting mmdet models to tensorrt models, --device should be set to "cuda" @@ -184,7 +185,7 @@ for index, bbox, label_id in zip(indices, bboxes, labels): cv2.imwrite('output_detection.png', img) ``` -Besides python API, mmdeploy SDK also provides other FFI (Foreign Function Interface), such as C, C++, C#, Java and so on. You can learn their usage from [demos](https://github.com/open-mmlab/mmdeploy/tree/dev-1.x/demo). +Besides python API, mmdeploy SDK also provides other FFI (Foreign Function Interface), such as C, C++, C#, Java and so on. You can learn their usage from [demos](https://github.com/open-mmlab/mmdeploy/tree/1.x/demo). ## Supported models diff --git a/docs/zh_cn/04-supported-codebases/mmdet.md b/docs/zh_cn/04-supported-codebases/mmdet.md index 74de4a3c7..39d8ad8bc 100644 --- a/docs/zh_cn/04-supported-codebases/mmdet.md +++ b/docs/zh_cn/04-supported-codebases/mmdet.md @@ -1,14 +1,15 @@ # MMDetection 模型部署 -- [安装](#安装) - - [安装 mmdet](#安装-mmdet) - - [安装 mmdeploy](#安装-mmdeploy) -- [模型转换](#模型转换) -- [模型规范](#模型规范) -- [模型推理](#模型推理) -- [后端模型推理](#后端模型推理) -- [SDK 模型推理](#sdk-模型推理) -- [模型支持列表](#模型支持列表) +- [MMDetection 模型部署](#mmdetection-模型部署) + - [安装](#安装) + - [安装 mmdet](#安装-mmdet) + - [安装 mmdeploy](#安装-mmdeploy) + - [模型转换](#模型转换) + - [模型规范](#模型规范) + - [模型推理](#模型推理) + - [后端模型推理](#后端模型推理) + - [SDK 模型推理](#sdk-模型推理) + - [模型支持列表](#模型支持列表) ______________________________________________________________________ @@ -26,7 +27,7 @@ mmdeploy 有以下几种安装方式: **方式一:** 安装预编译包 -> 待 mmdeploy 正式发布 dev-1.x,再补充 +> 待 mmdeploy 正式发布 1.x,再补充 **方式二:** 一键式脚本安装 @@ -34,7 +35,7 @@ mmdeploy 有以下几种安装方式: 比如,以下命令可以安装 mmdeploy 以及配套的推理引擎——`ONNX Runtime`. ```shell -git clone --recursive -b dev-1.x https://github.com/open-mmlab/mmdeploy.git +git clone --recursive -b 1.x https://github.com/open-mmlab/mmdeploy.git cd mmdeploy python3 tools/scripts/build_ubuntu_x64_ort.py $(nproc) export PYTHONPATH=$(pwd)/build/lib:$PYTHONPATH @@ -47,7 +48,7 @@ export LD_LIBRARY_PATH=$(pwd)/../mmdeploy-dep/onnxruntime-linux-x64-1.8.1/lib/:$ ## 模型转换 -你可以使用 [tools/deploy.py](https://github.com/open-mmlab/mmdeploy/blob/dev-1.x/tools/deploy.py) 把 mmdet 模型一键式转换为推理后端模型。 +你可以使用 [tools/deploy.py](https://github.com/open-mmlab/mmdeploy/blob/1.x/tools/deploy.py) 把 mmdet 模型一键式转换为推理后端模型。 该工具的详细使用说明请参考[这里](https://github.com/open-mmlab/mmdeploy/blob/master/docs/en/02-how-to-run/convert_model.md#usage). 以下,我们将演示如何把 `Faster R-CNN` 转换为 onnx 模型。 @@ -68,7 +69,7 @@ python tools/deploy.py \ --dump-info ``` -转换的关键之一是使用正确的配置文件。项目中已内置了各后端部署[配置文件](https://github.com/open-mmlab/mmdeploy/tree/dev-1.x/configs/mmdet)。 +转换的关键之一是使用正确的配置文件。项目中已内置了各后端部署[配置文件](https://github.com/open-mmlab/mmdeploy/tree/1.x/configs/mmdet)。 文件的命名模式是: ``` @@ -187,7 +188,7 @@ cv2.imwrite('output_detection.png', img) ``` 除了python API,mmdeploy SDK 还提供了诸如 C、C++、C#、Java等多语言接口。 -你可以参考[样例](https://github.com/open-mmlab/mmdeploy/tree/dev-1.x/demo)学习其他语言接口的使用方法。 +你可以参考[样例](https://github.com/open-mmlab/mmdeploy/tree/1.x/demo)学习其他语言接口的使用方法。 ## 模型支持列表 diff --git a/mmdeploy/codebase/mmdet/deploy/object_detection.py b/mmdeploy/codebase/mmdet/deploy/object_detection.py index 9ccb6fc8f..1e88372e4 100644 --- a/mmdeploy/codebase/mmdet/deploy/object_detection.py +++ b/mmdeploy/codebase/mmdet/deploy/object_detection.py @@ -25,12 +25,14 @@ class MMDetection(MMCodebase): @classmethod def register_deploy_modules(cls): + """register all rewriters for mmdet.""" import mmdeploy.codebase.mmdet.models # noqa: F401 import mmdeploy.codebase.mmdet.ops import mmdeploy.codebase.mmdet.structures # noqa: F401 @classmethod def register_all_modules(cls): + """register all related modules and rewriters for mmdet.""" from mmdet.utils.setup_env import register_all_modules cls.register_deploy_modules() @@ -114,6 +116,15 @@ def _get_dataset_metainfo(model_cfg: Config): @MMDET_TASK.register_module(Task.OBJECT_DETECTION.value) class ObjectDetection(BaseTask): + """Object Detection task. + + Args: + model_cfg (Config): The config of the model in mmdet. + deploy_cfg (Config): The config of deployment. + device (str): Device name. + experiment_name (str, optional): The experiment name used to create + runner. Defaults to 'ObjectDetection'. + """ def __init__(self, model_cfg: Config, @@ -161,6 +172,8 @@ class ObjectDetection(BaseTask): `str`, `np.ndarray`. input_shape (list[int]): A list of two integer in (width, height) format specifying input shape. Defaults to `None`. + data_preprocessor (BaseDataPreprocessor): The data preprocessor + of the model. Default to `None`. Returns: tuple: (data, img), meta information for the input image and input. diff --git a/mmdeploy/codebase/mmdet/deploy/object_detection_model.py b/mmdeploy/codebase/mmdet/deploy/object_detection_model.py index 50379fd2f..2de340934 100644 --- a/mmdeploy/codebase/mmdet/deploy/object_detection_model.py +++ b/mmdeploy/codebase/mmdet/deploy/object_detection_model.py @@ -32,9 +32,9 @@ class End2EndModel(BaseBackendModel): backend_files (Sequence[str]): Paths to all required backend files (e.g. '.onnx' for ONNX Runtime, '.param' and '.bin' for ncnn). device (str): A string specifying device type. - class_names (Sequence[str]): A list of string specifying class names. deploy_cfg (str|Config): Deployment config file or loaded Config object. + data_preprocessor (dict|nn.Module): The data preprocessor. """ def __init__(self, @@ -172,6 +172,17 @@ class End2EndModel(BaseBackendModel): data_samples: Optional[List[BaseDataElement]] = None, mode: str = 'predict', **kwargs) -> Any: + """The model forward. + + Args: + inputs (torch.Tensor): The input tensors + data_samples (List[BaseDataElement], optional): The data samples. + Defaults to None. + mode (str, optional): forward mode, only support `predict`. + + Returns: + Any: Model output. + """ assert mode == 'predict', 'Deploy model only allow mode=="predict".' inputs = inputs.contiguous() outputs = self.predict(inputs) @@ -247,7 +258,7 @@ class End2EndModel(BaseBackendModel): return results def predict(self, imgs: Tensor) -> Tuple[np.ndarray, np.ndarray]: - """The interface for forward test. + """The interface for predict. Args: imgs (Tensor): Input image(s) in [N x C x H x W] format. @@ -270,19 +281,23 @@ class PartitionSingleStageModel(End2EndModel): backend_files (Sequence[str]): Paths to all required backend files (e.g. '.onnx' for ONNX Runtime, '.param' and '.bin' for ncnn). device (str): A string specifying device type. - class_names (Sequence[str]): A list of string specifying class names. model_cfg (str|Config): Input model config file or Config object. deploy_cfg (str|Config): Deployment config file or loaded Config object. + data_preprocessor (dict|nn.Module): The data preprocessor. """ - def __init__(self, backend: Backend, backend_files: Sequence[str], - device: str, class_names: Sequence[str], - model_cfg: Union[str, Config], deploy_cfg: Union[str, Config], + def __init__(self, + backend: Backend, + backend_files: Sequence[str], + device: str, + model_cfg: Union[str, Config], + deploy_cfg: Union[str, Config], + data_preprocessor: Optional[Union[dict, nn.Module]] = None, **kwargs): - super().__init__(backend, backend_files, device, class_names, - deploy_cfg, **kwargs) + super().__init__(backend, backend_files, device, deploy_cfg, + data_preprocessor, **kwargs) # load cfg if necessary model_cfg = load_config(model_cfg)[0] self.model_cfg = model_cfg @@ -353,16 +368,19 @@ class PartitionTwoStageModel(End2EndModel): backend_files (Sequence[str]): Paths to all required backend files (e.g. '.onnx' for ONNX Runtime, '.param' and '.bin' for ncnn). device (str): A string specifying device type. - class_names (Sequence[str]): A list of string specifying class names. model_cfg (str|Config): Input model config file or Config object. deploy_cfg (str|Config): Deployment config file or loaded Config object. """ - def __init__(self, backend: Backend, backend_files: Sequence[str], - device: str, class_names: Sequence[str], - model_cfg: Union[str, Config], deploy_cfg: Union[str, Config], + def __init__(self, + backend: Backend, + backend_files: Sequence[str], + device: str, + model_cfg: Union[str, Config], + deploy_cfg: Union[str, Config], + data_preprocessor: Optional[Union[dict, nn.Module]] = None, **kwargs): # load cfg if necessary @@ -370,8 +388,8 @@ class PartitionTwoStageModel(End2EndModel): self.model_cfg = model_cfg - super().__init__(backend, backend_files, device, class_names, - deploy_cfg, **kwargs) + super().__init__(backend, backend_files, device, deploy_cfg, + data_preprocessor, **kwargs) from mmdet.models.builder import build_head, build_roi_extractor from ..models.roi_heads.bbox_head import bbox_head__get_bboxes @@ -535,7 +553,6 @@ class NCNNEnd2EndModel(End2EndModel): backend_files (Sequence[str]): Paths to all required backend files (e.g. '.onnx' for ONNX Runtime, '.param' and '.bin' for ncnn). device (str): A string specifying device type. - class_names (Sequence[str]): A list of string specifying class names. model_cfg (str|Config): Input model config file or Config object. deploy_cfg (str|Config): Deployment config file or loaded Config @@ -596,7 +613,8 @@ class SDKEnd2EndModel(End2EndModel): Args: img (Sequence[Tensor]): A list contains input image(s) in [N x C x H x W] format. - img_metas (Sequence[dict]): A list of meta info for image(s). + data_samples (List[BaseDataElement]): A list of meta info + for image(s). *args: Other arguments. **kwargs: Other key-pair arguments. @@ -639,6 +657,16 @@ class RKNNModel(End2EndModel): """RKNNModel. RKNN inference class, converts RKNN output to mmdet format. + + Args: + backend (Backend): The backend enum, specifying backend type. + backend_files (Sequence[str]): Paths to all required backend files + (e.g. '.onnx' for ONNX Runtime, '.param' and '.bin' for ncnn). + device (str): A string specifying device type. + model_cfg (str|Config): Input model config file or Config + object. + deploy_cfg (str|Config): Deployment config file or loaded Config + object. """ def __init__(self, backend: Backend, backend_files: Sequence[str], @@ -653,7 +681,19 @@ class RKNNModel(End2EndModel): model_cfg = load_config(model_cfg)[0] self.model_cfg = model_cfg - def _get_bboxes(self, outputs, metainfos): + def _get_bboxes(self, outputs: List[Tensor], metainfos: Any): + """get bboxes from output by meta infos. + + Args: + outputs (List[Tensor]): The backend wrapper outputs. + metainfos (Any): The meta infos of inputs. + + Raises: + NotImplementedError: Head type not supported. + + Returns: + Any: model outputs. + """ from mmdet.models import build_head head_cfg = self.model_cfg._cfg_dict.model.bbox_head head = build_head(head_cfg) diff --git a/mmdeploy/codebase/mmdet/deploy/utils.py b/mmdeploy/codebase/mmdet/deploy/utils.py index f1a59849f..656200234 100644 --- a/mmdeploy/codebase/mmdet/deploy/utils.py +++ b/mmdeploy/codebase/mmdet/deploy/utils.py @@ -41,6 +41,7 @@ def clip_bboxes(x1: Tensor, y1: Tensor, x2: Tensor, y2: Tensor, x2 (Tensor): The x2 for bounding boxes. y2 (Tensor): The y2 for bounding boxes. max_shape (Tensor | Sequence[int]): The (H,W) of original image. + Returns: tuple(Tensor): The clipped x1, y1, x2, y2. """ @@ -87,6 +88,7 @@ def clip_bboxes__trt8(ctx, x1: Tensor, y1: Tensor, x2: Tensor, y2: Tensor, x2 (Tensor): The x2 for bounding boxes. y2 (Tensor): The y2 for bounding boxes. max_shape (Tensor | Sequence[int]): The (H,W) of original image. + Returns: tuple(Tensor): The clipped x1, y1, x2, y2. """ @@ -125,7 +127,7 @@ def pad_with_value(x: Tensor, def pad_with_value_if_necessary(x: Tensor, pad_dim: int, pad_size: int, - pad_value: Optional[Any] = None): + pad_value: Optional[Any] = None) -> Tensor: """Pad a tensor with a value along some dim if necessary. Args: @@ -144,7 +146,7 @@ def pad_with_value_if_necessary(x: Tensor, def __pad_with_value_if_necessary(x: Tensor, pad_dim: int, pad_size: int, - pad_value: Optional[Any] = None): + pad_value: Optional[Any] = None) -> Tensor: """Pad a tensor with a value along some dim, do nothing on default. Args: @@ -162,11 +164,12 @@ def __pad_with_value_if_necessary(x: Tensor, @FUNCTION_REWRITER.register_rewriter( 'mmdeploy.codebase.mmdet.deploy.utils.__pad_with_value_if_necessary', backend=Backend.TENSORRT.value) -def __pad_with_value_if_necessary__tensorrt(ctx, - x: Tensor, - pad_dim: int, - pad_size: int, - pad_value: Optional[Any] = None): +def __pad_with_value_if_necessary__tensorrt( + ctx, + x: Tensor, + pad_dim: int, + pad_size: int, + pad_value: Optional[Any] = None) -> Tensor: """Pad a tensor with a value along some dim. Args: diff --git a/mmdeploy/codebase/mmdet/models/dense_heads/detr_head.py b/mmdeploy/codebase/mmdet/models/dense_heads/detr_head.py index 57a322fce..3ef050d5c 100644 --- a/mmdeploy/codebase/mmdet/models/dense_heads/detr_head.py +++ b/mmdeploy/codebase/mmdet/models/dense_heads/detr_head.py @@ -41,7 +41,7 @@ def detrhead__predict_by_feat__default(ctx, all_bbox_preds_list: List[Tensor], batch_img_metas: List[dict], rescale: bool = True): - """Rewrite `get_bboxes` of `FoveaHead` for default backend.""" + """Rewrite `predict_by_feat` of `FoveaHead` for default backend.""" from mmdet.structures.bbox import bbox_cxcywh_to_xyxy cls_scores = all_cls_scores_list[-1][-1] bbox_preds = all_bbox_preds_list[-1][-1] diff --git a/mmdeploy/codebase/mmdet/models/dense_heads/gfl_head.py b/mmdeploy/codebase/mmdet/models/dense_heads/gfl_head.py index 43258b7db..4583d0101 100644 --- a/mmdeploy/codebase/mmdet/models/dense_heads/gfl_head.py +++ b/mmdeploy/codebase/mmdet/models/dense_heads/gfl_head.py @@ -110,8 +110,18 @@ def gfl_head__predict_by_feat(ctx, 1).reshape(batch_size, -1).sigmoid() score_factors = score_factors.unsqueeze(2) - bbox_pred = batched_integral(self.integral, - bbox_pred.permute(0, 2, 3, 1)) * stride[0] + + def _batched_integral(intergral, x): + batch_size = x.size(0) + x = F.softmax( + x.reshape(batch_size, -1, intergral.reg_max + 1), dim=2) + x = F.linear(x, + intergral.project.type_as(x).unsqueeze(0)).reshape( + batch_size, -1, 4) + return x + + bbox_pred = _batched_integral( + self.integral, bbox_pred.permute(0, 2, 3, 1)) * stride[0] if not is_dynamic_flag: priors = priors.data if pre_topk > 0: @@ -181,12 +191,3 @@ def gfl_head__predict_by_feat(ctx, score_threshold=score_threshold, pre_top_k=pre_top_k, keep_top_k=keep_top_k) - - -def batched_integral(intergral, x): - batch_size = x.size(0) - x = F.softmax(x.reshape(batch_size, -1, intergral.reg_max + 1), dim=2) - x = F.linear(x, - intergral.project.type_as(x).unsqueeze(0)).reshape( - batch_size, -1, 4) - return x diff --git a/mmdeploy/codebase/mmdet/models/dense_heads/reppoints_head.py b/mmdeploy/codebase/mmdet/models/dense_heads/reppoints_head.py index dfc5e0ee3..e24998a92 100644 --- a/mmdeploy/codebase/mmdet/models/dense_heads/reppoints_head.py +++ b/mmdeploy/codebase/mmdet/models/dense_heads/reppoints_head.py @@ -64,7 +64,7 @@ def reppoints_head__predict_by_feat( cfg: Optional[ConfigDict] = None, rescale: bool = False, with_nms: bool = True) -> InstanceData: - """Rewrite `get_bboxes` of `RepPointsHead` for default backend. + """Rewrite `predict_by_feat` of `RepPointsHead` for default backend. Rewrite this function to deploy model, transform network output for a batch into bbox predictions. diff --git a/mmdeploy/codebase/mmdet/models/dense_heads/rpn_head.py b/mmdeploy/codebase/mmdet/models/dense_heads/rpn_head.py index 53a3c29e5..05f248486 100644 --- a/mmdeploy/codebase/mmdet/models/dense_heads/rpn_head.py +++ b/mmdeploy/codebase/mmdet/models/dense_heads/rpn_head.py @@ -160,6 +160,7 @@ def rpn_head__predict_by_feat(ctx, keep_top_k=keep_top_k) +# TODO: Fix for 1.x @FUNCTION_REWRITER.register_rewriter( 'mmdet.models.dense_heads.RPNHead.get_bboxes', backend=Backend.NCNN.value) def rpn_head__get_bboxes__ncnn(ctx, diff --git a/mmdeploy/codebase/mmdet/models/detectors/single_stage.py b/mmdeploy/codebase/mmdet/models/detectors/single_stage.py index 8dfe6c3dd..adfb6831f 100644 --- a/mmdeploy/codebase/mmdet/models/detectors/single_stage.py +++ b/mmdeploy/codebase/mmdet/models/detectors/single_stage.py @@ -43,8 +43,7 @@ def single_stage_detector__forward(ctx, data_samples (List[:obj:`DetDataSample`]): The Data Samples. It usually includes information such as `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. - rescale (bool): Whether to rescale the results. - Defaults to True. + mode (str): export mode, not used. Returns: tuple[Tensor]: Detection results of the diff --git a/mmdeploy/codebase/mmdet/models/detectors/two_stage.py b/mmdeploy/codebase/mmdet/models/detectors/two_stage.py index 1046319e3..9b20fed83 100644 --- a/mmdeploy/codebase/mmdet/models/detectors/two_stage.py +++ b/mmdeploy/codebase/mmdet/models/detectors/two_stage.py @@ -48,8 +48,7 @@ def two_stage_detector__forward(ctx, data_samples (List[:obj:`DetDataSample`]): The Data Samples. It usually includes information such as `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. - rescale (bool): Whether to rescale the results. - Defaults to True. + mode (str): export mode, not used. Returns: tuple[Tensor]: Detection results of the diff --git a/mmdeploy/codebase/mmdet/models/necks.py b/mmdeploy/codebase/mmdet/models/necks.py index 2931de9b0..adc40fa12 100644 --- a/mmdeploy/codebase/mmdet/models/necks.py +++ b/mmdeploy/codebase/mmdet/models/necks.py @@ -2,12 +2,16 @@ import torch from mmdeploy.core import FUNCTION_REWRITER -from mmdeploy.utils import Backend +from mmdeploy.utils import Backend, get_root_logger @FUNCTION_REWRITER.register_rewriter( func_name='mmdet.models.necks.ssd_neck.L2Norm.forward') def l2norm__forward__default(ctx, self, x): + """Default rewriter for l2norm. + + Implement with functinoal.normalize . + """ return torch.nn.functional.normalize( x, dim=1) * self.weight[None, :, None, None] @@ -20,10 +24,16 @@ def l2norm__forward__tensorrt(ctx, self, x): TensorRT7 does not support dynamic clamp, which is used in normalize. """ - import tensorrt as trt - from packaging import version - trt_version = version.parse(trt.__version__) - if trt_version.major >= 8: + logger = get_root_logger() + trt_version_major = 8 + try: + import tensorrt as trt + from packaging import version + trt_version = version.parse(trt.__version__) + trt_version_major = trt_version.major + except Exception: + logger.warning('Can not get TensorRT version.') + if trt_version_major >= 8: return l2norm__forward__default(ctx, self, x) else: return ctx.origin_func(self, x) diff --git a/mmdeploy/codebase/mmdet/models/roi_heads/single_level_roi_extractor.py b/mmdeploy/codebase/mmdet/models/roi_heads/single_level_roi_extractor.py index 01bce47d7..2f3da854c 100644 --- a/mmdeploy/codebase/mmdet/models/roi_heads/single_level_roi_extractor.py +++ b/mmdeploy/codebase/mmdet/models/roi_heads/single_level_roi_extractor.py @@ -16,9 +16,6 @@ class MultiLevelRoiAlign(Function): backend. """ - def __init__(self) -> None: - super().__init__() - @staticmethod def symbolic(g, *args): """Symbolic function for creating onnx op.""" @@ -265,10 +262,12 @@ class SingleRoIExtractorOpenVINO(Function): @staticmethod def forward(g, output_size, featmap_strides, sample_num, rois, *feats): + """Run forward.""" return SingleRoIExtractorOpenVINO.origin_output @staticmethod def symbolic(g, output_size, featmap_strides, sample_num, rois, *feats): + """Symbolic function for creating onnx op.""" from torch.onnx.symbolic_helper import _slice_helper rois = _slice_helper(g, rois, axes=[1], starts=[1], ends=[5]) domain = 'org.openvinotoolkit' diff --git a/mmdeploy/codebase/mmdet/models/roi_heads/test_mixins.py b/mmdeploy/codebase/mmdet/models/roi_heads/test_mixins.py deleted file mode 100644 index e21b2acd3..000000000 --- a/mmdeploy/codebase/mmdet/models/roi_heads/test_mixins.py +++ /dev/null @@ -1,114 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import torch - -from mmdeploy.core import FUNCTION_REWRITER - - -@FUNCTION_REWRITER.register_rewriter( - 'mmdet.models.roi_heads.test_mixins.BBoxTestMixin.simple_test_bboxes') -def bbox_test_mixin__simple_test_bboxes(ctx, - self, - x, - img_metas, - proposals, - rcnn_test_cfg, - rescale=False): - """Rewrite `simple_test_bboxes` of `BBoxTestMixin` for default backend. - - 1. This function eliminates the batch dimension to get forward bbox - results, and recover batch dimension to calculate final result - for deployment. - 2. This function returns detection result as Tensor instead of numpy - array. - - Args: - ctx (ContextCaller): The context with additional information. - self: The instance of the original class. - x (tuple[Tensor]): Features from upstream network. Each - has shape (batch_size, c, h, w). - img_metas (list[dict]): Meta information of images. - proposals (list(Tensor)): Proposals from rpn head. - Each has shape (num_proposals, 5), last dimension - 5 represent (x1, y1, x2, y2, score). - rcnn_test_cfg (obj:`ConfigDict`): `test_cfg` of R-CNN. - rescale (bool): If True, return boxes in original image space. - Default: False. - - Returns: - tuple[Tensor, Tensor]: (det_bboxes, det_labels), `det_bboxes` of - shape [N, num_det, 5] and `det_labels` of shape [N, num_det]. - """ - rois = proposals - batch_index = torch.arange( - rois.shape[0], device=rois.device).float().view(-1, 1, 1).expand( - rois.size(0), rois.size(1), 1) - rois = torch.cat([batch_index, rois[..., :4]], dim=-1) - batch_size = rois.shape[0] - - # Eliminate the batch dimension - rois = rois.view(-1, 5) - bbox_results = self._bbox_forward(x, rois) - cls_score = bbox_results['cls_score'] - bbox_pred = bbox_results['bbox_pred'] - - # Recover the batch dimension - rois = rois.reshape(batch_size, -1, rois.size(-1)) - cls_score = cls_score.reshape(batch_size, -1, cls_score.size(-1)) - - bbox_pred = bbox_pred.reshape(batch_size, -1, bbox_pred.size(-1)) - det_bboxes, det_labels = self.bbox_head.get_bboxes( - rois, - cls_score, - bbox_pred, - img_metas[0]['img_shape'], - None, - rescale=rescale, - cfg=rcnn_test_cfg) - return det_bboxes, det_labels - - -@FUNCTION_REWRITER.register_rewriter( - 'mmdet.models.roi_heads.test_mixins.MaskTestMixin.simple_test_mask') -def mask_test_mixin__simple_test_mask(ctx, self, x, img_metas, det_bboxes, - det_labels, **kwargs): - """Rewrite `simple_test_mask` of `BBoxTestMixin` for default backend. - - This function returns detection result as Tensor instead of numpy - array. - - Args: - ctx (ContextCaller): The context with additional information. - self: The instance of the original class. - x (tuple[Tensor]): Features from upstream network. Each - has shape (batch_size, c, h, w). - img_metas (list[dict]): Meta information of images. - det_bboxes (tuple[Tensor]): Detection bounding-boxes from features. - Each has shape of (batch_size, num_det, 5). - det_labels (tuple[Tensor]): Detection labels from features. Each - has shape of (batch_size, num_det). - - Returns: - tuple[Tensor]: (segm_results), `segm_results` of shape - [N, num_det, roi_H, roi_W]. - """ - batch_size = det_bboxes.size(0) - det_bboxes = det_bboxes[..., :4] - batch_index = torch.arange( - det_bboxes.size(0), - device=det_bboxes.device).float().view(-1, 1, 1).expand( - det_bboxes.size(0), det_bboxes.size(1), 1) - mask_rois = torch.cat([batch_index, det_bboxes], dim=-1) - mask_rois = mask_rois.view(-1, 5) - mask_results = self._mask_forward(x, mask_rois) - mask_pred = mask_results['mask_pred'] - max_shape = img_metas[0]['img_shape'] - num_det = det_bboxes.shape[1] - det_bboxes = det_bboxes.reshape(-1, 4) - det_labels = det_labels.reshape(-1) - segm_results = self.mask_head.get_seg_masks(mask_pred, det_bboxes, - det_labels, self.test_cfg, - max_shape) - segm_results = segm_results.reshape(batch_size, num_det, - segm_results.shape[-2], - segm_results.shape[-1]) - return segm_results diff --git a/mmdeploy/codebase/mmdet/models/task_modules/prior_generators/anchor.py b/mmdeploy/codebase/mmdet/models/task_modules/prior_generators/anchor.py index 4287de6a8..7550dd03b 100644 --- a/mmdeploy/codebase/mmdet/models/task_modules/prior_generators/anchor.py +++ b/mmdeploy/codebase/mmdet/models/task_modules/prior_generators/anchor.py @@ -12,6 +12,7 @@ class GridPriorsTRTOp(torch.autograd.Function): @staticmethod def forward(ctx, base_anchors, feat_h, feat_w, stride_h: int, stride_w: int): + """Generate grid priors by base anchors.""" device = base_anchors.device dtype = base_anchors.dtype shift_x = torch.arange(0, feat_w, device=device).to(dtype) * stride_w @@ -38,6 +39,7 @@ class GridPriorsTRTOp(torch.autograd.Function): @symbolic_helper.parse_args('v', 'v', 'v', 'i', 'i') def symbolic(g, base_anchors, feat_h, feat_w, stride_h: int, stride_w: int): + """Map ops to onnx symbolics.""" # zero_h and zero_w is used to provide shape to GridPriorsTRT feat_h = symbolic_helper._unsqueeze_helper(g, feat_h, [0]) feat_w = symbolic_helper._unsqueeze_helper(g, feat_w, [0]) diff --git a/mmdeploy/codebase/mmdet/models/transformer.py b/mmdeploy/codebase/mmdet/models/transformer.py index 6bbb938de..7ff62c675 100644 --- a/mmdeploy/codebase/mmdet/models/transformer.py +++ b/mmdeploy/codebase/mmdet/models/transformer.py @@ -8,14 +8,16 @@ from mmdeploy.core import FUNCTION_REWRITER func_name='mmdet.models.utils.transformer.PatchMerging.forward', backend='tensorrt') def patch_merging__forward__tensorrt(ctx, self, x, input_size): - """Rewrite forward function of PatchMerging class for TensorRT. - In original implementation, mmdet applies nn.unfold to accelerate the - inferece. However, the onnx graph of it can not be parsed correctly by - TensorRT. In mmdeploy, it is replaced. + """Rewrite forward function of PatchMerging class for TensorRT. In original + implementation, mmdet applies nn.unfold to accelerate the inference. + However, the onnx graph of it can not be parsed correctly by TensorRT. In + mmdeploy, it is replaced. + Args: x (Tensor): Has shape (B, H*W, C_in). input_size (tuple[int]): The spatial shape of x, arrange as (H, W). Default: None. + Returns: tuple: Contains merged results and its spatial shape. - x (Tensor): Has shape (B, Merged_H * Merged_W, C_out)