diff --git a/demo/demo_rewrite.py b/demo/demo_rewrite.py index a624c26eb..a11bc9e0e 100644 --- a/demo/demo_rewrite.py +++ b/demo/demo_rewrite.py @@ -13,7 +13,7 @@ from mmdeploy.utils import get_root_logger @FUNCTION_REWRITER.register_rewriter( func_name='torchvision.models.ResNet._forward_impl') -def forward_of_resnet(ctx, self, x): +def forward_of_resnet(self, x): """Rewrite the forward implementation of resnet. Early return the feature map after two down-sampling steps. diff --git a/docs/en/07-developer-guide/partition_model.md b/docs/en/07-developer-guide/partition_model.md index 96aa8b73e..f1f1420b0 100644 --- a/docs/en/07-developer-guide/partition_model.md +++ b/docs/en/07-developer-guide/partition_model.md @@ -13,13 +13,13 @@ from mmdeploy.core import FUNCTION_REWRITER, mark @mark( 'detector_forward', inputs=['input'], outputs=['dets', 'labels', 'masks']) -def __forward_impl(ctx, self, img, img_metas=None, **kwargs): +def __forward_impl(self, img, img_metas=None, **kwargs): ... @FUNCTION_REWRITER.register_rewriter( 'mmdet.models.detectors.base.BaseDetector.forward') -def base_detector__forward(ctx, self, img, img_metas=None, **kwargs): +def base_detector__forward(self, img, img_metas=None, **kwargs): ... # call the mark function return __forward_impl(...) @@ -32,8 +32,7 @@ from mmdeploy.core import FUNCTION_REWRITER, mark @FUNCTION_REWRITER.register_rewriter( func_name='mmdet.models.dense_heads.YOLOV3Head.get_bboxes') -def yolov3_head__get_bboxes(ctx, - self, +def yolov3_head__get_bboxes(self, pred_maps, img_metas, cfg=None, diff --git a/docs/en/07-developer-guide/support_new_model.md b/docs/en/07-developer-guide/support_new_model.md index ae456a45b..1fb4c012c 100644 --- a/docs/en/07-developer-guide/support_new_model.md +++ b/docs/en/07-developer-guide/support_new_model.md @@ -11,7 +11,8 @@ from mmdeploy.core import FUNCTION_REWRITER @FUNCTION_REWRITER.register_rewriter( func_name='torch.Tensor.repeat', backend='tensorrt') -def repeat_static(ctx, input, *size): +def repeat_static(input, *size): + ctx = FUNCTION_REWRITER.get_context() origin_func = ctx.origin_func if input.dim() == 1 and len(size) == 1: return origin_func(input.unsqueeze(0), *([1] + list(size))).squeeze(0) @@ -72,7 +73,7 @@ The mappings between PyTorch and ONNX are defined in PyTorch with symbolic funct ```python @SYMBOLIC_REWRITER.register_symbolic('squeeze', is_pytorch=True) -def squeeze_default(ctx, g, self, dim=None): +def squeeze_default(g, self, dim=None): if dim is None: dims = [] for i, size in enumerate(self.type().sizes()): diff --git a/docs/en/07-developer-guide/test_rewritten_models.md b/docs/en/07-developer-guide/test_rewritten_models.md index 311e2adbd..e81e79fe0 100644 --- a/docs/en/07-developer-guide/test_rewritten_models.md +++ b/docs/en/07-developer-guide/test_rewritten_models.md @@ -18,7 +18,7 @@ class BaseClassifier(BaseModule, metaclass=ABCMeta): # Custom rewritten function @FUNCTION_REWRITER.register_rewriter( 'mmcls.models.classifiers.BaseClassifier.forward', backend='default') -def forward_of_base_classifier(ctx, self, img, *args, **kwargs): +def forward_of_base_classifier(self, img, *args, **kwargs): """Rewrite `forward` for default backend.""" return self.simple_test(img, {}) ``` @@ -63,7 +63,8 @@ In the first example, the output is generated in Python. Sometimes we may make b # Custom rewritten function @FUNCTION_REWRITER.register_rewriter( func_name='mmseg.models.segmentors.BaseSegmentor.forward') -def base_segmentor__forward(ctx, self, img, img_metas=None, **kwargs): +def base_segmentor__forward(self, img, img_metas=None, **kwargs): + ctx = FUNCTION_REWRITER.get_context() if img_metas is None: img_metas = {} assert isinstance(img_metas, dict) diff --git a/docs/zh_cn/07-developer-guide/partition_model.md b/docs/zh_cn/07-developer-guide/partition_model.md index 2356554d4..bfcaa1058 100644 --- a/docs/zh_cn/07-developer-guide/partition_model.md +++ b/docs/zh_cn/07-developer-guide/partition_model.md @@ -13,13 +13,13 @@ from mmdeploy.core import FUNCTION_REWRITER, mark @mark( 'detector_forward', inputs=['input'], outputs=['dets', 'labels', 'masks']) -def __forward_impl(ctx, self, img, img_metas=None, **kwargs): +def __forward_impl(self, img, img_metas=None, **kwargs): ... @FUNCTION_REWRITER.register_rewriter( 'mmdet.models.detectors.base.BaseDetector.forward') -def base_detector__forward(ctx, self, img, img_metas=None, **kwargs): +def base_detector__forward(self, img, img_metas=None, **kwargs): ... # call the mark function return __forward_impl(...) @@ -32,8 +32,7 @@ from mmdeploy.core import FUNCTION_REWRITER, mark @FUNCTION_REWRITER.register_rewriter( func_name='mmdet.models.dense_heads.YOLOV3Head.get_bboxes') -def yolov3_head__get_bboxes(ctx, - self, +def yolov3_head__get_bboxes(self, pred_maps, img_metas, cfg=None, diff --git a/docs/zh_cn/07-developer-guide/support_new_model.md b/docs/zh_cn/07-developer-guide/support_new_model.md index 7c9cd72ad..727a9a235 100644 --- a/docs/zh_cn/07-developer-guide/support_new_model.md +++ b/docs/zh_cn/07-developer-guide/support_new_model.md @@ -10,7 +10,8 @@ PyTorch 神经网络是用 python 编写的,可以简化算法的开发。但 from mmdeploy.core import FUNCTION_REWRITER @FUNCTION_REWRITER.register_rewriter( func_name='torch.Tensor.repeat', backend='tensorrt') -def repeat_static(ctx, input, *size): +def repeat_static(input, *size): + ctx = FUNCTION_REWRITER.get_context() origin_func = ctx.origin_func if input.dim() == 1 and len(size) == 1: return origin_func(input.unsqueeze(0), *([1] + list(size))).squeeze(0) @@ -67,7 +68,7 @@ PyTorch 和 ONNX 之间的映射是通过 PyTorch 中的符号函数进行定义 ```python @SYMBOLIC_REWRITER.register_symbolic('squeeze', is_pytorch=True) -def squeeze_default(ctx, g, self, dim=None): +def squeeze_default(g, self, dim=None): if dim is None: dims = [] for i, size in enumerate(self.type().sizes()): diff --git a/docs/zh_cn/07-developer-guide/test_rewritten_models.md b/docs/zh_cn/07-developer-guide/test_rewritten_models.md index 0ae0111de..16f3a96e0 100644 --- a/docs/zh_cn/07-developer-guide/test_rewritten_models.md +++ b/docs/zh_cn/07-developer-guide/test_rewritten_models.md @@ -18,7 +18,7 @@ class BaseClassifier(BaseModule, metaclass=ABCMeta): # Custom rewritten function @FUNCTION_REWRITER.register_rewriter( 'mmcls.models.classifiers.BaseClassifier.forward', backend='default') -def forward_of_base_classifier(ctx, self, img, *args, **kwargs): +def forward_of_base_classifier(self, img, *args, **kwargs): """Rewrite `forward` for default backend.""" return self.simple_test(img, {}) ``` @@ -63,7 +63,8 @@ def test_baseclassfier_forward(): # Custom rewritten function @FUNCTION_REWRITER.register_rewriter( func_name='mmseg.models.segmentors.BaseSegmentor.forward') -def base_segmentor__forward(ctx, self, img, img_metas=None, **kwargs): +def base_segmentor__forward(self, img, img_metas=None, **kwargs): + ctx = FUNCTION_REWRITER.get_context() if img_metas is None: img_metas = {} assert isinstance(img_metas, dict) diff --git a/mmdeploy/apis/onnx/export.py b/mmdeploy/apis/onnx/export.py index 6ca127af9..92a9002d8 100644 --- a/mmdeploy/apis/onnx/export.py +++ b/mmdeploy/apis/onnx/export.py @@ -116,6 +116,15 @@ def export(model: torch.nn.Module, input_metas, dict ), f'Expect input_metas type is dict, get {type(input_metas)}.' model_forward = patched_model.forward + + def wrap_forward(forward): + + def wrapper(*arg, **kwargs): + return forward(*arg, **kwargs) + + return wrapper + + patched_model.forward = wrap_forward(patched_model.forward) patched_model.forward = partial(patched_model.forward, **input_metas) diff --git a/mmdeploy/apis/onnx/optimizer.py b/mmdeploy/apis/onnx/optimizer.py index 1456292bb..bfc2cc0ab 100644 --- a/mmdeploy/apis/onnx/optimizer.py +++ b/mmdeploy/apis/onnx/optimizer.py @@ -5,8 +5,9 @@ from mmdeploy.core import FUNCTION_REWRITER @FUNCTION_REWRITER.register_rewriter('torch.onnx.utils._model_to_graph') -def model_to_graph__custom_optimizer(ctx, *args, **kwargs): +def model_to_graph__custom_optimizer(*args, **kwargs): """Rewriter of _model_to_graph, add custom passes.""" + ctx = FUNCTION_REWRITER.get_context() graph, params_dict, torch_out = ctx.origin_func(*args, **kwargs) custom_passes = getattr(ctx, 'onnx_custom_passes', None) @@ -23,8 +24,7 @@ def model_to_graph__custom_optimizer(ctx, *args, **kwargs): @FUNCTION_REWRITER.register_rewriter( 'torch._C._jit_pass_onnx_deduplicate_initializers', backend='tensorrt') -def jit_pass_onnx_deduplicate_initializers__disable(ctx, graph, param_dict, - arg2): +def jit_pass_onnx_deduplicate_initializers__disable(graph, param_dict, arg2): """This pass will disable TensorRT topk export. disable for TensorRT. @@ -34,6 +34,6 @@ def jit_pass_onnx_deduplicate_initializers__disable(ctx, graph, param_dict, @FUNCTION_REWRITER.register_rewriter( 'torch._C._jit_pass_onnx_autograd_function_process') -def jit_pass_onnx_autograd_function_process__disable(ctx, graph): +def jit_pass_onnx_autograd_function_process__disable(graph): """Disable process autograph function.""" return diff --git a/mmdeploy/codebase/mmaction/models/recognizers/base.py b/mmdeploy/codebase/mmaction/models/recognizers/base.py index 5504f2166..7e667e128 100644 --- a/mmdeploy/codebase/mmaction/models/recognizers/base.py +++ b/mmdeploy/codebase/mmaction/models/recognizers/base.py @@ -8,8 +8,7 @@ from mmdeploy.core import FUNCTION_REWRITER @FUNCTION_REWRITER.register_rewriter( 'mmaction.models.recognizers.BaseRecognizer.forward') -def base_recognizer__forward(ctx, - self, +def base_recognizer__forward(self, inputs: Tensor, data_samples: OptSampleList = None, mode: str = 'tensor', diff --git a/mmdeploy/codebase/mmcls/models/backbones/shufflenet_v2.py b/mmdeploy/codebase/mmcls/models/backbones/shufflenet_v2.py index fe3a73d0b..d47c0c6cf 100644 --- a/mmdeploy/codebase/mmcls/models/backbones/shufflenet_v2.py +++ b/mmdeploy/codebase/mmcls/models/backbones/shufflenet_v2.py @@ -6,7 +6,7 @@ from mmdeploy.core import FUNCTION_REWRITER @FUNCTION_REWRITER.register_rewriter( func_name='mmcls.models.backbones.shufflenet_v2.InvertedResidual.forward') -def shufflenetv2_backbone__forward__default(ctx, self, x): +def shufflenetv2_backbone__forward__default(self, x): """Rewrite `forward` of InvertedResidual used in shufflenet_v2. The chunk in original InvertedResidual.forward will convert to dynamic diff --git a/mmdeploy/codebase/mmcls/models/backbones/vision_transformer.py b/mmdeploy/codebase/mmcls/models/backbones/vision_transformer.py index a31853912..2acf13bb8 100644 --- a/mmdeploy/codebase/mmcls/models/backbones/vision_transformer.py +++ b/mmdeploy/codebase/mmcls/models/backbones/vision_transformer.py @@ -9,7 +9,7 @@ from mmdeploy.utils import Backend func_name= # noqa: E251 'mmcls.models.backbones.vision_transformer.VisionTransformer.forward', backend=Backend.NCNN.value) -def visiontransformer__forward__ncnn(ctx, self, x): +def visiontransformer__forward__ncnn(self, x): """Rewrite `forward` of VisionTransformer for ncnn backend. The chunk in original VisionTransformer.forward will convert diff --git a/mmdeploy/codebase/mmcls/models/classifiers/base.py b/mmdeploy/codebase/mmcls/models/classifiers/base.py index aebcb8f3c..ec211bcf4 100644 --- a/mmdeploy/codebase/mmcls/models/classifiers/base.py +++ b/mmdeploy/codebase/mmcls/models/classifiers/base.py @@ -12,7 +12,6 @@ from mmdeploy.core import FUNCTION_REWRITER @FUNCTION_REWRITER.register_rewriter( 'mmcls.models.classifiers.BaseClassifier.forward', backend='default') def base_classifier__forward( - ctx, self, batch_inputs: Tensor, data_samples: Optional[List[BaseDataElement]] = None, diff --git a/mmdeploy/codebase/mmcls/models/necks/gap.py b/mmdeploy/codebase/mmcls/models/necks/gap.py index d89939def..f17d0ebac 100644 --- a/mmdeploy/codebase/mmcls/models/necks/gap.py +++ b/mmdeploy/codebase/mmcls/models/necks/gap.py @@ -9,7 +9,7 @@ from mmdeploy.utils import Backend @FUNCTION_REWRITER.register_rewriter( 'mmcls.models.necks.GlobalAveragePooling.forward', backend=Backend.DEFAULT.value) -def gap__forward(ctx, self, inputs): +def gap__forward(self, inputs): """Rewrite `forward` of GlobalAveragePooling for default backend. Replace `view` with `flatten` to export simple onnx graph. diff --git a/mmdeploy/codebase/mmcls/models/utils/attention.py b/mmdeploy/codebase/mmcls/models/utils/attention.py index edbbc1169..96adea1ad 100644 --- a/mmdeploy/codebase/mmcls/models/utils/attention.py +++ b/mmdeploy/codebase/mmcls/models/utils/attention.py @@ -11,7 +11,7 @@ from mmdeploy.utils import Backend, get_dynamic_axes @FUNCTION_REWRITER.register_rewriter( func_name='mmcls.models.utils.attention.MultiheadAttention.forward', backend=Backend.NCNN.value) -def multiheadattention__forward__ncnn(ctx, self, qkv_input): +def multiheadattention__forward__ncnn(self, qkv_input): """Rewrite `forward` of MultiheadAttention used in vision_transformer for ncnn backend. @@ -53,12 +53,13 @@ def multiheadattention__forward__ncnn(ctx, self, qkv_input): func_name= # noqa: E251 'mmcls.models.utils.ShiftWindowMSA.forward', extra_checkers=LibVersionChecker('mmcls', min_version='0.21.0')) -def shift_window_msa__forward__default(ctx, self, query, hw_shape): +def shift_window_msa__forward__default(self, query, hw_shape): """Rewrite forward function of ShiftWindowMSA class for TensorRT. 1. replace dynamic padding with static padding and dynamic slice. 2. always do slice `x = x[:, :H, :W, :].contiguous()` for stability. """ + ctx = FUNCTION_REWRITER.get_context() if get_dynamic_axes(ctx.cfg) is None: # avoid the weird bug of torch to onnx return ctx.origin_func(self, query, hw_shape) @@ -142,8 +143,7 @@ def shift_window_msa__forward__default(ctx, self, query, hw_shape): func_name= # noqa: E251 'mmcls.models.utils.ShiftWindowMSA.get_attn_mask', extra_checkers=LibVersionChecker('mmcls', min_version='0.21.0')) -def shift_window_msa__get_attn_mask__default(ctx, - self, +def shift_window_msa__get_attn_mask__default(self, hw_shape, window_size, shift_size, diff --git a/mmdeploy/codebase/mmdet/deploy/utils.py b/mmdeploy/codebase/mmdet/deploy/utils.py index 656200234..a7dc0b6fb 100644 --- a/mmdeploy/codebase/mmdet/deploy/utils.py +++ b/mmdeploy/codebase/mmdet/deploy/utils.py @@ -76,7 +76,7 @@ def clip_bboxes(x1: Tensor, y1: Tensor, x2: Tensor, y2: Tensor, func_name='mmdeploy.codebase.mmdet.deploy.utils.clip_bboxes', backend='tensorrt', extra_checkers=LibVersionChecker('tensorrt', min_version='8')) -def clip_bboxes__trt8(ctx, x1: Tensor, y1: Tensor, x2: Tensor, y2: Tensor, +def clip_bboxes__trt8(x1: Tensor, y1: Tensor, x2: Tensor, y2: Tensor, max_shape: Union[Tensor, Sequence[int]]): """Clip bboxes for onnx. From TensorRT 8 we can do the operators on the tensors directly. @@ -165,7 +165,6 @@ def __pad_with_value_if_necessary(x: Tensor, 'mmdeploy.codebase.mmdet.deploy.utils.__pad_with_value_if_necessary', backend=Backend.TENSORRT.value) def __pad_with_value_if_necessary__tensorrt( - ctx, x: Tensor, pad_dim: int, pad_size: int, @@ -223,12 +222,12 @@ class TRTGatherTopk(torch.autograd.Function): @FUNCTION_REWRITER.register_rewriter( 'mmdeploy.codebase.mmdet.deploy.utils.__gather_topk', backend=Backend.TENSORRT.value) -def __gather_topk__trt(ctx, - *inputs: Sequence[torch.Tensor], +def __gather_topk__trt(*inputs: Sequence[torch.Tensor], inds: torch.Tensor, batch_size: int, is_batched: bool = True) -> Tuple[torch.Tensor]: """TensorRT gather_topk.""" + ctx = FUNCTION_REWRITER.get_context() _ = ctx if is_batched: index_shape = inds.shape @@ -253,8 +252,7 @@ def __gather_topk__trt(ctx, @FUNCTION_REWRITER.register_rewriter( 'mmdeploy.codebase.mmdet.deploy.utils.__gather_topk', backend=Backend.COREML.value) -def __gather_topk__nonbatch(ctx, - *inputs: Sequence[torch.Tensor], +def __gather_topk__nonbatch(*inputs: Sequence[torch.Tensor], inds: torch.Tensor, batch_size: int, is_batched: bool = True) -> Tuple[torch.Tensor]: diff --git a/mmdeploy/codebase/mmdet/models/backbones.py b/mmdeploy/codebase/mmdet/models/backbones.py index 122a362b9..6f6a72d5c 100644 --- a/mmdeploy/codebase/mmdet/models/backbones.py +++ b/mmdeploy/codebase/mmdet/models/backbones.py @@ -7,7 +7,7 @@ from mmdeploy.utils import get_common_config @FUNCTION_REWRITER.register_rewriter( func_name='mmdet.models.backbones.csp_darknet.Focus.forward') -def focus__forward__default(ctx, self, x): +def focus__forward__default(self, x): """Rewrite forward function of Focus class. Replace slice with transpose. @@ -27,7 +27,7 @@ def focus__forward__default(ctx, self, x): @FUNCTION_REWRITER.register_rewriter( func_name='mmdet.models.backbones.csp_darknet.Focus.forward', backend='ncnn') -def focus__forward__ncnn(ctx, self, x): +def focus__forward__ncnn(self, x): """Rewrite forward function of Focus class for ncnn. Focus width and height information into channel space. ncnn does not @@ -69,7 +69,7 @@ def focus__forward__ncnn(ctx, self, x): @FUNCTION_REWRITER.register_rewriter( func_name='mmdet.models.backbones.swin.WindowMSA.forward', backend='tensorrt') -def windowmsa__forward__tensorrt(ctx, self, x, mask=None): +def windowmsa__forward__tensorrt(self, x, mask=None): """Rewrite forward function of WindowMSA class for TensorRT. 1. replace Gather operation of qkv with split. @@ -80,6 +80,7 @@ def windowmsa__forward__tensorrt(ctx, self, x, mask=None): mask (tensor | None, Optional): mask with shape of (num_windows, Wh*Ww, Wh*Ww), value should be between (-inf, 0]. """ + ctx = FUNCTION_REWRITER.get_context() B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4).contiguous() @@ -129,7 +130,7 @@ def windowmsa__forward__tensorrt(ctx, self, x, mask=None): @FUNCTION_REWRITER.register_rewriter( func_name='mmdet.models.backbones.swin.ShiftWindowMSA.window_reverse', backend='tensorrt') -def shift_window_msa__window_reverse__tensorrt(ctx, self, windows, H, W): +def shift_window_msa__window_reverse__tensorrt(self, windows, H, W): """Rewrite window_reverse function of ShiftWindowMSA class for TensorRT. For TensorRT, seems radical shape transformations are not allowed. Replace them with soft ones. @@ -155,7 +156,7 @@ def shift_window_msa__window_reverse__tensorrt(ctx, self, windows, H, W): @FUNCTION_REWRITER.register_rewriter( func_name='mmdet.models.backbones.swin.ShiftWindowMSA.window_partition', backend='tensorrt') -def shift_window_msa__window_partition__tensorrt(ctx, self, x): +def shift_window_msa__window_partition__tensorrt(self, x): """Rewrite window_partition function of ShiftWindowMSA class for TensorRT. For TensorRT, seems radical shape transformations are not allowed. Replace them with soft ones. @@ -176,7 +177,7 @@ def shift_window_msa__window_partition__tensorrt(ctx, self, x): @FUNCTION_REWRITER.register_rewriter( func_name='mmdet.models.backbones.swin.ShiftWindowMSA.forward') -def shift_window_msa__forward__default(ctx, self, query, hw_shape): +def shift_window_msa__forward__default(self, query, hw_shape): """Rewrite forward function of ShiftWindowMSA class. 1. replace dynamic padding with static padding and dynamic slice. diff --git a/mmdeploy/codebase/mmdet/models/dense_heads/base_dense_head.py b/mmdeploy/codebase/mmdet/models/dense_heads/base_dense_head.py index e4aaf8c4c..670cb1087 100644 --- a/mmdeploy/codebase/mmdet/models/dense_heads/base_dense_head.py +++ b/mmdeploy/codebase/mmdet/models/dense_heads/base_dense_head.py @@ -24,7 +24,6 @@ from mmdeploy.utils import Backend, is_dynamic_shape func_name='mmdet.models.dense_heads.base_dense_head.' 'BaseDenseHead.predict_by_feat') def base_dense_head__predict_by_feat( - ctx, self, cls_scores: List[Tensor], bbox_preds: List[Tensor], @@ -66,6 +65,7 @@ def base_dense_head__predict_by_feat( tuple[Tensor, Tensor, Tensor]: batch_mlvl_bboxes, batch_mlvl_scores, batch_mlvl_centerness """ + ctx = FUNCTION_REWRITER.get_context() deploy_cfg = ctx.cfg is_dynamic_flag = is_dynamic_shape(deploy_cfg) num_levels = len(cls_scores) @@ -211,7 +211,6 @@ def base_dense_head__predict_by_feat( 'BaseDenseHead.predict_by_feat', backend=Backend.RKNN.value) def base_dense_head__predict_by_feat__rknn( - ctx, self, cls_scores: List[Tensor], bbox_preds: List[Tensor], @@ -253,6 +252,8 @@ def base_dense_head__predict_by_feat__rknn( tuple[Tensor, Tensor, Tensor]: batch_mlvl_bboxes, batch_mlvl_scores, batch_mlvl_centerness """ + ctx = FUNCTION_REWRITER.get_context() + # mark nodes for partition @mark('BaseDenseHead', outputs=['BaseDenseHead.cls', 'BaseDenseHead.loc']) def __mark_dense_head(cls_scores, bbox_preds): @@ -337,7 +338,6 @@ def base_dense_head__predict_by_feat__rknn( 'BaseDenseHead.predict_by_feat', backend=Backend.NCNN.value) def base_dense_head__predict_by_feat__ncnn( - ctx, self, cls_scores: List[Tensor], bbox_preds: List[Tensor], @@ -376,6 +376,7 @@ def base_dense_head__predict_by_feat__ncnn( Returns: output__ncnn (Tensor): outputs, shape is [N, num_det, 6]. """ + ctx = FUNCTION_REWRITER.get_context() assert len(cls_scores) == len(bbox_preds) deploy_cfg = ctx.cfg assert not is_dynamic_shape(deploy_cfg), 'base_dense_head for ncnn\ diff --git a/mmdeploy/codebase/mmdet/models/dense_heads/centernet_head.py b/mmdeploy/codebase/mmdet/models/dense_heads/centernet_head.py index e5130489b..9453b2dac 100644 --- a/mmdeploy/codebase/mmdet/models/dense_heads/centernet_head.py +++ b/mmdeploy/codebase/mmdet/models/dense_heads/centernet_head.py @@ -9,7 +9,6 @@ from mmdeploy.core import FUNCTION_REWRITER @FUNCTION_REWRITER.register_rewriter( 'mmdet.models.dense_heads.centernet_head.CenterNetHead.predict_by_feat') def centernet_head__predict_by_feat__default( - ctx, self, center_heatmap_preds: List[Tensor], wh_preds: List[Tensor], diff --git a/mmdeploy/codebase/mmdet/models/dense_heads/detr_head.py b/mmdeploy/codebase/mmdet/models/dense_heads/detr_head.py index 3ef050d5c..8af369913 100644 --- a/mmdeploy/codebase/mmdet/models/dense_heads/detr_head.py +++ b/mmdeploy/codebase/mmdet/models/dense_heads/detr_head.py @@ -10,7 +10,7 @@ from mmdeploy.core import FUNCTION_REWRITER @FUNCTION_REWRITER.register_rewriter( 'mmdet.models.dense_heads.DETRHead.forward_single') -def detrhead__forward_single__default(ctx, self, x, img_metas): +def detrhead__forward_single__default(self, x, img_metas): """forward_single of DETRHead. Ease the mask computation @@ -35,8 +35,7 @@ def detrhead__forward_single__default(ctx, self, x, img_metas): @FUNCTION_REWRITER.register_rewriter( 'mmdet.models.dense_heads.DETRHead.predict_by_feat') -def detrhead__predict_by_feat__default(ctx, - self, +def detrhead__predict_by_feat__default(self, all_cls_scores_list: List[Tensor], all_bbox_preds_list: List[Tensor], batch_img_metas: List[dict], diff --git a/mmdeploy/codebase/mmdet/models/dense_heads/fovea_head.py b/mmdeploy/codebase/mmdet/models/dense_heads/fovea_head.py index 39fd30f61..8d7d318dd 100644 --- a/mmdeploy/codebase/mmdet/models/dense_heads/fovea_head.py +++ b/mmdeploy/codebase/mmdet/models/dense_heads/fovea_head.py @@ -13,8 +13,7 @@ from mmdeploy.mmcv.ops import multiclass_nms @FUNCTION_REWRITER.register_rewriter( 'mmdet.models.dense_heads.fovea_head.FoveaHead.predict_by_feat') -def fovea_head__predict_by_feat(ctx, - self, +def fovea_head__predict_by_feat(self, cls_scores: List[Tensor], bbox_preds: List[Tensor], score_factors: Optional[List[Tensor]] = None, @@ -49,6 +48,7 @@ def fovea_head__predict_by_feat(ctx, `dets` of shape [N, num_det, 5] and `labels` of shape [N, num_det]. """ + ctx = FUNCTION_REWRITER.get_context() assert len(cls_scores) == len(bbox_preds) cfg = self.test_cfg if cfg is None else cfg num_levels = len(cls_scores) diff --git a/mmdeploy/codebase/mmdet/models/dense_heads/gfl_head.py b/mmdeploy/codebase/mmdet/models/dense_heads/gfl_head.py index 8e8348ccd..04d5caa61 100644 --- a/mmdeploy/codebase/mmdet/models/dense_heads/gfl_head.py +++ b/mmdeploy/codebase/mmdet/models/dense_heads/gfl_head.py @@ -18,8 +18,7 @@ from mmdeploy.utils import Backend, get_backend, is_dynamic_shape @FUNCTION_REWRITER.register_rewriter( func_name='mmdet.models.dense_heads.gfl_head.' 'GFLHead.predict_by_feat') -def gfl_head__predict_by_feat(ctx, - self, +def gfl_head__predict_by_feat(self, cls_scores: List[Tensor], bbox_preds: List[Tensor], score_factors: Optional[List[Tensor]] = None, @@ -58,6 +57,7 @@ def gfl_head__predict_by_feat(ctx, tuple[Tensor, Tensor, Tensor]: batch_mlvl_bboxes, batch_mlvl_scores, batch_mlvl_centerness """ + ctx = FUNCTION_REWRITER.get_context() deploy_cfg = ctx.cfg is_dynamic_flag = is_dynamic_shape(deploy_cfg) backend = get_backend(deploy_cfg) diff --git a/mmdeploy/codebase/mmdet/models/dense_heads/reppoints_head.py b/mmdeploy/codebase/mmdet/models/dense_heads/reppoints_head.py index 7eeb594af..24d4de8d7 100644 --- a/mmdeploy/codebase/mmdet/models/dense_heads/reppoints_head.py +++ b/mmdeploy/codebase/mmdet/models/dense_heads/reppoints_head.py @@ -35,12 +35,13 @@ def _bbox_post_decode(bboxes: torch.Tensor, max_shape: Sequence[int]): @FUNCTION_REWRITER.register_rewriter( 'mmdet.models.dense_heads.reppoints_head.RepPointsHead.points2bbox') -def reppoints_head__points2bbox(ctx, self, pts, y_first=True): +def reppoints_head__points2bbox(self, pts, y_first=True): """Rewrite of `points2bbox` in `RepPointsHead`. Use `self.moment_transfer` in `points2bbox` will cause error: RuntimeError: Input, output and indices must be on the current device """ + ctx = FUNCTION_REWRITER.get_context() update_moment = hasattr(self, 'moment_transfer') if update_moment: moment_transfer = self.moment_transfer @@ -55,7 +56,6 @@ def reppoints_head__points2bbox(ctx, self, pts, y_first=True): @FUNCTION_REWRITER.register_rewriter( 'mmdet.models.dense_heads.reppoints_head.RepPointsHead.predict_by_feat') def reppoints_head__predict_by_feat( - ctx, self, cls_scores: List[Tensor], bbox_preds: List[Tensor], @@ -91,6 +91,7 @@ def reppoints_head__predict_by_feat( `dets` of shape [N, num_det, 5] and `labels` of shape [N, num_det]. """ + ctx = FUNCTION_REWRITER.get_context() deploy_cfg = ctx.cfg is_dynamic_flag = is_dynamic_shape(deploy_cfg) num_levels = len(cls_scores) diff --git a/mmdeploy/codebase/mmdet/models/dense_heads/rpn_head.py b/mmdeploy/codebase/mmdet/models/dense_heads/rpn_head.py index 344bcdc15..70b542c54 100644 --- a/mmdeploy/codebase/mmdet/models/dense_heads/rpn_head.py +++ b/mmdeploy/codebase/mmdet/models/dense_heads/rpn_head.py @@ -16,8 +16,7 @@ from mmdeploy.utils import Backend, is_dynamic_shape @FUNCTION_REWRITER.register_rewriter( func_name='mmdet.models.dense_heads.rpn_head.' 'RPNHead.predict_by_feat') -def rpn_head__predict_by_feat(ctx, - self, +def rpn_head__predict_by_feat(self, cls_scores: List[Tensor], bbox_preds: List[Tensor], score_factors: Optional[List[Tensor]] = None, @@ -61,6 +60,7 @@ def rpn_head__predict_by_feat(ctx, tuple[Tensor, Tensor, Tensor]: batch_mlvl_bboxes, batch_mlvl_scores, batch_mlvl_centerness """ + ctx = FUNCTION_REWRITER.get_context() img_metas = batch_img_metas assert len(cls_scores) == len(bbox_preds) deploy_cfg = ctx.cfg @@ -166,8 +166,7 @@ def rpn_head__predict_by_feat(ctx, # TODO: Fix for 1.x @FUNCTION_REWRITER.register_rewriter( 'mmdet.models.dense_heads.RPNHead.get_bboxes', backend=Backend.NCNN.value) -def rpn_head__get_bboxes__ncnn(ctx, - self, +def rpn_head__get_bboxes__ncnn(self, cls_scores, bbox_preds, img_metas, @@ -204,6 +203,7 @@ def rpn_head__get_bboxes__ncnn(ctx, Else: tuple[Tensor, Tensor]: batch_mlvl_bboxes, batch_mlvl_scores """ + ctx = FUNCTION_REWRITER.get_context() assert len(cls_scores) == len(bbox_preds) deploy_cfg = ctx.cfg assert not is_dynamic_shape(deploy_cfg) diff --git a/mmdeploy/codebase/mmdet/models/dense_heads/rtmdet_head.py b/mmdeploy/codebase/mmdet/models/dense_heads/rtmdet_head.py index f3456e88b..80d57cc93 100644 --- a/mmdeploy/codebase/mmdet/models/dense_heads/rtmdet_head.py +++ b/mmdeploy/codebase/mmdet/models/dense_heads/rtmdet_head.py @@ -14,8 +14,7 @@ from mmdeploy.mmcv.ops import multiclass_nms @FUNCTION_REWRITER.register_rewriter( func_name='mmdet.models.dense_heads.rtmdet_head.' 'RTMDetHead.predict_by_feat') -def rtmdet_head__predict_by_feat(ctx, - self, +def rtmdet_head__predict_by_feat(self, cls_scores: List[Tensor], bbox_preds: List[Tensor], batch_img_metas: Optional[List[dict]] = None, @@ -52,6 +51,7 @@ def rtmdet_head__predict_by_feat(ctx, tensor in the tuple is (N, num_box), and each element represents the class label of the corresponding box. """ + ctx = FUNCTION_REWRITER.get_context() assert len(cls_scores) == len(bbox_preds) device = cls_scores[0].device cfg = self.test_cfg if cfg is None else cfg diff --git a/mmdeploy/codebase/mmdet/models/dense_heads/yolo_head.py b/mmdeploy/codebase/mmdet/models/dense_heads/yolo_head.py index f1f19665c..73ac9f921 100644 --- a/mmdeploy/codebase/mmdet/models/dense_heads/yolo_head.py +++ b/mmdeploy/codebase/mmdet/models/dense_heads/yolo_head.py @@ -15,8 +15,7 @@ from mmdeploy.utils import Backend, is_dynamic_shape @FUNCTION_REWRITER.register_rewriter( func_name='mmdet.models.dense_heads.yolo_head.' 'YOLOV3Head.predict_by_feat') -def yolov3_head__predict_by_feat(ctx, - self, +def yolov3_head__predict_by_feat(self, pred_maps: Sequence[Tensor], cfg: OptConfigType = None, rescale: bool = False, @@ -47,6 +46,7 @@ def yolov3_head__predict_by_feat(ctx, Else: tuple[Tensor, Tensor]: batch_mlvl_bboxes, batch_mlvl_scores """ + ctx = FUNCTION_REWRITER.get_context() deploy_cfg = ctx.cfg # mark pred_maps @@ -154,8 +154,7 @@ def yolov3_head__predict_by_feat(ctx, @FUNCTION_REWRITER.register_rewriter( func_name='mmdet.models.dense_heads.YOLOV3Head.predict_by_feat', backend=Backend.NCNN.value) -def yolov3_head__predict_by_feat__ncnn(ctx, - self, +def yolov3_head__predict_by_feat__ncnn(self, pred_maps, with_nms=True, cfg=None, @@ -188,6 +187,7 @@ def yolov3_head__predict_by_feat__ncnn(ctx, fore-ground class label in Yolov3DetectionOutput starts from `1`. x1, y1, x2, y2 are normalized in range(0,1). """ + ctx = FUNCTION_REWRITER.get_context() num_levels = len(pred_maps) cfg = self.test_cfg if cfg is None else cfg post_params = get_post_processing_params(ctx.cfg) diff --git a/mmdeploy/codebase/mmdet/models/dense_heads/yolox_head.py b/mmdeploy/codebase/mmdet/models/dense_heads/yolox_head.py index 071b33a7e..8b18fceea 100644 --- a/mmdeploy/codebase/mmdet/models/dense_heads/yolox_head.py +++ b/mmdeploy/codebase/mmdet/models/dense_heads/yolox_head.py @@ -15,8 +15,7 @@ from mmdeploy.utils import Backend @FUNCTION_REWRITER.register_rewriter( func_name='mmdet.models.dense_heads.yolox_head.' 'YOLOXHead.predict_by_feat') -def yolox_head__predict_by_feat(ctx, - self, +def yolox_head__predict_by_feat(self, cls_scores: List[Tensor], bbox_preds: List[Tensor], objectnesses: Optional[List[Tensor]], @@ -57,6 +56,8 @@ def yolox_head__predict_by_feat(ctx, tensor in the tuple is (N, num_box), and each element represents the class label of the corresponding box. """ + ctx = FUNCTION_REWRITER.get_context() + # mark pred_maps @mark('yolo_head', inputs=['cls_scores', 'bbox_preds', 'objectnesses']) def __mark_pred_maps(cls_scores, bbox_preds, objectnesses): @@ -125,7 +126,6 @@ def yolox_head__predict_by_feat(ctx, 'YOLOXHead.predict_by_feat', backend=Backend.NCNN.value) def yolox_head__predict_by_feat__ncnn( - ctx, self, cls_scores: List[Tensor], bbox_preds: List[Tensor], @@ -169,6 +169,7 @@ def yolox_head__predict_by_feat__ncnn( Returns: output__ncnn (Tensor): outputs, shape is [N, num_det, 6]. """ + ctx = FUNCTION_REWRITER.get_context() from mmdeploy.codebase.mmdet.ops import ncnn_detection_output_forward from mmdeploy.utils import get_root_logger from mmdeploy.utils.config_utils import is_dynamic_shape diff --git a/mmdeploy/codebase/mmdet/models/detectors/single_stage.py b/mmdeploy/codebase/mmdet/models/detectors/single_stage.py index adfb6831f..5f3872c8b 100644 --- a/mmdeploy/codebase/mmdet/models/detectors/single_stage.py +++ b/mmdeploy/codebase/mmdet/models/detectors/single_stage.py @@ -12,7 +12,7 @@ from mmdeploy.utils import is_dynamic_shape @mark( 'detector_forward', inputs=['input'], outputs=['dets', 'labels', 'masks']) -def __forward_impl(ctx, self, batch_inputs, data_samples, **kwargs): +def __forward_impl(self, batch_inputs, data_samples): """Rewrite and adding mark for `forward`. Encapsulate this function for rewriting `forward` of BaseDetector. @@ -25,10 +25,30 @@ def __forward_impl(ctx, self, batch_inputs, data_samples, **kwargs): return output +@torch.fx.wrap +def _set_metainfo(data_samples, img_shape): + """Set the metainfo. + + Code in this function cannot be traced by fx. + """ + + # fx can not trace deepcopy correctly + data_samples = copy.deepcopy(data_samples) + if data_samples is None: + data_samples = [DetDataSample()] + + # note that we can not use `set_metainfo`, deepcopy would crash the + # onnx trace. + for data_sample in data_samples: + data_sample.set_field( + name='img_shape', value=img_shape, field_type='metainfo') + + return data_samples + + @FUNCTION_REWRITER.register_rewriter( 'mmdet.models.detectors.single_stage.SingleStageDetector.forward') -def single_stage_detector__forward(ctx, - self, +def single_stage_detector__forward(self, batch_inputs: torch.Tensor, data_samples: OptSampleList = None, mode: str = 'tensor', @@ -53,9 +73,8 @@ def single_stage_detector__forward(ctx, - labels (Tensor): Labels of bboxes, has a shape (num_instances, ). """ - data_samples = copy.deepcopy(data_samples) - if data_samples is None: - data_samples = [DetDataSample()] + ctx = FUNCTION_REWRITER.get_context() + deploy_cfg = ctx.cfg # get origin input shape as tensor to support onnx dynamic shape @@ -65,11 +84,6 @@ def single_stage_detector__forward(ctx, img_shape = [int(val) for val in img_shape] # set the metainfo - # note that we can not use `set_metainfo`, deepcopy would crash the - # onnx trace. - for data_sample in data_samples: - data_sample.set_field( - name='img_shape', value=img_shape, field_type='metainfo') + data_samples = _set_metainfo(data_samples, img_shape) - return __forward_impl( - ctx, self, batch_inputs, data_samples=data_samples, **kwargs) + return __forward_impl(self, batch_inputs, data_samples=data_samples) diff --git a/mmdeploy/codebase/mmdet/models/detectors/two_stage.py b/mmdeploy/codebase/mmdet/models/detectors/two_stage.py index 9b20fed83..d0bd14000 100644 --- a/mmdeploy/codebase/mmdet/models/detectors/two_stage.py +++ b/mmdeploy/codebase/mmdet/models/detectors/two_stage.py @@ -11,8 +11,7 @@ from mmdeploy.utils import is_dynamic_shape @FUNCTION_REWRITER.register_rewriter( 'mmdet.models.detectors.two_stage.TwoStageDetector.extract_feat') -@mark('extract_feat', inputs='img', outputs='feat') -def two_stage_detector__extract_feat(ctx, self, img): +def two_stage_detector__extract_feat(self, img): """Rewrite `extract_feat` for default backend. This function uses the specific `extract_feat` function for the two @@ -27,13 +26,18 @@ def two_stage_detector__extract_feat(ctx, self, img): list[Tensor]: Each item with shape (N, C, H, W) corresponds one level of backbone and neck features. """ - return ctx.origin_func(self, img) + ctx = FUNCTION_REWRITER.get_context() + + @mark('extract_feat', inputs='img', outputs='feat') + def __extract_feat_impl(self, img): + return ctx.origin_func(self, img) + + return __extract_feat_impl(self, img) @FUNCTION_REWRITER.register_rewriter( 'mmdet.models.detectors.two_stage.TwoStageDetector.forward') -def two_stage_detector__forward(ctx, - self, +def two_stage_detector__forward(self, batch_inputs: torch.Tensor, data_samples: OptSampleList = None, mode: str = 'tensor', @@ -58,6 +62,7 @@ def two_stage_detector__forward(ctx, - labels (Tensor): Labels of bboxes, has a shape (num_instances, ). """ + ctx = FUNCTION_REWRITER.get_context() data_samples = copy.deepcopy(data_samples) deploy_cfg = ctx.cfg diff --git a/mmdeploy/codebase/mmdet/models/necks.py b/mmdeploy/codebase/mmdet/models/necks.py index adc40fa12..4ea29db5f 100644 --- a/mmdeploy/codebase/mmdet/models/necks.py +++ b/mmdeploy/codebase/mmdet/models/necks.py @@ -7,7 +7,7 @@ from mmdeploy.utils import Backend, get_root_logger @FUNCTION_REWRITER.register_rewriter( func_name='mmdet.models.necks.ssd_neck.L2Norm.forward') -def l2norm__forward__default(ctx, self, x): +def l2norm__forward__default(self, x): """Default rewriter for l2norm. Implement with functinoal.normalize . @@ -19,11 +19,12 @@ def l2norm__forward__default(ctx, self, x): @FUNCTION_REWRITER.register_rewriter( func_name='mmdet.models.necks.ssd_neck.L2Norm.forward', backend=Backend.TENSORRT.value) -def l2norm__forward__tensorrt(ctx, self, x): +def l2norm__forward__tensorrt(self, x): """rewrite `l2norm` for TensorRT. TensorRT7 does not support dynamic clamp, which is used in normalize. """ + ctx = FUNCTION_REWRITER.get_context() logger = get_root_logger() trt_version_major = 8 try: @@ -34,6 +35,6 @@ def l2norm__forward__tensorrt(ctx, self, x): except Exception: logger.warning('Can not get TensorRT version.') if trt_version_major >= 8: - return l2norm__forward__default(ctx, self, x) + return l2norm__forward__default(self, x) else: return ctx.origin_func(self, x) diff --git a/mmdeploy/codebase/mmdet/models/roi_heads/bbox_head.py b/mmdeploy/codebase/mmdet/models/roi_heads/bbox_head.py index e8387fe5e..da0765aa7 100644 --- a/mmdeploy/codebase/mmdet/models/roi_heads/bbox_head.py +++ b/mmdeploy/codebase/mmdet/models/roi_heads/bbox_head.py @@ -17,11 +17,7 @@ from mmdeploy.mmcv.ops import multiclass_nms @FUNCTION_REWRITER.register_rewriter( 'mmdet.models.roi_heads.bbox_heads.convfc_bbox_head.ConvFCBBoxHead.forward' ) -@mark( - 'bbox_head_forward', - inputs=['bbox_feats'], - outputs=['cls_score', 'bbox_pred']) -def bbox_head__forward(ctx, self, x): +def bbox_head__forward(self, x): """Rewrite `forward` for default backend. This function uses the specific `forward` function for the BBoxHead @@ -37,13 +33,21 @@ def bbox_head__forward(ctx, self, x): has shape (N, num_det, num_cls) and the bbox_pred has shape (N, num_det, 4). """ - return ctx.origin_func(self, x) + ctx = FUNCTION_REWRITER.get_context() + + @mark( + 'bbox_head_forward', + inputs=['bbox_feats'], + outputs=['cls_score', 'bbox_pred']) + def __forward(self, x): + return ctx.origin_func(self, x) + + return __forward(self, x) @FUNCTION_REWRITER.register_rewriter( 'mmdet.models.roi_heads.bbox_heads.bbox_head.BBoxHead.predict_by_feat') -def bbox_head__predict_by_feat(ctx, - self, +def bbox_head__predict_by_feat(self, rois: Tuple[Tensor], cls_scores: Tuple[Tensor], bbox_preds: Tuple[Tensor], @@ -75,6 +79,7 @@ def bbox_head__predict_by_feat(ctx, - labels (Tensor): Labels of bboxes, has a shape (num_instances, ). """ + ctx = FUNCTION_REWRITER.get_context() assert rois.ndim == 3, 'Only support export two stage ' \ 'model to ONNX ' \ 'with batch dimension. ' diff --git a/mmdeploy/codebase/mmdet/models/roi_heads/cascade_roi_head.py b/mmdeploy/codebase/mmdet/models/roi_heads/cascade_roi_head.py index ef926f550..f302d1378 100644 --- a/mmdeploy/codebase/mmdet/models/roi_heads/cascade_roi_head.py +++ b/mmdeploy/codebase/mmdet/models/roi_heads/cascade_roi_head.py @@ -11,8 +11,7 @@ from mmdeploy.core import FUNCTION_REWRITER @FUNCTION_REWRITER.register_rewriter( 'mmdet.models.roi_heads.cascade_roi_head.CascadeRoIHead.predict_bbox') -def cascade_roi_head__predict_bbox(ctx, - self, +def cascade_roi_head__predict_bbox(self, x: Tuple[Tensor], batch_img_metas: List[dict], rpn_results_list: List[Tensor], @@ -86,8 +85,7 @@ def cascade_roi_head__predict_bbox(ctx, @FUNCTION_REWRITER.register_rewriter( 'mmdet.models.roi_heads.cascade_roi_head.CascadeRoIHead.predict_mask') -def cascade_roi_head__predict_mask(ctx, - self, +def cascade_roi_head__predict_mask(self, x: Tuple[Tensor], batch_img_metas: List[dict], results_list: List[Tensor], diff --git a/mmdeploy/codebase/mmdet/models/roi_heads/fcn_mask_head.py b/mmdeploy/codebase/mmdet/models/roi_heads/fcn_mask_head.py index 360faeb1a..9371ff552 100644 --- a/mmdeploy/codebase/mmdet/models/roi_heads/fcn_mask_head.py +++ b/mmdeploy/codebase/mmdet/models/roi_heads/fcn_mask_head.py @@ -14,8 +14,7 @@ from mmdeploy.utils import Backend, get_backend @FUNCTION_REWRITER.register_rewriter( 'mmdet.models.roi_heads.' 'mask_heads.fcn_mask_head.FCNMaskHead.predict_by_feat') -def fcn_mask_head__predict_by_feat(ctx, - self, +def fcn_mask_head__predict_by_feat(self, mask_preds: Tuple[Tensor], results_list: List[Tensor], batch_img_metas: List[dict], @@ -48,6 +47,7 @@ def fcn_mask_head__predict_by_feat(ctx, (num_instances, ). - masks (Tensor): Has a shape (num_instances, H, W). """ + ctx = FUNCTION_REWRITER.get_context() ori_shape = batch_img_metas[0]['img_shape'] dets, det_labels = results_list dets = dets.view(-1, 5) diff --git a/mmdeploy/codebase/mmdet/models/roi_heads/single_level_roi_extractor.py b/mmdeploy/codebase/mmdet/models/roi_heads/single_level_roi_extractor.py index ab20541c9..d8f53ad0f 100644 --- a/mmdeploy/codebase/mmdet/models/roi_heads/single_level_roi_extractor.py +++ b/mmdeploy/codebase/mmdet/models/roi_heads/single_level_roi_extractor.py @@ -61,13 +61,12 @@ class MultiLevelRoiAlign(Function): (num_proposals, channel, output_size[1], output_size[0])) +@mark('roi_extractor', inputs=['feats', 'rois'], outputs=['bbox_feats']) @FUNCTION_REWRITER.register_rewriter( 'mmdet.models.roi_heads.roi_extractors.' 'single_level_roi_extractor.SingleRoIExtractor.forward', backend='tensorrt') -@mark('roi_extractor', inputs=['feats', 'rois'], outputs=['bbox_feats']) -def single_roi_extractor__forward__tensorrt(ctx, - self, +def single_roi_extractor__forward__tensorrt(self, feats, rois, roi_scale_factor=None): @@ -154,8 +153,7 @@ class AscendRoiExtractor(Function): 'mmdet.models.roi_heads.roi_extractors.' 'single_level_roi_extractor.SingleRoIExtractor.forward', backend='ascend') -def single_roi_extractor__forward__ascend(ctx, - self, +def single_roi_extractor__forward__ascend(self, feats, rois, roi_scale_factor=None): @@ -185,14 +183,10 @@ def single_roi_extractor__forward__ascend(ctx, finest_scale, featmap_strides, aligned) +@mark('roi_extractor', inputs=['feats', 'rois'], outputs=['bbox_feats']) @FUNCTION_REWRITER.register_rewriter( func_name='mmdet.models.roi_heads.SingleRoIExtractor.forward') -@mark('roi_extractor', inputs=['feats', 'rois'], outputs=['bbox_feats']) -def single_roi_extractor__forward(ctx, - self, - feats, - rois, - roi_scale_factor=None): +def single_roi_extractor__forward(self, feats, rois, roi_scale_factor=None): """Rewrite `forward` of SingleRoIExtractor for default backend. Rewrite this function to: @@ -206,6 +200,8 @@ def single_roi_extractor__forward(ctx, 3. use the roi align in torhcvision to accelerate the inference. """ + ctx = FUNCTION_REWRITER.get_context( + 'mmdet.models.roi_heads.SingleRoIExtractor.forward') backend = get_backend(ctx.cfg) out_size = self.roi_layers[0].output_size num_levels = len(feats) @@ -291,8 +287,7 @@ class SingleRoIExtractorOpenVINO(Function): 'mmdet.models.roi_heads.roi_extractors.' 'single_level_roi_extractor.SingleRoIExtractor.forward', backend='openvino') -def single_roi_extractor__forward__openvino(ctx, - self, +def single_roi_extractor__forward__openvino(self, feats, rois, roi_scale_factor=None): @@ -301,6 +296,7 @@ def single_roi_extractor__forward__openvino(ctx, This function uses ExperimentalDetectronROIFeatureExtractor for OpenVINO. """ + ctx = FUNCTION_REWRITER.get_context() # Adding original output to SingleRoIExtractorOpenVINO. state = torch._C._get_tracing_state() @@ -317,12 +313,11 @@ def single_roi_extractor__forward__openvino(ctx, return result +@mark('roi_extractor', inputs=['feats', 'rois'], outputs=['bbox_feats']) @FUNCTION_REWRITER.register_rewriter( func_name='mmdet.models.roi_heads.SingleRoIExtractor.forward', backend=Backend.COREML.value) -@mark('roi_extractor', inputs=['feats', 'rois'], outputs=['bbox_feats']) -def single_roi_extractor__forward__coreml(ctx, - self, +def single_roi_extractor__forward__coreml(self, feats, rois, roi_scale_factor=None): diff --git a/mmdeploy/codebase/mmdet/models/roi_heads/standard_roi_head.py b/mmdeploy/codebase/mmdet/models/roi_heads/standard_roi_head.py index c62c8c70e..aa7801dd8 100644 --- a/mmdeploy/codebase/mmdet/models/roi_heads/standard_roi_head.py +++ b/mmdeploy/codebase/mmdet/models/roi_heads/standard_roi_head.py @@ -10,8 +10,7 @@ from mmdeploy.core import FUNCTION_REWRITER @FUNCTION_REWRITER.register_rewriter( 'mmdet.models.roi_heads.standard_roi_head.StandardRoIHead.predict_bbox') -def standard_roi_head__predict_bbox(ctx, - self, +def standard_roi_head__predict_bbox(self, x: Tuple[Tensor], batch_img_metas: List[dict], rpn_results_list: List[Tensor], @@ -72,8 +71,7 @@ def standard_roi_head__predict_bbox(ctx, @FUNCTION_REWRITER.register_rewriter( 'mmdet.models.roi_heads.standard_roi_head.StandardRoIHead.predict_mask') -def standard_roi_head__predict_mask(ctx, - self, +def standard_roi_head__predict_mask(self, x: Tuple[Tensor], batch_img_metas: List[dict], results_list: List[Tensor], diff --git a/mmdeploy/codebase/mmdet/models/task_modules/coders/delta_xywh_bbox_coder.py b/mmdeploy/codebase/mmdet/models/task_modules/coders/delta_xywh_bbox_coder.py index e4e8f8e82..8a5394423 100644 --- a/mmdeploy/codebase/mmdet/models/task_modules/coders/delta_xywh_bbox_coder.py +++ b/mmdeploy/codebase/mmdet/models/task_modules/coders/delta_xywh_bbox_coder.py @@ -9,8 +9,7 @@ from mmdeploy.core import FUNCTION_REWRITER func_name='mmdet.models.task_modules.coders.delta_xywh_bbox_coder.' 'DeltaXYWHBBoxCoder.decode', backend='default') -def deltaxywhbboxcoder__decode(ctx, - self, +def deltaxywhbboxcoder__decode(self, bboxes, pred_bboxes, max_shape=None, @@ -51,8 +50,7 @@ def deltaxywhbboxcoder__decode(ctx, func_name='mmdet.models.task_modules.coders' '.delta_xywh_bbox_coder.delta2bbox', backend='default') -def delta2bbox(ctx, - rois, +def delta2bbox(rois, deltas, means=(0., 0., 0., 0.), stds=(1., 1., 1., 1.), @@ -143,8 +141,7 @@ def delta2bbox(ctx, func_name='mmdet.models.task_modules.coders.' 'delta_xywh_bbox_coder.delta2bbox', backend='ncnn') -def delta2bbox__ncnn(ctx, - rois, +def delta2bbox__ncnn(rois, deltas, means=(0., 0., 0., 0.), stds=(1., 1., 1., 1.), diff --git a/mmdeploy/codebase/mmdet/models/task_modules/coders/distance_point_bbox_coder.py b/mmdeploy/codebase/mmdet/models/task_modules/coders/distance_point_bbox_coder.py index 41da1bdfc..8b8bbdb0e 100644 --- a/mmdeploy/codebase/mmdet/models/task_modules/coders/distance_point_bbox_coder.py +++ b/mmdeploy/codebase/mmdet/models/task_modules/coders/distance_point_bbox_coder.py @@ -8,11 +8,7 @@ from mmdeploy.core import FUNCTION_REWRITER func_name='mmdet.models.task_modules.coders.distance_point_bbox_coder' '.DistancePointBBoxCoder.decode', backend='default') -def distancepointbboxcoder__decode(ctx, - self, - points, - pred_bboxes, - max_shape=None): +def distancepointbboxcoder__decode(self, points, pred_bboxes, max_shape=None): """Rewrite `mmdet.models.task_modules.coders.distance_point_bbox_coder. \ DistancePointBBoxCoder.decode` diff --git a/mmdeploy/codebase/mmdet/models/task_modules/coders/tblr_bbox_coder.py b/mmdeploy/codebase/mmdet/models/task_modules/coders/tblr_bbox_coder.py index c5ca8cd37..b0f56676c 100644 --- a/mmdeploy/codebase/mmdet/models/task_modules/coders/tblr_bbox_coder.py +++ b/mmdeploy/codebase/mmdet/models/task_modules/coders/tblr_bbox_coder.py @@ -7,8 +7,7 @@ from mmdeploy.core import FUNCTION_REWRITER @FUNCTION_REWRITER.register_rewriter( func_name='mmdet.models.task_modules.coders.tblr_bbox_coder.tblr2bboxes', backend='default') -def tblr2bboxes(ctx, - priors, +def tblr2bboxes(priors, tblr, normalizer=4.0, normalize_by_wh=True, diff --git a/mmdeploy/codebase/mmdet/models/task_modules/prior_generators/anchor.py b/mmdeploy/codebase/mmdet/models/task_modules/prior_generators/anchor.py index 126063fae..b16ec26fa 100644 --- a/mmdeploy/codebase/mmdet/models/task_modules/prior_generators/anchor.py +++ b/mmdeploy/codebase/mmdet/models/task_modules/prior_generators/anchor.py @@ -77,7 +77,6 @@ grid_priors_trt = GridPriorsTRTOp.apply 'AnchorGenerator.single_level_grid_priors', backend='tensorrt') def anchorgenerator__single_level_grid_priors__trt( - ctx, self, featmap_size: Tuple[int], level_idx: int, @@ -98,6 +97,7 @@ def anchorgenerator__single_level_grid_priors__trt( Returns: torch.Tensor: Anchors in the overall feature maps. """ + ctx = FUNCTION_REWRITER.get_context() from mmdet.models.task_modules.prior_generators import AnchorGenerator if type(self) != AnchorGenerator: # only use custom node on default generator. diff --git a/mmdeploy/codebase/mmdet/models/task_modules/prior_generators/point_generator.py b/mmdeploy/codebase/mmdet/models/task_modules/prior_generators/point_generator.py index 91e54692c..bfeb91bdc 100644 --- a/mmdeploy/codebase/mmdet/models/task_modules/prior_generators/point_generator.py +++ b/mmdeploy/codebase/mmdet/models/task_modules/prior_generators/point_generator.py @@ -10,7 +10,6 @@ from mmdeploy.utils.constants import Backend '.single_level_grid_priors', backend=Backend.TENSORRT.value) def mlvl_point_generator__single_level_grid_priors__tensorrt( - ctx, self, featmap_size, level_idx, diff --git a/mmdeploy/codebase/mmdet/models/transformer.py b/mmdeploy/codebase/mmdet/models/transformer.py index 7ff62c675..e89a50669 100644 --- a/mmdeploy/codebase/mmdet/models/transformer.py +++ b/mmdeploy/codebase/mmdet/models/transformer.py @@ -7,7 +7,7 @@ from mmdeploy.core import FUNCTION_REWRITER @FUNCTION_REWRITER.register_rewriter( func_name='mmdet.models.utils.transformer.PatchMerging.forward', backend='tensorrt') -def patch_merging__forward__tensorrt(ctx, self, x, input_size): +def patch_merging__forward__tensorrt(self, x, input_size): """Rewrite forward function of PatchMerging class for TensorRT. In original implementation, mmdet applies nn.unfold to accelerate the inference. However, the onnx graph of it can not be parsed correctly by TensorRT. In diff --git a/mmdeploy/codebase/mmdet/structures/bbox/transforms.py b/mmdeploy/codebase/mmdet/structures/bbox/transforms.py index 727ce0f45..aed156f51 100644 --- a/mmdeploy/codebase/mmdet/structures/bbox/transforms.py +++ b/mmdeploy/codebase/mmdet/structures/bbox/transforms.py @@ -8,7 +8,7 @@ from mmdeploy.core import FUNCTION_REWRITER @FUNCTION_REWRITER.register_rewriter( func_name='mmdet.structures.bbox.transforms.distance2bbox' # noqa ) -def distance2bbox__default(ctx, points, distance, max_shape=None): +def distance2bbox__default(points, distance, max_shape=None): """Rewrite `mmdet.core.bbox.transforms.distance2bbox` Decode distance prediction to bounding box. diff --git a/mmdeploy/codebase/mmdet3d/models/base.py b/mmdeploy/codebase/mmdet3d/models/base.py index 38d35cd95..4410e77e2 100644 --- a/mmdeploy/codebase/mmdet3d/models/base.py +++ b/mmdeploy/codebase/mmdet3d/models/base.py @@ -9,8 +9,7 @@ from mmdeploy.core import FUNCTION_REWRITER @FUNCTION_REWRITER.register_rewriter( 'mmdet3d.models.detectors.Base3DDetector.forward' # noqa: E501 ) -def basedetector__forward(ctx, - self, +def basedetector__forward(self, inputs: list, data_samples=None, **kwargs) -> Tuple[List[torch.Tensor]]: diff --git a/mmdeploy/codebase/mmdet3d/models/mvx_two_stage.py b/mmdeploy/codebase/mmdet3d/models/mvx_two_stage.py index 83ee17088..12df74ff5 100644 --- a/mmdeploy/codebase/mmdet3d/models/mvx_two_stage.py +++ b/mmdeploy/codebase/mmdet3d/models/mvx_two_stage.py @@ -7,8 +7,7 @@ from mmdeploy.core import FUNCTION_REWRITER @FUNCTION_REWRITER.register_rewriter( 'mmdet3d.models.detectors.mvx_two_stage.MVXTwoStageDetector.extract_img_feat' # noqa: E501 ) -def mvxtwostagedetector__extract_img_feat(ctx, self, - img: torch.Tensor) -> dict: +def mvxtwostagedetector__extract_img_feat(self, img: torch.Tensor) -> dict: """Extract features of images.""" if self.with_img_backbone and img is not None: if img.dim() == 5 and img.size(0) == 1: @@ -26,8 +25,7 @@ def mvxtwostagedetector__extract_img_feat(ctx, self, @FUNCTION_REWRITER.register_rewriter( 'mmdet3d.models.detectors.mvx_two_stage.MVXTwoStageDetector.extract_feat') -def mvxtwostagedetector__extract_feat(ctx, self, - batch_inputs_dict: dict) -> tuple: +def mvxtwostagedetector__extract_feat(self, batch_inputs_dict: dict) -> tuple: """Rewrite this func to remove voxelize op. Args: @@ -47,7 +45,7 @@ def mvxtwostagedetector__extract_feat(ctx, self, @FUNCTION_REWRITER.register_rewriter( 'mmdet3d.models.detectors.mvx_two_stage.MVXTwoStageDetector.forward') -def mvxtwostagedetector__forward(ctx, self, inputs: list, **kwargs): +def mvxtwostagedetector__forward(self, inputs: list, **kwargs): """Rewrite this func to remove voxelize op. Args: diff --git a/mmdeploy/codebase/mmdet3d/models/pillar_encode.py b/mmdeploy/codebase/mmdet3d/models/pillar_encode.py index 4908a5707..0a327ad29 100644 --- a/mmdeploy/codebase/mmdet3d/models/pillar_encode.py +++ b/mmdeploy/codebase/mmdet3d/models/pillar_encode.py @@ -7,7 +7,7 @@ from mmdeploy.core import FUNCTION_REWRITER @FUNCTION_REWRITER.register_rewriter( 'mmdet3d.models.voxel_encoders.pillar_encoder.PillarFeatureNet.forward') -def pillar_encoder__forward(ctx, self, features, num_points, coors, *args, +def pillar_encoder__forward(self, features, num_points, coors, *args, **kwargs): """Rewrite this func to optimize node. Modify the code at _with_voxel_center and use slice instead of the original operation. diff --git a/mmdeploy/codebase/mmdet3d/models/pillar_scatter.py b/mmdeploy/codebase/mmdet3d/models/pillar_scatter.py index 66ae455b5..34351ccbc 100644 --- a/mmdeploy/codebase/mmdet3d/models/pillar_scatter.py +++ b/mmdeploy/codebase/mmdet3d/models/pillar_scatter.py @@ -7,11 +7,7 @@ from mmdeploy.core import FUNCTION_REWRITER @FUNCTION_REWRITER.register_rewriter( 'mmdet3d.models.middle_encoders.pillar_scatter.' 'PointPillarsScatter.forward_batch', ) -def pointpillarsscatter__forward(ctx, - self, - voxel_features, - coors, - batch_size=1): +def pointpillarsscatter__forward(self, voxel_features, coors, batch_size=1): """Scatter features of single sample. Args: diff --git a/mmdeploy/codebase/mmedit/models/base_models/base_edit_model.py b/mmdeploy/codebase/mmedit/models/base_models/base_edit_model.py index eb3dad7dd..620165ec6 100644 --- a/mmdeploy/codebase/mmedit/models/base_models/base_edit_model.py +++ b/mmdeploy/codebase/mmedit/models/base_models/base_edit_model.py @@ -10,7 +10,6 @@ from mmdeploy.core import FUNCTION_REWRITER @FUNCTION_REWRITER.register_rewriter( 'mmedit.models.base_models.BaseEditModel.forward', backend='default') def base_edit_model__forward( - ctx, self, batch_inputs: Tensor, data_samples: Optional[List[BaseDataElement]] = None, diff --git a/mmdeploy/codebase/mmocr/models/text_detection/fpn_cat.py b/mmdeploy/codebase/mmocr/models/text_detection/fpn_cat.py index c188d7e56..39acc2ea5 100644 --- a/mmdeploy/codebase/mmocr/models/text_detection/fpn_cat.py +++ b/mmdeploy/codebase/mmocr/models/text_detection/fpn_cat.py @@ -8,7 +8,7 @@ from mmdeploy.core import FUNCTION_REWRITER @FUNCTION_REWRITER.register_rewriter( func_name='mmocr.models.textdet.FPNC.forward', backend='tensorrt') -def fpnc__forward__tensorrt(ctx, self, inputs, **kwargs): +def fpnc__forward__tensorrt(self, inputs, **kwargs): """Rewrite `forward` of FPNC for tensorrt backend. Rewrite this function to replace nearest upsampling with bilinear diff --git a/mmdeploy/codebase/mmocr/models/text_detection/heads.py b/mmdeploy/codebase/mmocr/models/text_detection/heads.py index 9c5801852..d23c957ff 100644 --- a/mmdeploy/codebase/mmocr/models/text_detection/heads.py +++ b/mmdeploy/codebase/mmocr/models/text_detection/heads.py @@ -10,7 +10,7 @@ from mmdeploy.core import FUNCTION_REWRITER @FUNCTION_REWRITER.register_rewriter( func_name='mmocr.models.textdet.heads.BaseTextDetHead.predict') def base_text_det_head__predict( - ctx, self, x: torch.Tensor, + self, x: torch.Tensor, batch_data_samples: DetSampleList) -> DetSampleList: """Rewrite `predict` of BaseTextDetHead for default backend. @@ -38,7 +38,7 @@ def base_text_det_head__predict( @FUNCTION_REWRITER.register_rewriter( func_name='mmocr.models.textdet.heads.DBHead.predict') -def db_head__predict(ctx, self, x: torch.Tensor, +def db_head__predict(self, x: torch.Tensor, batch_data_samples: DetSampleList) -> DetSampleList: """Rewrite to avoid post-process of text detection head. diff --git a/mmdeploy/codebase/mmocr/models/text_detection/single_stage_text_detector.py b/mmdeploy/codebase/mmocr/models/text_detection/single_stage_text_detector.py index ea72eae89..0313097af 100644 --- a/mmdeploy/codebase/mmocr/models/text_detection/single_stage_text_detector.py +++ b/mmdeploy/codebase/mmocr/models/text_detection/single_stage_text_detector.py @@ -10,7 +10,6 @@ from mmdeploy.core import FUNCTION_REWRITER @FUNCTION_REWRITER.register_rewriter( func_name='mmocr.models.textdet.SingleStageTextDetector.forward') def single_stage_text_detector__forward( - ctx, self, batch_inputs: torch.Tensor, data_samples: TextDetDataSample = None, diff --git a/mmdeploy/codebase/mmocr/models/text_recognition/base_decoder.py b/mmdeploy/codebase/mmocr/models/text_recognition/base_decoder.py index 26adccaec..036e95218 100644 --- a/mmdeploy/codebase/mmocr/models/text_recognition/base_decoder.py +++ b/mmdeploy/codebase/mmocr/models/text_recognition/base_decoder.py @@ -10,7 +10,6 @@ from mmdeploy.core import FUNCTION_REWRITER @FUNCTION_REWRITER.register_rewriter( func_name='mmocr.models.textrecog.decoders.BaseDecoder.predict') def base_decoder__forward( - ctx, self, feat: Optional[torch.Tensor] = None, out_enc: Optional[torch.Tensor] = None, diff --git a/mmdeploy/codebase/mmocr/models/text_recognition/crnn_decoder.py b/mmdeploy/codebase/mmocr/models/text_recognition/crnn_decoder.py index 76cb318f6..ce0696e0b 100644 --- a/mmdeploy/codebase/mmocr/models/text_recognition/crnn_decoder.py +++ b/mmdeploy/codebase/mmocr/models/text_recognition/crnn_decoder.py @@ -5,7 +5,7 @@ from mmdeploy.core import FUNCTION_REWRITER @FUNCTION_REWRITER.register_rewriter( func_name='mmocr.models.textrecog.decoders.CRNNDecoder.forward_train', backend='ncnn') -def crnndecoder__forward_train__ncnn(ctx, self, feat, *args, **kwargs): +def crnndecoder__forward_train__ncnn(self, feat, *args, **kwargs): """Rewrite `forward_train` of CRNNDecoder for ncnn backend. Rewrite this function to skip permuting dims of outputs from `[W, N, C]` to diff --git a/mmdeploy/codebase/mmocr/models/text_recognition/encoder_decoder_recognizer.py b/mmdeploy/codebase/mmocr/models/text_recognition/encoder_decoder_recognizer.py index 155ece62c..041ab9758 100644 --- a/mmdeploy/codebase/mmocr/models/text_recognition/encoder_decoder_recognizer.py +++ b/mmdeploy/codebase/mmocr/models/text_recognition/encoder_decoder_recognizer.py @@ -7,7 +7,7 @@ from mmdeploy.core import FUNCTION_REWRITER @FUNCTION_REWRITER.register_rewriter( func_name='mmocr.models.textrecog.EncoderDecoderRecognizer.forward') -def encoder_decoder_recognizer__forward(ctx, self, batch_inputs: torch.Tensor, +def encoder_decoder_recognizer__forward(self, batch_inputs: torch.Tensor, data_samples: TextRecogDataSample, **kwargs) -> TextRecogDataSample: """Rewrite `forward` of EncoderDecoderRecognizer for default backend. diff --git a/mmdeploy/codebase/mmocr/models/text_recognition/lstm_layer.py b/mmdeploy/codebase/mmocr/models/text_recognition/lstm_layer.py index bd0d5df34..b181daaaa 100644 --- a/mmdeploy/codebase/mmocr/models/text_recognition/lstm_layer.py +++ b/mmdeploy/codebase/mmocr/models/text_recognition/lstm_layer.py @@ -6,7 +6,7 @@ from mmdeploy.core import FUNCTION_REWRITER func_name='mmocr.models.textrecog.layers.lstm_layer' '.BidirectionalLSTM.forward', backend='ncnn') -def bidirectionallstm__forward__ncnn(ctx, self, input): +def bidirectionallstm__forward__ncnn(self, input): """Rewrite `forward` of BidirectionalLSTM for ncnn backend. Rewrite this function to set batch_first of rnn layer to true. RNN in ncnn diff --git a/mmdeploy/codebase/mmocr/models/text_recognition/sar_decoder.py b/mmdeploy/codebase/mmocr/models/text_recognition/sar_decoder.py index c6fc68c83..38b51e092 100644 --- a/mmdeploy/codebase/mmocr/models/text_recognition/sar_decoder.py +++ b/mmdeploy/codebase/mmocr/models/text_recognition/sar_decoder.py @@ -15,7 +15,6 @@ from mmdeploy.core import FUNCTION_REWRITER, MODULE_REWRITER '._2d_attention', backend='default') def parallel_sar_decoder__2d_attention( - ctx, self, decoder_input: torch.Tensor, feat: torch.Tensor, @@ -85,8 +84,7 @@ def parallel_sar_decoder__2d_attention( func_name='mmocr.models.textrecog.decoders.SequentialSARDecoder' '._2d_attention', backend='default') -def sequential_sar_decoder__2d_attention(ctx, - self, +def sequential_sar_decoder__2d_attention(self, y_prev, feat, holistic_feat, @@ -151,7 +149,6 @@ def sequential_sar_decoder__2d_attention(ctx, '.forward_test', backend='default') def sequential_sar_decoder__forward_test( - ctx, self, feat: torch.Tensor, out_enc: torch.Tensor, diff --git a/mmdeploy/codebase/mmocr/models/text_recognition/sar_encoder.py b/mmdeploy/codebase/mmocr/models/text_recognition/sar_encoder.py index 8c756fc00..dc5a87f6f 100644 --- a/mmdeploy/codebase/mmocr/models/text_recognition/sar_encoder.py +++ b/mmdeploy/codebase/mmocr/models/text_recognition/sar_encoder.py @@ -12,7 +12,6 @@ from mmdeploy.core import FUNCTION_REWRITER func_name='mmocr.models.textrecog.encoders.SAREncoder.forward', backend='default') def sar_encoder__forward( - ctx, self, feat: torch.Tensor, data_samples: Optional[Sequence[TextRecogDataSample]] = None): diff --git a/mmdeploy/codebase/mmpose/models/heads/mspn_head.py b/mmdeploy/codebase/mmpose/models/heads/mspn_head.py index 2c92d0271..7b391040f 100644 --- a/mmdeploy/codebase/mmpose/models/heads/mspn_head.py +++ b/mmdeploy/codebase/mmpose/models/heads/mspn_head.py @@ -7,7 +7,7 @@ from mmdeploy.core import FUNCTION_REWRITER 'mmpose.models.heads.heatmap_heads.CPMHead.forward') @FUNCTION_REWRITER.register_rewriter( 'mmpose.models.heads.heatmap_heads.MSPNHead.forward') -def mspn_head__forward(ctx, self, feats): +def mspn_head__forward(self, feats): """Rewrite `forward` of MSPNHead and CPMHead for default backend. 1. return last stage heatmaps directly. @@ -18,6 +18,7 @@ def mspn_head__forward(ctx, self, feats): Returns: output_heatmap (torch.Tensor): Output heatmaps. """ + ctx = FUNCTION_REWRITER.get_context() msmu_batch_heatmaps = ctx.origin_func(self, feats) batch_heatmaps = msmu_batch_heatmaps[-1] return batch_heatmaps diff --git a/mmdeploy/codebase/mmpose/models/pose_estimators/base.py b/mmdeploy/codebase/mmpose/models/pose_estimators/base.py index a0e11e45f..3962a4f3f 100644 --- a/mmdeploy/codebase/mmpose/models/pose_estimators/base.py +++ b/mmdeploy/codebase/mmpose/models/pose_estimators/base.py @@ -4,7 +4,7 @@ from mmdeploy.core import FUNCTION_REWRITER @FUNCTION_REWRITER.register_rewriter( 'mmpose.models.pose_estimators.base.BasePoseEstimator.forward') -def base_pose_estimator__forward(ctx, self, inputs, *args, **kwargs): +def base_pose_estimator__forward(self, inputs, *args, **kwargs): """Rewrite `forward` of TopDown for default backend.'. 1.directly call _forward of subclass. diff --git a/mmdeploy/codebase/mmrotate/models/dense_heads/oriented_rpn_head.py b/mmdeploy/codebase/mmrotate/models/dense_heads/oriented_rpn_head.py index ab091f979..14331f3c1 100644 --- a/mmdeploy/codebase/mmrotate/models/dense_heads/oriented_rpn_head.py +++ b/mmdeploy/codebase/mmrotate/models/dense_heads/oriented_rpn_head.py @@ -17,8 +17,7 @@ from mmdeploy.utils import is_dynamic_shape @FUNCTION_REWRITER.register_rewriter( 'mmrotate.models.dense_heads.OrientedRPNHead.predict_by_feat') -def rpn_head__predict_by_feat(ctx, - self, +def rpn_head__predict_by_feat(self, cls_scores: List[Tensor], bbox_preds: List[Tensor], score_factors: Optional[List[Tensor]] = None, @@ -62,6 +61,7 @@ def rpn_head__predict_by_feat(ctx, tuple[Tensor, Tensor, Tensor]: batch_mlvl_bboxes, batch_mlvl_scores, batch_mlvl_centerness """ + ctx = FUNCTION_REWRITER.get_context() img_metas = batch_img_metas assert len(cls_scores) == len(bbox_preds) deploy_cfg = ctx.cfg diff --git a/mmdeploy/codebase/mmrotate/models/roi_heads/gv_bbox_head.py b/mmdeploy/codebase/mmrotate/models/roi_heads/gv_bbox_head.py index 33d304d7a..f070d96f2 100644 --- a/mmdeploy/codebase/mmrotate/models/roi_heads/gv_bbox_head.py +++ b/mmdeploy/codebase/mmrotate/models/roi_heads/gv_bbox_head.py @@ -15,8 +15,7 @@ from mmdeploy.mmcv.ops import multiclass_nms @FUNCTION_REWRITER.register_rewriter( 'mmrotate.models.roi_heads.bbox_heads.GVBBoxHead.predict_by_feat') -def gv_bbox_head__predict_by_feat(ctx, - self, +def gv_bbox_head__predict_by_feat(self, rois: Tuple[Tensor], cls_scores: Tuple[Tensor], bbox_preds: Tuple[Tensor], @@ -60,6 +59,7 @@ def gv_bbox_head__predict_by_feat(ctx, assert rois.ndim == 3, 'Only support export two stage ' \ 'model to ONNX ' \ 'with batch dimension. ' + ctx = FUNCTION_REWRITER.get_context() img_shape = batch_img_metas[0]['img_shape'] if self.custom_cls_channels: diff --git a/mmdeploy/codebase/mmrotate/models/roi_heads/gv_ratio_roi_head.py b/mmdeploy/codebase/mmrotate/models/roi_heads/gv_ratio_roi_head.py index bc2b85589..f2c17033c 100644 --- a/mmdeploy/codebase/mmrotate/models/roi_heads/gv_ratio_roi_head.py +++ b/mmdeploy/codebase/mmrotate/models/roi_heads/gv_ratio_roi_head.py @@ -11,8 +11,7 @@ from mmdeploy.core import FUNCTION_REWRITER @FUNCTION_REWRITER.register_rewriter( 'mmrotate.models.roi_heads.gv_ratio_roi_head' '.GVRatioRoIHead.predict_bbox') -def gv_ratio_roi_head__predict_bbox(ctx, - self, +def gv_ratio_roi_head__predict_bbox(self, x: Tuple[Tensor], batch_img_metas: List[dict], rpn_results_list: InstanceList, diff --git a/mmdeploy/codebase/mmrotate/models/roi_heads/roi_extractors.py b/mmdeploy/codebase/mmrotate/models/roi_heads/roi_extractors.py index f48e0dcf3..4e01c66ed 100644 --- a/mmdeploy/codebase/mmrotate/models/roi_heads/roi_extractors.py +++ b/mmdeploy/codebase/mmrotate/models/roi_heads/roi_extractors.py @@ -66,8 +66,7 @@ class MultiLevelRotatedRoiAlign(Function): backend='tensorrt') @mark( 'rotated_roi_extractor', inputs=['feats', 'rois'], outputs=['bbox_feats']) -def rotated_single_roi_extractor__forward__tensorrt(ctx, - self, +def rotated_single_roi_extractor__forward__tensorrt(self, feats, rois, roi_scale_factor=None): diff --git a/mmdeploy/codebase/mmrotate/models/task_modules/coders.py b/mmdeploy/codebase/mmrotate/models/task_modules/coders.py index c95028faa..886fbcb25 100644 --- a/mmdeploy/codebase/mmrotate/models/task_modules/coders.py +++ b/mmdeploy/codebase/mmrotate/models/task_modules/coders.py @@ -8,7 +8,7 @@ from mmdeploy.core import FUNCTION_REWRITER @FUNCTION_REWRITER.register_rewriter( 'mmrotate.models.task_modules.coders.gliding_vertex_coder' '.GVFixCoder.decode') -def gvfixcoder__decode(ctx, self, hboxes, fix_deltas): +def gvfixcoder__decode(self, hboxes, fix_deltas): """Rewriter for GVFixCoder decode, support more dimension input.""" assert hboxes.size( diff --git a/mmdeploy/codebase/mmrotate/structures/bbox.py b/mmdeploy/codebase/mmrotate/structures/bbox.py index 0b947cb01..dd8bbae3f 100644 --- a/mmdeploy/codebase/mmrotate/structures/bbox.py +++ b/mmdeploy/codebase/mmrotate/structures/bbox.py @@ -21,7 +21,7 @@ def _dist_torch(point1, point2): @FUNCTION_REWRITER.register_rewriter( 'mmrotate.structures.bbox.box_converters.qbox2rbox') -def qbox2rbox__default(ctx, boxes: Tensor) -> Tensor: +def qbox2rbox__default(boxes: Tensor) -> Tensor: """Convert quadrilateral boxes to rotated boxes. Implement with PyTorch. diff --git a/mmdeploy/codebase/mmseg/models/decode_heads/ema_head.py b/mmdeploy/codebase/mmseg/models/decode_heads/ema_head.py index 5d839691b..6ff07cd10 100644 --- a/mmdeploy/codebase/mmseg/models/decode_heads/ema_head.py +++ b/mmdeploy/codebase/mmseg/models/decode_heads/ema_head.py @@ -7,7 +7,7 @@ from mmdeploy.core import FUNCTION_REWRITER @FUNCTION_REWRITER.register_rewriter( func_name='mmseg.models.decode_heads.ema_head.EMAModule.forward') -def ema_module__forward(ctx, self, feats): +def ema_module__forward(self, feats): """Rewrite `forward` for default backend. Replace torch.einsum with other operations. diff --git a/mmdeploy/codebase/mmseg/models/decode_heads/point_head.py b/mmdeploy/codebase/mmseg/models/decode_heads/point_head.py index 717f5a7af..09b863a87 100644 --- a/mmdeploy/codebase/mmseg/models/decode_heads/point_head.py +++ b/mmdeploy/codebase/mmseg/models/decode_heads/point_head.py @@ -7,8 +7,8 @@ from mmdeploy.utils import get_root_logger @FUNCTION_REWRITER.register_rewriter( func_name='mmseg.models.decode_heads.point_head.PointHead.get_points_test', backend='tensorrt') -def point_head__get_points_test__tensorrt(ctx, self, seg_logits, - uncertainty_func, cfg): +def point_head__get_points_test__tensorrt(self, seg_logits, uncertainty_func, + cfg): """Sample points for testing. 1. set `num_points` no greater than TENSORRT_MAX_TOPK for tensorrt backend @@ -26,6 +26,7 @@ def point_head__get_points_test__tensorrt(ctx, self, seg_logits, 2) that contains [0, 1] x [0, 1] normalized coordinates of the most uncertain points from the ``height x width`` grid . """ + ctx = FUNCTION_REWRITER.get_context() from mmdeploy.utils.constants import TENSORRT_MAX_TOPK if cfg.subdivision_num_points > TENSORRT_MAX_TOPK: diff --git a/mmdeploy/codebase/mmseg/models/segmentors/base.py b/mmdeploy/codebase/mmseg/models/segmentors/base.py index 68e319622..360607407 100644 --- a/mmdeploy/codebase/mmseg/models/segmentors/base.py +++ b/mmdeploy/codebase/mmseg/models/segmentors/base.py @@ -7,8 +7,7 @@ from mmdeploy.utils import is_dynamic_shape @FUNCTION_REWRITER.register_rewriter( func_name='mmseg.models.segmentors.BaseSegmentor.forward') -def base_segmentor__forward(ctx, - self, +def base_segmentor__forward(self, inputs, data_samples=None, mode='predict', @@ -27,6 +26,7 @@ def base_segmentor__forward(ctx, Returns: torch.Tensor: Output segmentation map pf shape [N, 1, H, W]. """ + ctx = FUNCTION_REWRITER.get_context() if data_samples is None: data_samples = [SegDataSample()] diff --git a/mmdeploy/codebase/mmseg/models/segmentors/cascade_encoder_decoder.py b/mmdeploy/codebase/mmseg/models/segmentors/cascade_encoder_decoder.py index 30828311a..ad8d35b81 100644 --- a/mmdeploy/codebase/mmseg/models/segmentors/cascade_encoder_decoder.py +++ b/mmdeploy/codebase/mmseg/models/segmentors/cascade_encoder_decoder.py @@ -4,8 +4,7 @@ from mmdeploy.core import FUNCTION_REWRITER @FUNCTION_REWRITER.register_rewriter( func_name='mmseg.models.segmentors.CascadeEncoderDecoder.predict') -def cascade_encoder_decoder__predict(ctx, self, inputs, data_samples, - **kwargs): +def cascade_encoder_decoder__predict(self, inputs, data_samples, **kwargs): """Rewrite `predict` for default backend. 1. only support mode=`whole` inference diff --git a/mmdeploy/codebase/mmseg/models/segmentors/encoder_decoder.py b/mmdeploy/codebase/mmseg/models/segmentors/encoder_decoder.py index 332f39bed..ee401b22b 100644 --- a/mmdeploy/codebase/mmseg/models/segmentors/encoder_decoder.py +++ b/mmdeploy/codebase/mmseg/models/segmentors/encoder_decoder.py @@ -5,7 +5,7 @@ from mmdeploy.utils.constants import Backend @FUNCTION_REWRITER.register_rewriter( func_name='mmseg.models.segmentors.EncoderDecoder.predict') -def encoder_decoder__predict(ctx, self, inputs, data_samples, **kwargs): +def encoder_decoder__predict(self, inputs, data_samples, **kwargs): """Rewrite `predict` for default backend. 1. only support mode=`whole` inference @@ -32,7 +32,7 @@ def encoder_decoder__predict(ctx, self, inputs, data_samples, **kwargs): @FUNCTION_REWRITER.register_rewriter( func_name='mmseg.models.segmentors.EncoderDecoder.predict', backend=Backend.RKNN.value) -def encoder_decoder__predict__rknn(ctx, self, inputs, data_samples, **kwargs): +def encoder_decoder__predict__rknn(self, inputs, data_samples, **kwargs): """Rewrite `predict` for RKNN backend. Early return to avoid argmax operator. diff --git a/mmdeploy/codebase/mmseg/models/utils/up_conv_block.py b/mmdeploy/codebase/mmseg/models/utils/up_conv_block.py index 6ccf56f2b..bc1029976 100644 --- a/mmdeploy/codebase/mmseg/models/utils/up_conv_block.py +++ b/mmdeploy/codebase/mmseg/models/utils/up_conv_block.py @@ -8,7 +8,7 @@ from mmdeploy.utils import is_dynamic_shape @FUNCTION_REWRITER.register_rewriter( func_name='mmseg.models.utils.UpConvBlock.forward') -def up_conv_block__forward(ctx, self, skip, x): +def up_conv_block__forward(self, skip, x): """Rewrite `forward` for default backend. To support dynamic shape for UNet backbone, @@ -23,6 +23,7 @@ def up_conv_block__forward(ctx, self, skip, x): Returns: Tensor: Upsampled output feature map. """ + ctx = FUNCTION_REWRITER.get_context() from mmcv.cnn import ConvModule # only valid when self.upsample is from build_upsample_layer diff --git a/mmdeploy/core/optimizers/function_marker.py b/mmdeploy/core/optimizers/function_marker.py index 57ab7ff19..41deef71b 100644 --- a/mmdeploy/core/optimizers/function_marker.py +++ b/mmdeploy/core/optimizers/function_marker.py @@ -62,18 +62,20 @@ class Mark(torch.autograd.Function): @FUNCTION_REWRITER.register_rewriter( 'mmdeploy.core.optimizers.function_marker.Mark.symbolic') -def mark_symbolic(rewriter, g, x, *args): +def mark_symbolic(g, x, *args): """Rewrite symbolic of mark op.""" - if cfg_apply_marks(rewriter.cfg): - return rewriter.origin_func(g, x, *args) + ctx = FUNCTION_REWRITER.get_context() + if cfg_apply_marks(ctx.cfg): + return ctx.origin_func(g, x, *args) return x @FUNCTION_REWRITER.register_rewriter( 'mmdeploy.core.optimizers.function_marker.Mark.forward') -def forward_of_mark(rewriter, ctx, x, dtype, shape, func, func_id, type, name, - id, attrs) -> torch.Tensor: +def forward_of_mark(ctx, x, dtype, shape, func, func_id, type, name, id, + attrs) -> torch.Tensor: """Rewrite forward of mark op.""" + rewriter = FUNCTION_REWRITER.get_context() deploy_cfg = rewriter.cfg # save calib data apply_marks = cfg_apply_marks(deploy_cfg) @@ -182,7 +184,7 @@ def mark_tensors(xs: Any, func: str, func_id: int, io_type: str, ctx: Any, @FUNCTION_REWRITER.register_rewriter( 'mmdeploy.core.optimizers.function_marker.mark_tensors', ir=IR.TORCHSCRIPT) -def remove_mark__torchscript(ctx, xs: Any, *args, **kwargs): +def remove_mark__torchscript(xs: Any, *args, **kwargs): """Disable all marks for TorchScript backend. As the Node `mark` is not able to be traced, we just return original input @@ -216,12 +218,15 @@ def mark(func_name: Optional[str] = None, >>> from mmdeploy.core import FUNCTION_REWRITER, mark >>> @FUNCTION_REWRITER.register_rewriter( >>> func_name='mmdet.models.roi_heads.ConvFCBBoxHead.forward') - >>> @mark( - >>> 'bbox_head_forward', - >>> inputs=['bbox_feats'], - >>> outputs=['cls_score', 'bbox_pred']) - >>> def forward_of_bbox_head(ctx, self, x): - >>> return ctx.origin_func(self, x) + >>> def forward_of_bbox_head(self, x): + >>> ctx = FUNCTION_REWRITER.get_context() + >>> @mark( + >>> 'bbox_head_forward', + >>> inputs=['bbox_feats'], + >>> outputs=['cls_score', 'bbox_pred']) + >>> def _impl(): + >>> return ctx.origin_func(self, x) + >>> return _impl() """ MARK_FUNCTION_COUNT[func_name] = 0 diff --git a/mmdeploy/core/rewriters/function_rewriter.py b/mmdeploy/core/rewriters/function_rewriter.py index b623476f3..7882f5fb3 100644 --- a/mmdeploy/core/rewriters/function_rewriter.py +++ b/mmdeploy/core/rewriters/function_rewriter.py @@ -1,11 +1,25 @@ # Copyright (c) OpenMMLab. All rights reserved. +import types +from collections import defaultdict from typing import (Any, Callable, Dict, List, MutableSequence, Optional, Tuple, Union) from mmdeploy.utils import IR, Backend, get_root_logger from .rewriter_utils import (Checker, ContextCaller, RewriterRegistry, + copy_function, get_frame_func, get_func_qualname, import_function) +try: + try: + # torch>=1.10.0 + from torch.fx._symbolic_trace import _wrapped_fns_to_patch + except ImportError: + # 1.10.0>torch>=1.8.0 + from torch.fx.symbolic_trace import _wrapped_fns_to_patch +except ImportError: + # torch<1.8.0 + _wrapped_fns_to_patch = [] + def _replace_all_obj(obj: Any, new_obj: Any, @@ -92,6 +106,24 @@ def _del_func(path: str): continue +def _fx_wrap_copied_fn(func: types.FunctionType, + copied_func: types.FunctionType): + """If a function is wrapped by torch.fx.wrap, its copy also needs to be + wrapped by torch.fx.wrap.""" + if not hasattr(func, '__globals__'): + return + + wrapped_fns_globals = [item[0] for item in _wrapped_fns_to_patch] + wrapped_fns_names = [item[1] for item in _wrapped_fns_to_patch] + + # check if wrapped by torch.fx.wrap + if func.__globals__ in wrapped_fns_globals: + idx = wrapped_fns_globals.index(func.__globals__) + fn_name = wrapped_fns_names[idx] + # a hacky way to wrap the func in copied func + _wrapped_fns_to_patch.append((copied_func.__globals__, fn_name)) + + class FunctionRewriter: """A function rewriter which maintains rewritten functions. @@ -102,7 +134,8 @@ class FunctionRewriter: Examples: >>> @FUNCTION_REWRITER.register_rewriter( >>> func_name='torch.Tensor.size', backend='ncnn') - >>> def size_of_tensor_static(ctx, self, *args): + >>> def size_of_tensor_static(self, *args): + >>> ctx = FUNCTION_REWRITER.get_context() >>> ret = ctx.origin_func(self, *args) >>> if isinstance(ret, torch.Tensor): >>> ret = int(ret) @@ -114,6 +147,7 @@ class FunctionRewriter: def __init__(self): self._registry = RewriterRegistry() + self._func_contexts = defaultdict(list) def register_rewriter( self, @@ -140,8 +174,11 @@ class FunctionRewriter: def enter(self, cfg: Dict = dict(), env: Dict = dict(), **kwargs): """The implementation of function rewrite.""" + self._func_contexts.clear() # Get current records functions_records = self._registry.get_records(env) + # Get current fx wrapped func nums + self._ori_fx_wrap_num = len(_wrapped_fns_to_patch) self._origin_functions = list() self._additional_functions = list() @@ -181,15 +218,25 @@ class FunctionRewriter: # Create context_caller rewrite_function = record_dict['_object'] + # The func before and after copy has different globals + rewrite_function = copy_function(rewrite_function) extra_kwargs = kwargs.copy() extra_kwargs.update(record_dict) - context_caller = ContextCaller( - rewrite_function, origin_func, cfg, - **extra_kwargs).get_wrapped_caller() + context_caller = ContextCaller(rewrite_function, origin_func, + cfg, **extra_kwargs) + # If there is a function wrapped by torch.fx.wrap in + # rewrite_function's globals, we need to wrap the same name + # function in copied function's globals. + _fx_wrap_copied_fn(record_dict['_object'], context_caller.func) + + qualname = get_func_qualname(rewrite_function) + self._func_contexts[qualname].append(context_caller) + self._func_contexts[function_path].append(context_caller) # Cache new the function to avoid homonymic bug new_functions.append( - dict(func_path=function_path, origin_func=context_caller)) + dict( + func_path=function_path, origin_func=rewrite_function)) for func_dict in new_functions: function_path = func_dict['func_path'] @@ -199,9 +246,46 @@ class FunctionRewriter: def exit(self): """Recover the function rewrite.""" + # Restore _wrapped_fns_to_patch + cur_fx_wrap_num = len(_wrapped_fns_to_patch) + for _ in range(cur_fx_wrap_num - self._ori_fx_wrap_num): + _wrapped_fns_to_patch.pop(-1) + for func_dict in self._origin_functions: func_path = func_dict['func_path'] func = func_dict['origin_func'] _set_func(func_path, func) for func_path in self._additional_functions: _del_func(func_path) + + self._func_contexts.clear() + + def get_context(self, key: Optional[str] = None) -> ContextCaller: + """Get the context of rewriter. + + Args: + key: key to the context. + + Returns: + ContextCaller: context of function + """ + func = None + if key is None: + func = get_frame_func(2) + key = get_func_qualname(func) + + # get all contexts + ctxs = self._func_contexts.get(key, []) + + if func is None: + assert len(ctxs) == 1 + return ctxs[0] + + ctx = None + for tmp_ctx in ctxs: + if tmp_ctx.func == func: + ctx = tmp_ctx + + if ctx is None: + get_root_logger().warning(f'Can not found context of {key}') + return ctx diff --git a/mmdeploy/core/rewriters/rewriter_utils.py b/mmdeploy/core/rewriters/rewriter_utils.py index 7c5e4e45e..ca1e98936 100644 --- a/mmdeploy/core/rewriters/rewriter_utils.py +++ b/mmdeploy/core/rewriters/rewriter_utils.py @@ -1,5 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. +import functools import inspect +import types import warnings from abc import ABCMeta, abstractmethod from functools import wraps @@ -342,8 +344,9 @@ class ContextCaller: Example: >>> @FUNCTION_REWRITER.register_rewriter(func_name='torch.add') - >>> def func(ctx, x, y): + >>> def func(x, y): >>> # ctx is an instance of ContextCaller + >>> ctx = FUNCTION_REWRITER.get_context() >>> print(ctx.cfg) >>> return x + y """ @@ -379,3 +382,61 @@ class ContextCaller: return self.func(self, *args, **kwargs) return wrapper + + +def get_func_qualname(func: Callable) -> str: + """get function name.""" + assert isinstance(func, Callable), f'{func} is not a Callable object.' + _func_name = None + if hasattr(func, '__qualname__'): + _func_name = f'{func.__module__}.{func.__qualname__}' + elif hasattr(func, '__class__'): + _func_name = func.__class__ + else: + _func_name = str(func) + return _func_name + + +def get_frame_func(top: int = 1) -> Callable: + """get func of frame.""" + frameinfo = inspect.stack()[top] + frame = frameinfo.frame + + g_vars = frame.f_globals + func_name = frameinfo.function + assert func_name in g_vars, \ + f'Can not find function: {func_name} in global.' + func = g_vars[func_name] + return func + + +def get_frame_qualname(top: int = 1) -> str: + """get frame name.""" + frameinfo = inspect.stack()[top] + frame = frameinfo.frame + + g_vars = frame.f_globals + func_name = frameinfo.function + assert func_name in g_vars, \ + f'Can not find function: {func_name} in global.' + func = g_vars[func_name] + module_name = inspect.getmodule(func).__name__ + + return f'{module_name}.{func_name}' + + +def copy_function(f: types.FunctionType): + """Copy the function.""" + # copy the global so we can get different func for different origin + glb = f.__globals__.copy() + name = f.__name__ + g = types.FunctionType( + f.__code__, + glb, + name=name, + argdefs=f.__defaults__, + closure=f.__closure__) + g = functools.update_wrapper(g, f) + g.__kwdefaults__ = f.__kwdefaults__ + glb[name] = g + return g diff --git a/mmdeploy/core/rewriters/symbolic_rewriter.py b/mmdeploy/core/rewriters/symbolic_rewriter.py index dfcc28f76..e045dcc35 100644 --- a/mmdeploy/core/rewriters/symbolic_rewriter.py +++ b/mmdeploy/core/rewriters/symbolic_rewriter.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +from collections import defaultdict from typing import Callable, Dict, List, Optional, Sequence, Union import torch @@ -7,7 +8,8 @@ from torch.onnx.symbolic_helper import parse_args from mmdeploy.utils import IR, Backend, get_root_logger from .rewriter_utils import (Checker, ContextCaller, RewriterRegistry, - eval_with_import) + copy_function, eval_with_import, get_frame_func, + get_func_qualname) class SymbolicRewriter: @@ -21,7 +23,7 @@ class SymbolicRewriter: Examples: >>> @SYMBOLIC_REWRITER.register_symbolic('squeeze', \ >>> is_pytorch=True) - >>> def squeeze_default(ctx, g, self, dim=None): + >>> def squeeze_default(g, self, dim=None): >>> if dim is None: >>> dims = [] >>> for i, size in enumerate(self.type().sizes()): @@ -34,6 +36,7 @@ class SymbolicRewriter: def __init__(self) -> None: self._registry = RewriterRegistry() + self._func_contexts = defaultdict(list) def register_symbolic(self, func_name: str, @@ -75,6 +78,9 @@ class SymbolicRewriter: opset: int = 11, **kwargs): """The implementation of symbolic register.""" + # clear context + self._func_contexts.clear() + # Get current records symbolic_records = self._registry.get_records(env) @@ -84,19 +90,27 @@ class SymbolicRewriter: for function_name, record_dict in symbolic_records: symbolic_function = record_dict['_object'] + symbolic_function = copy_function(symbolic_function) arg_descriptors = record_dict['arg_descriptors'] extra_kwargs = kwargs.copy() extra_kwargs.update(record_dict) context_caller = ContextCaller(symbolic_function, None, cfg, **extra_kwargs) + + # register context + qualname = get_func_qualname(symbolic_function) + self._func_contexts[qualname].append(context_caller) + self._func_contexts[function_name].append(context_caller) + if arg_descriptors is not None and len(arg_descriptors) > 0: - context_caller = parse_args(*arg_descriptors)(context_caller) + symbolic_function = parse_args(*arg_descriptors)( + symbolic_function) is_pytorch = record_dict['is_pytorch'] if is_pytorch: from torch.onnx import register_custom_op_symbolic register_custom_op_symbolic(f'::{function_name}', - context_caller, opset) + symbolic_function, opset) # Save domain and version self._pytorch_symbolic.append((function_name, '', opset)) @@ -123,7 +137,7 @@ class SymbolicRewriter: self._extra_symbolic.append((origin_func, origin_symbolic)) # Cache new the function to avoid homonymic bug - new_functions.append((origin_func, context_caller)) + new_functions.append((origin_func, symbolic_function)) for origin_func, new_func in new_functions: origin_symbolic = getattr(origin_func, 'symbolic', None) @@ -132,6 +146,9 @@ class SymbolicRewriter: def exit(self): """The implementation of symbolic unregister.""" + # clear context + self._func_contexts.clear() + # Unregister pytorch op if hasattr(torch.onnx, 'unregister_custom_op_symbolic'): from torch.onnx import unregister_custom_op_symbolic @@ -149,3 +166,33 @@ class SymbolicRewriter: # Unregister custom op for origin_func, origin_symbolic in self._extra_symbolic: origin_func.symbolic = origin_symbolic + + def get_context(self, key: Optional[str] = None) -> ContextCaller: + """Get the context of rewriter. + + Args: + key: key to the context. + + Returns: + ContextCaller: context of function + """ + func = None + if key is None: + func = get_frame_func(2) + key = get_func_qualname(func) + + # get all contexts + ctxs = self._func_contexts.get(key, []) + + if func is None: + assert len(ctxs) == 1 + return ctxs[0] + + ctx = None + for tmp_ctx in ctxs: + if tmp_ctx.func == func: + ctx = tmp_ctx + + if ctx is None: + get_root_logger().warning(f'Can not found context of {key}') + return ctx diff --git a/mmdeploy/mmcv/cnn/__init__.py b/mmdeploy/mmcv/cnn/__init__.py index 3b777d8b0..917a4a6df 100644 --- a/mmdeploy/mmcv/cnn/__init__.py +++ b/mmdeploy/mmcv/cnn/__init__.py @@ -1,5 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -from . import conv2d_adaptive_padding # noqa: F401,F403 from .transformer import MultiHeadAttentionop __all__ = ['MultiHeadAttentionop'] diff --git a/mmdeploy/mmcv/cnn/conv2d_adaptive_padding.py b/mmdeploy/mmcv/cnn/conv2d_adaptive_padding.py deleted file mode 100644 index d00184c8e..000000000 --- a/mmdeploy/mmcv/cnn/conv2d_adaptive_padding.py +++ /dev/null @@ -1,86 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import math - -import torch -import torch.nn.functional as F - -from mmdeploy.core import FUNCTION_REWRITER -from mmdeploy.utils import Backend, is_dynamic_batch, is_dynamic_shape - - -def compute_padding(input_size, kernel_size, stride, dilation): - """Compute padding.""" - - input_h, input_w = input_size - kernel_h, kernel_w = kernel_size - stride_h, stride_w = stride - dilation_h, dilation_w = dilation - output_h = math.ceil(input_h / stride_h) - output_w = math.ceil(input_w / stride_w) - pad_h = max( - (output_h - 1) * stride_h + (kernel_h - 1) * dilation_h + 1 - input_h, - 0) - pad_w = max( - (output_w - 1) * stride_w + (kernel_w - 1) * dilation_w + 1 - input_w, - 0) - if pad_w > 0 or pad_h > 0: - padded = [ - pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2 - ] - else: - padded = None - return padded - - -class AdaptivePadOp(torch.autograd.Function): - """Dummy adaptive pad op.""" - - @staticmethod - def forward(ctx, x, padded): - if padded is not None: - x = F.pad(x, padded) - return x - - @staticmethod - def symbolic(g, x, padded): - if padded is None: - return g.op('Identity', x) - padded = g.op( - 'Constant', value_t=torch.tensor(padded, dtype=torch.int64)) - constant_value = g.op( - 'Constant', value_t=torch.tensor(0, dtype=torch.int64)) - return g.op( - 'Pad', x, padded, constant_value, mode_s='constant', outputs=1) - - -@FUNCTION_REWRITER.register_rewriter( - func_name='mmcv.cnn.bricks.conv2d_adaptive_padding. \ - Conv2dAdaptivePadding.forward', - backend=Backend.TENSORRT.value) -def conv2d_adaptive_padding__forward__tensorrt(ctx, self, x): - """Rewrite `forward` of Conv2dAdaptivePadding used in EfficientNet for - TensorRT backend. Main changes of this rewritten function is to separate - the computation of padding and encapsulate it into another - `torch.autograd.Function` so that the adaptive padding could be parsed as - `Pad` ops in ONNX with the padding information computed in advance (Only - for static shape configuration). - - Args: - x (Tensor): Input tensor of Conv2dAdaptivePadding ops - Returns: - Tensor: forward result of 2D convolution after padding - """ - - deploy_cfg = ctx.cfg - is_dynamic_flag = is_dynamic_shape(deploy_cfg) - if (not is_dynamic_flag) or is_dynamic_batch(deploy_cfg): - padded = compute_padding(x.shape[2:], self.weight.shape[2:], - self.stride, self.dilation) - if padded is not None: - padded = [int(_) for _ in padded] - x = AdaptivePadOp.apply(x, padded) - return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, - self.dilation, self.groups) - else: - x = ctx.origin_func(x) - return x diff --git a/mmdeploy/mmcv/cnn/transformer.py b/mmdeploy/mmcv/cnn/transformer.py index 58f79657c..6069f5a43 100644 --- a/mmdeploy/mmcv/cnn/transformer.py +++ b/mmdeploy/mmcv/cnn/transformer.py @@ -57,8 +57,7 @@ class MultiHeadAttentionop(torch.autograd.Function): @FUNCTION_REWRITER.register_rewriter( func_name='mmcv.cnn.bricks.transformer.MultiheadAttention.forward', backend=Backend.NCNN.value) -def multiheadattention__forward__ncnn(ctx, - self, +def multiheadattention__forward__ncnn(self, query, key=None, value=None, diff --git a/mmdeploy/mmcv/ops/deform_conv.py b/mmdeploy/mmcv/ops/deform_conv.py index 3e2a436f4..fbbc300b8 100644 --- a/mmdeploy/mmcv/ops/deform_conv.py +++ b/mmdeploy/mmcv/ops/deform_conv.py @@ -4,8 +4,7 @@ from mmdeploy.core import SYMBOLIC_REWRITER @SYMBOLIC_REWRITER.register_symbolic( 'mmcv.ops.deform_conv.DeformConv2dFunction') -def deform_conv__default(ctx, - g, +def deform_conv__default(g, input, offset, weight, @@ -31,8 +30,7 @@ def deform_conv__default(ctx, @SYMBOLIC_REWRITER.register_symbolic( 'mmcv.ops.deform_conv.DeformConv2dFunction', backend='openvino') -def deform_conv_openvino(ctx, - g, +def deform_conv_openvino(g, input, offset, weight, diff --git a/mmdeploy/mmcv/ops/modulated_deform_conv.py b/mmdeploy/mmcv/ops/modulated_deform_conv.py index df3c338a8..64fd9fdd7 100644 --- a/mmdeploy/mmcv/ops/modulated_deform_conv.py +++ b/mmdeploy/mmcv/ops/modulated_deform_conv.py @@ -4,9 +4,8 @@ from mmdeploy.core import SYMBOLIC_REWRITER @SYMBOLIC_REWRITER.register_symbolic( 'mmcv.ops.modulated_deform_conv.ModulatedDeformConv2dFunction') -def modulated_deform_conv_default(ctx, g, input, offset, mask, weight, bias, - stride, padding, dilation, groups, - deform_groups): +def modulated_deform_conv_default(g, input, offset, mask, weight, bias, stride, + padding, dilation, groups, deform_groups): """Rewrite mdcn symbolic function for all backend.""" input_tensors = [input, offset, mask, weight] if bias is not None: diff --git a/mmdeploy/mmcv/ops/nms.py b/mmdeploy/mmcv/ops/nms.py index 8c316d0a9..8a29f0b16 100644 --- a/mmdeploy/mmcv/ops/nms.py +++ b/mmdeploy/mmcv/ops/nms.py @@ -82,7 +82,6 @@ class ONNXNMSop(torch.autograd.Function): Returns: NonMaxSuppression op for onnx. """ - if not sym_help._is_value(max_output_boxes_per_class): max_output_boxes_per_class = g.op( 'Constant', @@ -354,8 +353,7 @@ def _multiclass_nms_single(boxes: Tensor, @FUNCTION_REWRITER.register_rewriter( func_name='mmdeploy.mmcv.ops.nms._multiclass_nms') -def multiclass_nms__default(ctx, - boxes: Tensor, +def multiclass_nms__default(boxes: Tensor, scores: Tensor, max_output_boxes_per_class: int = 1000, iou_threshold: float = 0.5, @@ -388,6 +386,7 @@ def multiclass_nms__default(ctx, tuple[Tensor, Tensor]: (dets, labels), `dets` of shape [N, num_det, 5] and `labels` of shape [N, num_det]. """ + ctx = FUNCTION_REWRITER.get_context() deploy_cfg = ctx.cfg batch_size = boxes.size(0) if not is_dynamic_batch(deploy_cfg) and batch_size == 1: @@ -414,8 +413,7 @@ def multiclass_nms__default(ctx, @FUNCTION_REWRITER.register_rewriter( func_name='mmdeploy.mmcv.ops.nms._multiclass_nms', backend='tensorrt') -def multiclass_nms_static(ctx, - boxes: Tensor, +def multiclass_nms_static(boxes: Tensor, scores: Tensor, max_output_boxes_per_class: int = 1000, iou_threshold: float = 0.5, @@ -479,12 +477,35 @@ def multiclass_nms_static(ctx, 'multiclass_nms', inputs=['boxes', 'scores'], outputs=['dets', 'labels', 'index']) -def multiclass_nms(*args, nms_type='nms', **kwargs): +def multiclass_nms(boxes: Tensor, + scores: Tensor, + max_output_boxes_per_class: int = 1000, + iou_threshold: float = 0.5, + score_threshold: float = 0.05, + pre_top_k: int = -1, + keep_top_k: int = -1, + output_index: bool = False, + nms_type='nms'): """Apis for multiclass nms.""" if nms_type == 'nms': - return _multiclass_nms(*args, **kwargs) + return _multiclass_nms( + boxes, + scores, + max_output_boxes_per_class=max_output_boxes_per_class, + iou_threshold=iou_threshold, + score_threshold=score_threshold, + pre_top_k=pre_top_k, + keep_top_k=keep_top_k, + output_index=output_index) elif nms_type == 'nms_rotated': - return multiclass_nms_rotated(*args, **kwargs) + return multiclass_nms_rotated( + boxes, + scores, + max_output_boxes_per_class=max_output_boxes_per_class, + iou_threshold=iou_threshold, + score_threshold=score_threshold, + pre_top_k=pre_top_k, + keep_top_k=keep_top_k) else: raise NotImplementedError(f'Unsupported nms type: {nms_type}.') @@ -492,8 +513,7 @@ def multiclass_nms(*args, nms_type='nms', **kwargs): @FUNCTION_REWRITER.register_rewriter( func_name='mmdeploy.mmcv.ops.nms.bbox_nms._multiclass_nms', backend=Backend.COREML.value) -def multiclass_nms__coreml(ctx, - boxes: Tensor, +def multiclass_nms__coreml(boxes: Tensor, scores: Tensor, max_output_boxes_per_class: int = 1000, iou_threshold: float = 0.5, @@ -556,8 +576,7 @@ def multiclass_nms__coreml(ctx, @FUNCTION_REWRITER.register_rewriter( func_name='mmdeploy.mmcv.ops.nms.bbox_nms._multiclass_nms', ir=IR.TORCHSCRIPT) -def multiclass_nms__torchscript(ctx, - boxes: Tensor, +def multiclass_nms__torchscript(boxes: Tensor, scores: Tensor, max_output_boxes_per_class: int = 1000, iou_threshold: float = 0.5, @@ -659,8 +678,7 @@ class AscendBatchNMSOp(torch.autograd.Function): @FUNCTION_REWRITER.register_rewriter( func_name='mmdeploy.mmcv.ops.nms.bbox_nms._multiclass_nms', backend='ascend') -def multiclass_nms__ascend(ctx, - boxes: Tensor, +def multiclass_nms__ascend(boxes: Tensor, scores: Tensor, max_output_boxes_per_class: int = 1000, iou_threshold: float = 0.5, diff --git a/mmdeploy/mmcv/ops/nms_rotated.py b/mmdeploy/mmcv/ops/nms_rotated.py index 8d701b70c..0449d2030 100644 --- a/mmdeploy/mmcv/ops/nms_rotated.py +++ b/mmdeploy/mmcv/ops/nms_rotated.py @@ -272,8 +272,7 @@ def _multiclass_nms_rotated(boxes: Tensor, @FUNCTION_REWRITER.register_rewriter( func_name='mmdeploy.mmcv.ops.nms_rotated._multiclass_nms_rotated', backend='tensorrt') -def multiclass_nms_rotated__tensorrt(ctx, - boxes: Tensor, +def multiclass_nms_rotated__tensorrt(boxes: Tensor, scores: Tensor, max_output_boxes_per_class: int = 1000, iou_threshold: float = 0.5, @@ -317,7 +316,19 @@ def multiclass_nms_rotated__tensorrt(ctx, 'multiclass_nms_rotated', inputs=['boxes', 'scores'], outputs=['dets', 'labels']) -def multiclass_nms_rotated(*args, **kwargs): +def multiclass_nms_rotated(boxes: Tensor, + scores: Tensor, + max_output_boxes_per_class: int = 1000, + iou_threshold: float = 0.1, + score_threshold: float = 0.05, + pre_top_k: int = -1, + keep_top_k: int = -1): """Wrapper function for `_multiclass_nms`.""" return mmdeploy.mmcv.ops.nms_rotated._multiclass_nms_rotated( - *args, **kwargs) + boxes, + scores, + max_output_boxes_per_class=max_output_boxes_per_class, + iou_threshold=iou_threshold, + score_threshold=score_threshold, + pre_top_k=pre_top_k, + keep_top_k=keep_top_k) diff --git a/mmdeploy/mmcv/ops/point_sample.py b/mmdeploy/mmcv/ops/point_sample.py index 7f2e43ecf..8051b708d 100644 --- a/mmdeploy/mmcv/ops/point_sample.py +++ b/mmdeploy/mmcv/ops/point_sample.py @@ -6,7 +6,7 @@ from mmdeploy.core import FUNCTION_REWRITER @FUNCTION_REWRITER.register_rewriter( func_name='mmcv.ops.point_sample', backend='default') -def point_sample__default(ctx, input, points, align_corners=False, **kwargs): +def point_sample__default(input, points, align_corners=False, **kwargs): """A wrapper around :func:`grid_sample` to support 3D point_coords tensors Unlike :func:`torch.nn.functional.grid_sample` it assumes point_coords to lie inside ``[0, 1] x [0, 1]`` square. @@ -37,7 +37,7 @@ def point_sample__default(ctx, input, points, align_corners=False, **kwargs): @FUNCTION_REWRITER.register_rewriter( func_name='mmcv.ops.SimpleRoIAlign.forward') -def simple_roialign__forward(ctx, self, features, rois): +def simple_roialign__forward(self, features, rois): """Rewrite `forward` of SimpleRoIAlign. Args: diff --git a/mmdeploy/mmcv/ops/roi_align.py b/mmdeploy/mmcv/ops/roi_align.py index d7eaaa514..6ee901a04 100644 --- a/mmdeploy/mmcv/ops/roi_align.py +++ b/mmdeploy/mmcv/ops/roi_align.py @@ -13,9 +13,9 @@ from mmdeploy.utils import Backend, get_backend, get_ir_config # visible in mmcv. @SYMBOLIC_REWRITER.register_symbolic( 'mmcv.ops.roi_align.__self__', backend='default') -def roi_align_default(ctx, g, input: Tensor, rois: Tensor, - output_size: List[int], spatial_scale: float, - sampling_ratio: int, pool_mode: str, aligned: bool): +def roi_align_default(g, input: Tensor, rois: Tensor, output_size: List[int], + spatial_scale: float, sampling_ratio: int, + pool_mode: str, aligned: bool): """Rewrite symbolic function for default backend. Replace onnx::RoiAlign with mmcv::MMCVRoiAlign for PPLNN. For ONNXRuntime, @@ -41,6 +41,7 @@ def roi_align_default(ctx, g, input: Tensor, rois: Tensor, Returns: MMCVRoiAlign op for onnx. """ + ctx = SYMBOLIC_REWRITER.get_context() backend = get_backend(ctx.cfg) if backend == Backend.PPLNN or backend == Backend.TENSORRT: domain = 'mmcv' diff --git a/mmdeploy/mmcv/ops/roi_align_rotated.py b/mmdeploy/mmcv/ops/roi_align_rotated.py index f7707071d..90c2e0414 100644 --- a/mmdeploy/mmcv/ops/roi_align_rotated.py +++ b/mmdeploy/mmcv/ops/roi_align_rotated.py @@ -11,7 +11,7 @@ from mmdeploy.core import SYMBOLIC_REWRITER # is not visible in mmcv. @SYMBOLIC_REWRITER.register_symbolic( 'mmcv.ops.roi_align_rotated.__self__', backend='default') -def roi_align_rotated_default(ctx, g, input: Tensor, rois: Tensor, +def roi_align_rotated_default(g, input: Tensor, rois: Tensor, output_size: List[int], spatial_scale: float, sampling_ratio: int, aligned: bool, clockwise: bool): diff --git a/mmdeploy/mmcv/ops/transformer.py b/mmdeploy/mmcv/ops/transformer.py index 53f7f550b..bf020cce2 100644 --- a/mmdeploy/mmcv/ops/transformer.py +++ b/mmdeploy/mmcv/ops/transformer.py @@ -6,7 +6,7 @@ from mmdeploy.utils import Backend @FUNCTION_REWRITER.register_rewriter( func_name='mmcv.cnn.bricks.transformer.PatchEmbed.forward', backend=Backend.NCNN.value) -def patch_embed__forward__ncnn(ctx, self, x): +def patch_embed__forward__ncnn(self, x): """Rewrite `forward` of PatchEmbed for ncnn backend. Args: diff --git a/mmdeploy/pytorch/functions/adaptive_pool.py b/mmdeploy/pytorch/functions/adaptive_pool.py index fb09cd82e..14a185ed8 100644 --- a/mmdeploy/pytorch/functions/adaptive_pool.py +++ b/mmdeploy/pytorch/functions/adaptive_pool.py @@ -9,8 +9,9 @@ from mmdeploy.utils import Backend, get_root_logger, is_dynamic_shape @FUNCTION_REWRITER.register_rewriter( func_name='torch.nn.functional.adaptive_avg_pool2d') -def adaptive_avg_pool2d__default(ctx, input, output_size): +def adaptive_avg_pool2d__default(input, output_size): """Rewrite `adaptive_avg_pool2d` for default backend.""" + ctx = FUNCTION_REWRITER.get_context() output_size = _pair(output_size) if int(output_size[0]) == int(output_size[1]) == 1: out = ctx.origin_func(input, output_size) @@ -39,6 +40,7 @@ def adaptive_avg_pool2d__default(ctx, input, output_size): @FUNCTION_REWRITER.register_rewriter( func_name='torch.nn.functional.adaptive_avg_pool2d', backend=Backend.TORCHSCRIPT.value) -def adaptive_avg_pool2d__ncnn(ctx, input, output_size): +def adaptive_avg_pool2d__ncnn(input, output_size): + ctx = FUNCTION_REWRITER.get_context() """Rewrite `adaptive_avg_pool2d` for ncnn and torchscript backend.""" return ctx.origin_func(input, output_size) diff --git a/mmdeploy/pytorch/functions/atan2.py b/mmdeploy/pytorch/functions/atan2.py index a09986a8f..90ce5d63d 100644 --- a/mmdeploy/pytorch/functions/atan2.py +++ b/mmdeploy/pytorch/functions/atan2.py @@ -7,7 +7,6 @@ from mmdeploy.core import FUNCTION_REWRITER @FUNCTION_REWRITER.register_rewriter( func_name='torch.atan2', backend='default') def atan2__default( - ctx, input1: torch.Tensor, input2: torch.Tensor, ): diff --git a/mmdeploy/pytorch/functions/chunk.py b/mmdeploy/pytorch/functions/chunk.py index 98ad1b2ef..29677b2ea 100644 --- a/mmdeploy/pytorch/functions/chunk.py +++ b/mmdeploy/pytorch/functions/chunk.py @@ -7,7 +7,7 @@ from mmdeploy.utils import IR @FUNCTION_REWRITER.register_rewriter( func_name='torch.Tensor.chunk', backend='ncnn') -def chunk__ncnn(ctx, self, num_chunks: int, dim: int = 0) -> torch.Tensor: +def chunk__ncnn(self, num_chunks: int, dim: int = 0) -> torch.Tensor: """Rewrite `chunk` for NCNN backend. Chunk in ncnn are not supported, so it should be rewritten. @@ -36,10 +36,7 @@ def chunk__ncnn(ctx, self, num_chunks: int, dim: int = 0) -> torch.Tensor: @FUNCTION_REWRITER.register_rewriter( func_name='torch.Tensor.chunk', ir=IR.TORCHSCRIPT) -def chunk__torchscript(ctx, - self, - num_chunks: int, - dim: int = 0) -> torch.Tensor: +def chunk__torchscript(self, num_chunks: int, dim: int = 0) -> torch.Tensor: """Rewrite `chunk` for Torchscript. Replace chunk op with split op diff --git a/mmdeploy/pytorch/functions/clip.py b/mmdeploy/pytorch/functions/clip.py index 88a9b6489..c550358f4 100644 --- a/mmdeploy/pytorch/functions/clip.py +++ b/mmdeploy/pytorch/functions/clip.py @@ -13,11 +13,12 @@ from mmdeploy.utils import Backend func_name='torch.Tensor.clamp', backend=Backend.COREML.value) @FUNCTION_REWRITER.register_rewriter( func_name='torch.clamp', backend=Backend.COREML.value) -def clip__coreml(ctx, input, min=None, max=None, **kwargs) -> torch.Tensor: +def clip__coreml(input, min=None, max=None, **kwargs) -> torch.Tensor: """Rewrite `clip` for coreml backend. Cast data type. """ + ctx = FUNCTION_REWRITER.get_context() if min is not None and not isinstance(min, torch.Tensor): min = input.new_tensor(min) diff --git a/mmdeploy/pytorch/functions/expand.py b/mmdeploy/pytorch/functions/expand.py index 0ae90f8a4..c2a1aba70 100644 --- a/mmdeploy/pytorch/functions/expand.py +++ b/mmdeploy/pytorch/functions/expand.py @@ -6,11 +6,12 @@ from mmdeploy.core import FUNCTION_REWRITER @FUNCTION_REWRITER.register_rewriter( func_name='torch.Tensor.expand', backend='ncnn') -def expand__ncnn(ctx, self, *sizes) -> torch.Tensor: +def expand__ncnn(self, *sizes) -> torch.Tensor: """Rewrite `expand` for NCNN backend. Do not expand on batch dim for tensor with ndim >= 3 """ + ctx = FUNCTION_REWRITER.get_context() if self.ndim < 3 or sizes[0] not in [1, -1]: return ctx.origin_func(*sizes) return self diff --git a/mmdeploy/pytorch/functions/flatten.py b/mmdeploy/pytorch/functions/flatten.py index d8d40dd54..7270f32fd 100644 --- a/mmdeploy/pytorch/functions/flatten.py +++ b/mmdeploy/pytorch/functions/flatten.py @@ -13,7 +13,7 @@ from mmdeploy.utils import Backend func_name='torch.Tensor.flatten', backend=Backend.COREML.value) @FUNCTION_REWRITER.register_rewriter( func_name='torch.flatten', backend=Backend.COREML.value) -def flatten__coreml(ctx, input, start_dim=0, end_dim=-1) -> torch.Tensor: +def flatten__coreml(input, start_dim=0, end_dim=-1) -> torch.Tensor: """Rewrite `flatten` for coreml backend. Use reshape instead of flatten diff --git a/mmdeploy/pytorch/functions/getattribute.py b/mmdeploy/pytorch/functions/getattribute.py index 8447aca8b..74e9bfa0b 100644 --- a/mmdeploy/pytorch/functions/getattribute.py +++ b/mmdeploy/pytorch/functions/getattribute.py @@ -6,13 +6,14 @@ from mmdeploy.core import FUNCTION_REWRITER @FUNCTION_REWRITER.register_rewriter( func_name='torch.Tensor.__getattribute__', backend='ncnn') -def tensor__getattribute__ncnn(ctx, self: torch.Tensor, name: str): +def tensor__getattribute__ncnn(self: torch.Tensor, name: str): """Rewrite `__getattribute__` of `torch.Tensor` for ncnn backend. Shape node is not supported by ncnn. This function transform dynamic shape to constant shape. """ + ctx = FUNCTION_REWRITER.get_context() ret = ctx.origin_func(self, name) if name == 'shape': ret = torch.Size([int(s) for s in ret]) diff --git a/mmdeploy/pytorch/functions/group_norm.py b/mmdeploy/pytorch/functions/group_norm.py index 393fd720d..25fe2b98a 100644 --- a/mmdeploy/pytorch/functions/group_norm.py +++ b/mmdeploy/pytorch/functions/group_norm.py @@ -9,7 +9,6 @@ from mmdeploy.core import FUNCTION_REWRITER @FUNCTION_REWRITER.register_rewriter( func_name='torch.nn.functional.group_norm', backend='ncnn') def group_norm__ncnn( - ctx, input: torch.Tensor, num_groups: int, weight: Union[torch.Tensor, torch.NoneType] = None, diff --git a/mmdeploy/pytorch/functions/interpolate.py b/mmdeploy/pytorch/functions/interpolate.py index a335792f0..39424b8a3 100644 --- a/mmdeploy/pytorch/functions/interpolate.py +++ b/mmdeploy/pytorch/functions/interpolate.py @@ -10,8 +10,7 @@ from mmdeploy.utils import Backend, get_root_logger @FUNCTION_REWRITER.register_rewriter( func_name='torch.nn.functional.interpolate', backend='ncnn') -def interpolate__ncnn(ctx, - input: torch.Tensor, +def interpolate__ncnn(input: torch.Tensor, size: Optional[Union[int, Tuple[int], Tuple[int, int], Tuple[int, int, int]]] = None, scale_factor: Optional[Union[float, @@ -24,6 +23,7 @@ def interpolate__ncnn(ctx, ncnn require `size` should be constant in ONNX Node. We use `scale_factor` instead of `size` to avoid dynamic size. """ + ctx = FUNCTION_REWRITER.get_context() input_size = input.shape if scale_factor is None: @@ -42,8 +42,7 @@ def interpolate__ncnn(ctx, @FUNCTION_REWRITER.register_rewriter( func_name='torch.nn.functional.interpolate', backend='rknn') -def interpolate__rknn(ctx, - input: torch.Tensor, +def interpolate__rknn(input: torch.Tensor, size: Optional[Union[int, Tuple[int], Tuple[int, int], Tuple[int, int, int]]] = None, scale_factor: Optional[Union[float, @@ -56,6 +55,7 @@ def interpolate__rknn(ctx, rknn require `size` should be constant in ONNX Node. We use `scale_factor` instead of `size` to avoid dynamic size. """ + ctx = FUNCTION_REWRITER.get_context() input_size = input.shape if scale_factor is None: scale_factor = [(s_out / s_in) @@ -77,7 +77,6 @@ def interpolate__rknn(ctx, is_pytorch=True, backend=Backend.TENSORRT.value) def interpolate__tensorrt( - ctx, input: torch.Tensor, size: Optional[Union[int, Tuple[int], Tuple[int, int], Tuple[int, int, int]]] = None, @@ -87,6 +86,7 @@ def interpolate__tensorrt( recompute_scale_factor: Optional[bool] = None, ): """Register default symbolic function for `interpolate`.""" + ctx = FUNCTION_REWRITER.get_context() class BicubicInterpolate(Function): diff --git a/mmdeploy/pytorch/functions/linear.py b/mmdeploy/pytorch/functions/linear.py index 7cfb4735a..616fef732 100644 --- a/mmdeploy/pytorch/functions/linear.py +++ b/mmdeploy/pytorch/functions/linear.py @@ -30,7 +30,6 @@ class GemmOp(torch.autograd.Function): @FUNCTION_REWRITER.register_rewriter( func_name='torch.nn.functional.linear', backend='ncnn') def linear__ncnn( - ctx, input: torch.Tensor, weight: torch.Tensor, bias: Optional[Union[torch.Tensor, torch.NoneType]] = None, @@ -41,6 +40,7 @@ def linear__ncnn( add extra reshape and transpose to support linear operation of different input shape. """ + ctx = FUNCTION_REWRITER.get_context() origin_func = ctx.origin_func dim = input.dim() diff --git a/mmdeploy/pytorch/functions/masked_fill.py b/mmdeploy/pytorch/functions/masked_fill.py index 5e4f67b45..bd8cd7b6c 100644 --- a/mmdeploy/pytorch/functions/masked_fill.py +++ b/mmdeploy/pytorch/functions/masked_fill.py @@ -13,13 +13,14 @@ from mmdeploy.utils.constants import Backend @FUNCTION_REWRITER.register_rewriter( func_name='torch.Tensor.masked_fill', backend=Backend.ONNXRUNTIME.value) def masked_fill__onnxruntime( - ctx, input, mask: torch.Tensor, value: Union[torch.Tensor, - Number]) -> torch.Tensor: + input, mask: torch.Tensor, value: Union[torch.Tensor, + Number]) -> torch.Tensor: """Rewrite `masked_fill` for onnxruntime backend. SATRN model as example, when value is set to `float('-inf')`, the results of ORT inferencing turns out to be NAN. """ + ctx = FUNCTION_REWRITER.get_context() if value == float('-inf'): value = -1e34 # hard coding number return ctx.origin_func(input, mask, value) diff --git a/mmdeploy/pytorch/functions/mod.py b/mmdeploy/pytorch/functions/mod.py index e6bb1cb51..bd1bd77d1 100644 --- a/mmdeploy/pytorch/functions/mod.py +++ b/mmdeploy/pytorch/functions/mod.py @@ -11,10 +11,11 @@ from mmdeploy.utils.constants import Backend # TODO add version control when MOD is supported by TensorRT @FUNCTION_REWRITER.register_rewriter( func_name='torch.Tensor.__mod__', backend=Backend.TENSORRT.value) -def mod__tensorrt(ctx, input: torch.Tensor, other: Union[torch.Tensor, - torch.NumberType], - *args, **kwargs) -> torch.Tensor: +def mod__tensorrt(input: torch.Tensor, other: Union[torch.Tensor, + torch.NumberType], *args, + **kwargs) -> torch.Tensor: """Rewrite `mod` when exporting model to ONNX for TensorRT backend.""" + ctx = FUNCTION_REWRITER.get_context() if version.parse(torch.__version__) > version.parse('1.10.0'): return input - (input // other) * other return ctx.origin_func(input, other, *args, **kwargs) diff --git a/mmdeploy/pytorch/functions/multi_head_attention_forward.py b/mmdeploy/pytorch/functions/multi_head_attention_forward.py index 8d165649c..fadfc3e91 100644 --- a/mmdeploy/pytorch/functions/multi_head_attention_forward.py +++ b/mmdeploy/pytorch/functions/multi_head_attention_forward.py @@ -46,7 +46,6 @@ class ScaledDotProductAttentionTRT(torch.autograd.Function): func_name='torch.nn.functional._scaled_dot_product_attention', backend=Backend.TENSORRT.value) def _scaled_dot_product_attention__tensorrt( - ctx, q: Tensor, k: Tensor, v: Tensor, diff --git a/mmdeploy/pytorch/functions/normalize.py b/mmdeploy/pytorch/functions/normalize.py index a676439cd..b0ae4ccfe 100644 --- a/mmdeploy/pytorch/functions/normalize.py +++ b/mmdeploy/pytorch/functions/normalize.py @@ -7,8 +7,7 @@ from mmdeploy.core import FUNCTION_REWRITER @FUNCTION_REWRITER.register_rewriter( func_name='torch.nn.functional.normalize', backend='ncnn') -def normalize__ncnn(ctx, - input: torch.Tensor, +def normalize__ncnn(input: torch.Tensor, p: int = 2, dim: int = 1, eps: float = 1e-12, @@ -18,6 +17,7 @@ def normalize__ncnn(ctx, Make sure L2 norm on channel dim and be exported to ncnn correctly. """ + ctx = FUNCTION_REWRITER.get_context() if dim < 0: dim += input.ndim assert dim != 0, 'Should not normalize on batch index' diff --git a/mmdeploy/pytorch/functions/pad.py b/mmdeploy/pytorch/functions/pad.py index 7f24785e8..82274d5c0 100644 --- a/mmdeploy/pytorch/functions/pad.py +++ b/mmdeploy/pytorch/functions/pad.py @@ -11,7 +11,7 @@ from mmdeploy.core import FUNCTION_REWRITER @FUNCTION_REWRITER.register_rewriter( func_name='torch.onnx.symbolic_opset11._prepare_onnx_paddings', backend='tensorrt') -def _prepare_onnx_paddings__tensorrt(ctx, g, input, pad): +def _prepare_onnx_paddings__tensorrt(g, input, pad): """Rewrite `_prepare_onnx_paddings` for TensorRT backend. For codes like `x = torch.nn.ZeroPad2d((0, a, 0, b))(x)`, where a and b are @@ -26,6 +26,7 @@ def _prepare_onnx_paddings__tensorrt(ctx, g, input, pad): ..., dim_m_begin, dim_m_end, where m is in range [0, n]. """ + ctx = FUNCTION_REWRITER.get_context() torch_version = version_parse(torch.__version__) if torch_version.major == 1 and torch_version.minor < 10: return ctx.origin_func(g, input, pad) diff --git a/mmdeploy/pytorch/functions/repeat.py b/mmdeploy/pytorch/functions/repeat.py index fa528c33f..edb6efc3a 100644 --- a/mmdeploy/pytorch/functions/repeat.py +++ b/mmdeploy/pytorch/functions/repeat.py @@ -8,13 +8,14 @@ from mmdeploy.core import FUNCTION_REWRITER @FUNCTION_REWRITER.register_rewriter( func_name='torch.Tensor.repeat', backend='tensorrt') -def tensor__repeat__tensorrt(ctx, input: torch.Tensor, - *size: Union[torch.Size, Sequence[int]]): +def tensor__repeat__tensorrt(input: torch.Tensor, *size: Union[torch.Size, + Sequence[int]]): """Rewrite `repeat` for TensorRT backend. Some layers in TensorRT can not be applied on batch axis. add extra axis before operation and remove it afterward. """ + ctx = FUNCTION_REWRITER.get_context() origin_func = ctx.origin_func if input.dim() == 1 and len(size) == 1: diff --git a/mmdeploy/pytorch/functions/size.py b/mmdeploy/pytorch/functions/size.py index 30ead981a..8325f115c 100644 --- a/mmdeploy/pytorch/functions/size.py +++ b/mmdeploy/pytorch/functions/size.py @@ -6,12 +6,13 @@ from mmdeploy.core import FUNCTION_REWRITER @FUNCTION_REWRITER.register_rewriter( func_name='torch.Tensor.size', backend='ncnn') -def tensor__size__ncnn(ctx, self, *args): +def tensor__size__ncnn(self, *args): """Rewrite `size` for ncnn backend. ONNX Shape node is not supported in ncnn. This function return integer instead of Torch.Size to avoid ONNX Shape node. """ + ctx = FUNCTION_REWRITER.get_context() ret = ctx.origin_func(self, *args) if isinstance(ret, torch.Tensor): @@ -26,11 +27,12 @@ def tensor__size__ncnn(ctx, self, *args): @FUNCTION_REWRITER.register_rewriter( func_name='torch.Tensor.size', backend='ascend') -def tensor__size__ascend(ctx, self, *args): +def tensor__size__ascend(self, *args): """Rewrite `size` for ascens backend. Support negative index. """ + ctx = FUNCTION_REWRITER.get_context() if len(args) != 0: index = args[0] diff --git a/mmdeploy/pytorch/functions/tensor_getitem.py b/mmdeploy/pytorch/functions/tensor_getitem.py index 7454a5a6d..17187eeb9 100644 --- a/mmdeploy/pytorch/functions/tensor_getitem.py +++ b/mmdeploy/pytorch/functions/tensor_getitem.py @@ -8,11 +8,12 @@ from mmdeploy.core import FUNCTION_REWRITER @FUNCTION_REWRITER.register_rewriter( func_name='torch.Tensor.__getitem__', backend='ascend') -def tensor__getitem__ascend(ctx, self, key) -> torch.Tensor: +def tensor__getitem__ascend(self, key) -> torch.Tensor: """Rewrite `getitem` for ascend backend. Ascend does not support negative select """ + ctx = FUNCTION_REWRITER.get_context() if not isinstance(key, (tuple, list)): if isinstance(key, int) and key < 0: key = self.dim() + key diff --git a/mmdeploy/pytorch/functions/tensor_setitem.py b/mmdeploy/pytorch/functions/tensor_setitem.py index 6795bc241..4860bbe14 100644 --- a/mmdeploy/pytorch/functions/tensor_setitem.py +++ b/mmdeploy/pytorch/functions/tensor_setitem.py @@ -8,8 +8,9 @@ from mmdeploy.core import FUNCTION_REWRITER, SYMBOLIC_REWRITER @FUNCTION_REWRITER.register_rewriter(func_name='torch.Tensor.__setitem__') -def tensor__setitem__default(ctx, self, key, value): +def tensor__setitem__default(self, key, value): """Rewrite `setitem` to ease the index put.""" + ctx = FUNCTION_REWRITER.get_context() # only support torch>=1.9.0 if parse(torch.__version__) < parse('1.9.0'): @@ -76,5 +77,5 @@ def tensor__setitem__default(ctx, self, key, value): if parse(torch.__version__) >= parse('1.12.0'): @SYMBOLIC_REWRITER.register_symbolic('copy', is_pytorch=True) - def copy__default(ctx, g, x, y, non_blocking): + def copy__default(g, x, y, non_blocking): return x diff --git a/mmdeploy/pytorch/functions/topk.py b/mmdeploy/pytorch/functions/topk.py index 82569250b..38dac1978 100644 --- a/mmdeploy/pytorch/functions/topk.py +++ b/mmdeploy/pytorch/functions/topk.py @@ -10,8 +10,7 @@ from mmdeploy.utils import get_root_logger @FUNCTION_REWRITER.register_rewriter(func_name='torch.topk', backend='default') @FUNCTION_REWRITER.register_rewriter( func_name='torch.Tensor.topk', backend='default') -def topk__dynamic(ctx, - input: torch.Tensor, +def topk__dynamic(input: torch.Tensor, k: int, dim: Optional[int] = None, largest: bool = True, @@ -20,6 +19,7 @@ def topk__dynamic(ctx, Cast k to tensor and makesure k is smaller than input.shape[dim]. """ + ctx = FUNCTION_REWRITER.get_context() if dim is None: dim = int(input.ndim - 1) @@ -37,8 +37,7 @@ def topk__dynamic(ctx, func_name='torch.topk', backend='tensorrt') @FUNCTION_REWRITER.register_rewriter( func_name='torch.Tensor.topk', backend='tensorrt') -def topk__tensorrt(ctx, - input: torch.Tensor, +def topk__tensorrt(input: torch.Tensor, k: int, dim: Optional[int] = None, largest: bool = True, @@ -48,6 +47,7 @@ def topk__tensorrt(ctx, TensorRT does not support topk with dynamic k. This function cast k to constant integer. """ + ctx = FUNCTION_REWRITER.get_context() # https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#topKsetup from mmdeploy.utils.constants import TENSORRT_MAX_TOPK diff --git a/mmdeploy/pytorch/functions/triu.py b/mmdeploy/pytorch/functions/triu.py index 025b2029f..e7e8e501e 100644 --- a/mmdeploy/pytorch/functions/triu.py +++ b/mmdeploy/pytorch/functions/triu.py @@ -5,8 +5,7 @@ from mmdeploy.core import FUNCTION_REWRITER @FUNCTION_REWRITER.register_rewriter(func_name='torch.triu') -def triu__default(ctx, - input: torch.Tensor, +def triu__default(input: torch.Tensor, diagonal: int = 0, *args, **kwargs) -> torch.Tensor: diff --git a/mmdeploy/pytorch/symbolics/adaptive_pool.py b/mmdeploy/pytorch/symbolics/adaptive_pool.py index d27049576..a3461313a 100644 --- a/mmdeploy/pytorch/symbolics/adaptive_pool.py +++ b/mmdeploy/pytorch/symbolics/adaptive_pool.py @@ -5,7 +5,7 @@ from mmdeploy.core import SYMBOLIC_REWRITER @SYMBOLIC_REWRITER.register_symbolic( 'adaptive_avg_pool2d', is_pytorch=True, backend='ncnn') -def adaptive_avg_pool2d__ncnn(ctx, g, x, output_size): +def adaptive_avg_pool2d__ncnn(g, x, output_size): """Register ncnn symbolic function for `adaptive_avg_pool2d`. Align symbolic of adaptive_avg_pool2d in ncnn. diff --git a/mmdeploy/pytorch/symbolics/gelu.py b/mmdeploy/pytorch/symbolics/gelu.py index 039e5a114..3d9131181 100644 --- a/mmdeploy/pytorch/symbolics/gelu.py +++ b/mmdeploy/pytorch/symbolics/gelu.py @@ -1,11 +1,18 @@ # Copyright (c) OpenMMLab. All rights reserved. +from torch.onnx import symbolic_helper from mmdeploy.core import SYMBOLIC_REWRITER from mmdeploy.utils import Backend -@SYMBOLIC_REWRITER.register_symbolic( - 'gelu', is_pytorch=True, arg_descriptors=['v'], backend=Backend.NCNN.value) -def gelu__ncnn(ctx, g, self): - """Support export GELU with ncnn backend.""" +@symbolic_helper.parse_args('v') +def gelu__ncnn_pt111(g, self): + """gelu for torch<=1.12.""" return g.op('mmdeploy::Gelu', self) + + +@SYMBOLIC_REWRITER.register_symbolic( + 'gelu', is_pytorch=True, backend=Backend.NCNN.value) +def gelu__ncnn(g, self, approximate: str = 'none'): + """Support export GELU with ncnn backend.""" + return gelu__ncnn_pt111(g, self) diff --git a/mmdeploy/pytorch/symbolics/grid_sampler.py b/mmdeploy/pytorch/symbolics/grid_sampler.py index b6fdbf5c7..0e3e10510 100644 --- a/mmdeploy/pytorch/symbolics/grid_sampler.py +++ b/mmdeploy/pytorch/symbolics/grid_sampler.py @@ -50,11 +50,12 @@ def grid_sampler_ppl(g, @SYMBOLIC_REWRITER.register_symbolic('grid_sampler', is_pytorch=True) -def grid_sampler__default(ctx, *args): +def grid_sampler__default(*args): """Register default symbolic function for `grid_sampler`. Add support to grid_sample to ONNX. """ + ctx = SYMBOLIC_REWRITER.get_context() backend = get_backend(ctx.cfg) if backend == Backend.PPLNN: return grid_sampler_ppl(*args) diff --git a/mmdeploy/pytorch/symbolics/hardsigmoid.py b/mmdeploy/pytorch/symbolics/hardsigmoid.py index a4d14173e..27561685e 100644 --- a/mmdeploy/pytorch/symbolics/hardsigmoid.py +++ b/mmdeploy/pytorch/symbolics/hardsigmoid.py @@ -6,7 +6,7 @@ from mmdeploy.core import SYMBOLIC_REWRITER @SYMBOLIC_REWRITER.register_symbolic( 'hardsigmoid', is_pytorch=True, arg_descriptors=['v']) -def hardsigmoid__default(ctx, g, self): +def hardsigmoid__default(g, self): """Support export hardsigmoid This rewrite enable export hardsigmoid in torch<=1.8.2.""" return g.op('HardSigmoid', self, alpha_f=1 / 6) diff --git a/mmdeploy/pytorch/symbolics/instance_norm.py b/mmdeploy/pytorch/symbolics/instance_norm.py index c04e42528..06d287574 100644 --- a/mmdeploy/pytorch/symbolics/instance_norm.py +++ b/mmdeploy/pytorch/symbolics/instance_norm.py @@ -64,7 +64,7 @@ def instance_norm(g, input, num_groups, weight, bias, eps, cudnn_enabled): @SYMBOLIC_REWRITER.register_symbolic( 'group_norm', backend='tensorrt', is_pytorch=True) -def instance_norm__tensorrt(ctx, *args): +def instance_norm__tensorrt(*args): """Register symbolic function for TensorRT backend. Notes: diff --git a/mmdeploy/pytorch/symbolics/layer_norm.py b/mmdeploy/pytorch/symbolics/layer_norm.py index 94ea0169a..854ef5fd7 100644 --- a/mmdeploy/pytorch/symbolics/layer_norm.py +++ b/mmdeploy/pytorch/symbolics/layer_norm.py @@ -12,7 +12,7 @@ from mmdeploy.utils import Backend 'layer_norm', is_pytorch=True, arg_descriptors=['v', 'is', 'v', 'v', 'f', 'i']) -def layer_norm__default(ctx, g, input, normalized_shape, weight, bias, eps, +def layer_norm__default(g, input, normalized_shape, weight, bias, eps, cudnn_enable): """Symbolic function for `layer_norm` @@ -62,7 +62,7 @@ def _layer_norm_ncnn(g, input, normalized_shape, weight, bias, eps, @SYMBOLIC_REWRITER.register_symbolic( 'layer_norm', is_pytorch=True, backend=Backend.NCNN.value) -def layer_norm__ncnn(ctx, *args): +def layer_norm__ncnn(*args): """Register default symbolic function for `layer_norm`. Add support to layer_norm to ONNX. diff --git a/mmdeploy/pytorch/symbolics/linear.py b/mmdeploy/pytorch/symbolics/linear.py index 8cb997b40..3236d71bf 100644 --- a/mmdeploy/pytorch/symbolics/linear.py +++ b/mmdeploy/pytorch/symbolics/linear.py @@ -36,7 +36,7 @@ def linear_normal(g, input, weight, bias): @SYMBOLIC_REWRITER.register_symbolic( 'linear', is_pytorch=True, backend=Backend.NCNN.value) -def linear__ncnn(ctx, g, input, weight, bias): +def linear__ncnn(g, input, weight, bias): """Support export linear This rewrite enable export Gemm.""" if bias is None: return linear_no_bias(g, input, weight) diff --git a/mmdeploy/pytorch/symbolics/lstm.py b/mmdeploy/pytorch/symbolics/lstm.py index 3b8926186..2316ef28b 100644 --- a/mmdeploy/pytorch/symbolics/lstm.py +++ b/mmdeploy/pytorch/symbolics/lstm.py @@ -13,8 +13,7 @@ from mmdeploy.core import FUNCTION_REWRITER @FUNCTION_REWRITER.register_rewriter( func_name='torch.onnx.symbolic_opset9._generic_rnn', backend='ncnn') -def generic_rnn__ncnn(ctx, - g, +def generic_rnn__ncnn(g, variant, input, initial_states, diff --git a/mmdeploy/pytorch/symbolics/roll.py b/mmdeploy/pytorch/symbolics/roll.py index 34b892045..7151990d1 100644 --- a/mmdeploy/pytorch/symbolics/roll.py +++ b/mmdeploy/pytorch/symbolics/roll.py @@ -28,6 +28,6 @@ def roll(g, self, shifts, dims): @SYMBOLIC_REWRITER.register_symbolic('roll', is_pytorch=True) -def roll_default(ctx, g, self, shifts, dims): +def roll_default(g, self, shifts, dims): """Support export roll to ONNX with PyTorch version 1.10-.""" return roll(g, self, shifts, dims) diff --git a/mmdeploy/pytorch/symbolics/squeeze.py b/mmdeploy/pytorch/symbolics/squeeze.py index ffcac55be..1484fa3dc 100644 --- a/mmdeploy/pytorch/symbolics/squeeze.py +++ b/mmdeploy/pytorch/symbolics/squeeze.py @@ -5,7 +5,7 @@ from mmdeploy.core import SYMBOLIC_REWRITER @SYMBOLIC_REWRITER.register_symbolic('squeeze', is_pytorch=True) -def squeeze__default(ctx, g, self, dim=None): +def squeeze__default(g, self, dim=None): """Register default symbolic function for `squeeze`. squeeze might be exported with IF node in ONNX, which is not supported in diff --git a/tests/test_core/test_function_rewriter.py b/tests/test_core/test_function_rewriter.py index ca7a681c3..d4e33a185 100644 --- a/tests/test_core/test_function_rewriter.py +++ b/tests/test_core/test_function_rewriter.py @@ -16,9 +16,10 @@ def test_function_rewriter(): func_name='torch.mul', backend='tensorrt') @FUNCTION_REWRITER.register_rewriter( func_name='torch.add', backend='tensorrt') - def sub_func(rewriter, x, y): - assert hasattr(rewriter, 'cfg') - assert hasattr(rewriter, 'origin_func') + def sub_func(x, y): + ctx = FUNCTION_REWRITER.get_context('torch.add') + assert hasattr(ctx, 'cfg') + assert hasattr(ctx, 'origin_func') return x - y cfg = dict() @@ -42,7 +43,7 @@ def test_function_rewriter(): # test different config @FUNCTION_REWRITER.register_rewriter( func_name='torch.Tensor.add', backend='default') - def mul_func_class(rewriter, x, y): + def mul_func_class(x, y): return x * y with RewriterContext(cfg, backend='tensorrt'): @@ -62,8 +63,9 @@ def test_function_rewriter(): # test origin_func @FUNCTION_REWRITER.register_rewriter( func_name='torch.add', backend='default') - def origin_add_func(rewriter, x, y, **kwargs): - return rewriter.origin_func(x, y, **kwargs) + 1 + def origin_add_func(x, y, **kwargs): + ctx = FUNCTION_REWRITER.get_context('torch.add') + return ctx.origin_func(x, y, **kwargs) + 1 with RewriterContext(cfg): result = torch.add(x, y) @@ -79,7 +81,7 @@ def test_rewrite_empty_function(): function_rewriter = FunctionRewriter() @function_rewriter.register_rewriter(func_name='torch.abcdefghijklmn') - def func(rewriter, x, y): + def func(x, y): return x + y function_rewriter.enter() @@ -101,12 +103,12 @@ class TestHomonymicRewriter: assert c.method() == 1 @function_rewriter.register_rewriter(func_name=path1) - def func_2(ctx, self): + def func_2(self): return 2 @function_rewriter.register_rewriter( func_name=path2, backend=Backend.NCNN.value) - def func_3(ctx, self): + def func_3(self): return 3 function_rewriter.enter(env=collect_env(Backend.NCNN, ir=IR.DEFAULT)) @@ -119,11 +121,11 @@ class TestHomonymicRewriter: @function_rewriter2.register_rewriter( func_name=path1, backend=Backend.NCNN.value) - def func_4(ctx, self): + def func_4(self): return 4 @function_rewriter2.register_rewriter(func_name=path2) - def func_5(ctx, self): + def func_5(self): return 5 function_rewriter2.enter(env=collect_env(Backend.NCNN, ir=IR.DEFAULT)) @@ -147,12 +149,12 @@ def test_rewrite_derived_methods(): function_rewriter = FunctionRewriter() @function_rewriter.register_rewriter(func_name=path1) - def func_2(ctx, self): + def func_2(self): return 2 @function_rewriter.register_rewriter( func_name=path2, backend=Backend.NCNN.value) - def func_3(ctx, self): + def func_3(self): return 3 function_rewriter.enter(env=collect_env(Backend.DEFAULT, ir=IR.DEFAULT)) diff --git a/tests/test_core/test_symbolic_register.py b/tests/test_core/test_symbolic_register.py index b012f6a8b..96bebfddf 100644 --- a/tests/test_core/test_symbolic_register.py +++ b/tests/test_core/test_symbolic_register.py @@ -40,18 +40,19 @@ def test_symbolic_rewriter(): @SYMBOLIC_REWRITER.register_symbolic('mmdeploy.TestFunc', backend='ncnn') @SYMBOLIC_REWRITER.register_symbolic('mmdeploy.TestFunc') - def symbolic_testfunc_default(symbolic_wrapper, g, x, val): - assert hasattr(symbolic_wrapper, 'cfg') + def symbolic_testfunc_default(g, x, val): + ctx = SYMBOLIC_REWRITER.get_context('mmdeploy.TestFunc') + assert hasattr(ctx, 'cfg') return g.op('mmdeploy::symbolic_testfunc_default', x, val_i=val) @SYMBOLIC_REWRITER.register_symbolic( 'mmdeploy.TestFunc', backend='tensorrt') - def symbolic_testfunc_tensorrt(symbolic_wrapper, g, x, val): + def symbolic_testfunc_tensorrt(g, x, val): return g.op('mmdeploy::symbolic_testfunc_tensorrt', x, val_i=val) @SYMBOLIC_REWRITER.register_symbolic( 'cummax', is_pytorch=True, arg_descriptors=['v', 'i']) - def symbolic_cummax(symbolic_wrapper, g, input, dim): + def symbolic_cummax(g, input, dim): return g.op('mmdeploy::cummax_default', input, dim_i=dim, outputs=2) class TestModel(torch.nn.Module): @@ -103,12 +104,12 @@ def test_unregister(): test_func = mmdeploy.TestFunc.apply @SYMBOLIC_REWRITER.register_symbolic('mmdeploy.TestFunc') - def symbolic_testfunc_default(symbolic_wrapper, g, x, val): + def symbolic_testfunc_default(g, x, val): return g.op('mmdeploy::symbolic_testfunc_default', x, val_i=val) @SYMBOLIC_REWRITER.register_symbolic( 'cummax', is_pytorch=True, arg_descriptors=['v', 'i']) - def symbolic_cummax(symbolic_wrapper, g, input, dim): + def symbolic_cummax(g, input, dim): return g.op('mmdeploy::cummax_default', input, dim_i=dim, outputs=2) class TestModel(torch.nn.Module): @@ -159,7 +160,7 @@ def test_register_empty_symbolic(): symbolic_rewriter = SymbolicRewriter() @symbolic_rewriter.register_symbolic('mmdeploy.EmptyFunction') - def symbolic_testfunc_default(symbolic_wrapper, g, x, val): + def symbolic_testfunc_default(g, x, val): return g.op('mmdeploy::symbolic_testfunc_default', x, val_i=val) symbolic_rewriter.enter() diff --git a/tests/test_mmcv/test_mmcv_cnn.py b/tests/test_mmcv/test_mmcv_cnn.py index 4ff02438b..496197952 100644 --- a/tests/test_mmcv/test_mmcv_cnn.py +++ b/tests/test_mmcv/test_mmcv_cnn.py @@ -30,30 +30,3 @@ def test_multiheadattention_ncnn(): else: assert torch.allclose( model_outputs, rewrite_outputs[0], rtol=1e-03, atol=1e-05) - - -def test_conv2d_adaptive_padding_tensorrt(): - check_backend(Backend.TENSORRT) - from mmcv.cnn.bricks.conv2d_adaptive_padding import Conv2dAdaptivePadding - in_channels, out_channels = 3, 64 - kernel_sz = 3 - model = Conv2dAdaptivePadding(in_channels, out_channels, kernel_sz) - dummy_input = torch.rand(1, 3, 256, 256) - - deploy_cfg = Config( - dict( - onnx_config=dict(input_shape=None), - backend_config=dict(type=Backend.TENSORRT.value), - )) - model_outputs = model(dummy_input) - rewrite_inputs = dict(x=dummy_input) - rewrite_outputs, is_backend_output = get_rewrite_outputs( - wrapped_model=model, - model_inputs=rewrite_inputs, - deploy_cfg=deploy_cfg, - run_with_backend=True) - if is_backend_output is None: - assert rewrite_outputs is not None - else: - assert torch.allclose( - model_outputs, rewrite_outputs[0], rtol=1e-03, atol=1e-05) diff --git a/tests/test_mmcv/test_mmcv_ops.py b/tests/test_mmcv/test_mmcv_ops.py index 980442c09..8c3cb185e 100644 --- a/tests/test_mmcv/test_mmcv_ops.py +++ b/tests/test_mmcv/test_mmcv_ops.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp import tempfile import onnx @@ -6,6 +7,7 @@ import pytest import torch from mmengine import Config +from mmdeploy.apis.onnx import export from mmdeploy.core import RewriterContext from mmdeploy.utils import Backend from mmdeploy.utils.test import (WrapFunction, backend_checker, check_backend, @@ -38,16 +40,15 @@ def test_ONNXNMSop(iou_threshold, score_threshold, max_output_boxes_per_class): wrapped_model = WrapFunction(wrapped_function).eval() result = wrapped_model(boxes, scores) assert result is not None - onnx_file_path = tempfile.NamedTemporaryFile().name - with RewriterContext({}, opset=11), torch.no_grad(): - torch.onnx.export( - wrapped_model, (boxes, scores), - onnx_file_path, - export_params=True, - keep_initializers_as_inputs=True, - input_names=['boxes', 'scores'], - output_names=['result'], - opset_version=11) + onnx_file_path = tempfile.NamedTemporaryFile(suffix='.onnx').name + onnx_file_prefix = osp.splitext(onnx_file_path)[0] + export( + wrapped_model, (boxes, scores), + onnx_file_prefix, + keep_initializers_as_inputs=False, + input_names=['boxes', 'scores'], + output_names=['result'], + opset_version=11) model = onnx.load(onnx_file_path) assert model.graph.node[3].op_type == 'NonMaxSuppression'