[Refactor] Refactor rewriter context for MMRazor ()

* wip

* update rewriter

* Support all codebase

* update docs

* fix ssd

* rename qualname

* support torch.fx.wrap

* import by torch version

Co-authored-by: pppppM <gjf_mail@126.com>
pull/1536/head
q.yao 2022-12-13 19:03:56 +08:00 committed by GitHub
parent 78901a2451
commit 3f261e6d50
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
120 changed files with 575 additions and 447 deletions
mmdeploy

View File

@ -13,7 +13,7 @@ from mmdeploy.utils import get_root_logger
@FUNCTION_REWRITER.register_rewriter(
func_name='torchvision.models.ResNet._forward_impl')
def forward_of_resnet(ctx, self, x):
def forward_of_resnet(self, x):
"""Rewrite the forward implementation of resnet.
Early return the feature map after two down-sampling steps.

View File

@ -13,13 +13,13 @@ from mmdeploy.core import FUNCTION_REWRITER, mark
@mark(
'detector_forward', inputs=['input'], outputs=['dets', 'labels', 'masks'])
def __forward_impl(ctx, self, img, img_metas=None, **kwargs):
def __forward_impl(self, img, img_metas=None, **kwargs):
...
@FUNCTION_REWRITER.register_rewriter(
'mmdet.models.detectors.base.BaseDetector.forward')
def base_detector__forward(ctx, self, img, img_metas=None, **kwargs):
def base_detector__forward(self, img, img_metas=None, **kwargs):
...
# call the mark function
return __forward_impl(...)
@ -32,8 +32,7 @@ from mmdeploy.core import FUNCTION_REWRITER, mark
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.dense_heads.YOLOV3Head.get_bboxes')
def yolov3_head__get_bboxes(ctx,
self,
def yolov3_head__get_bboxes(self,
pred_maps,
img_metas,
cfg=None,

View File

@ -11,7 +11,8 @@ from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
func_name='torch.Tensor.repeat', backend='tensorrt')
def repeat_static(ctx, input, *size):
def repeat_static(input, *size):
ctx = FUNCTION_REWRITER.get_context()
origin_func = ctx.origin_func
if input.dim() == 1 and len(size) == 1:
return origin_func(input.unsqueeze(0), *([1] + list(size))).squeeze(0)
@ -72,7 +73,7 @@ The mappings between PyTorch and ONNX are defined in PyTorch with symbolic funct
```python
@SYMBOLIC_REWRITER.register_symbolic('squeeze', is_pytorch=True)
def squeeze_default(ctx, g, self, dim=None):
def squeeze_default(g, self, dim=None):
if dim is None:
dims = []
for i, size in enumerate(self.type().sizes()):

View File

@ -18,7 +18,7 @@ class BaseClassifier(BaseModule, metaclass=ABCMeta):
# Custom rewritten function
@FUNCTION_REWRITER.register_rewriter(
'mmcls.models.classifiers.BaseClassifier.forward', backend='default')
def forward_of_base_classifier(ctx, self, img, *args, **kwargs):
def forward_of_base_classifier(self, img, *args, **kwargs):
"""Rewrite `forward` for default backend."""
return self.simple_test(img, {})
```
@ -63,7 +63,8 @@ In the first example, the output is generated in Python. Sometimes we may make b
# Custom rewritten function
@FUNCTION_REWRITER.register_rewriter(
func_name='mmseg.models.segmentors.BaseSegmentor.forward')
def base_segmentor__forward(ctx, self, img, img_metas=None, **kwargs):
def base_segmentor__forward(self, img, img_metas=None, **kwargs):
ctx = FUNCTION_REWRITER.get_context()
if img_metas is None:
img_metas = {}
assert isinstance(img_metas, dict)

View File

@ -13,13 +13,13 @@ from mmdeploy.core import FUNCTION_REWRITER, mark
@mark(
'detector_forward', inputs=['input'], outputs=['dets', 'labels', 'masks'])
def __forward_impl(ctx, self, img, img_metas=None, **kwargs):
def __forward_impl(self, img, img_metas=None, **kwargs):
...
@FUNCTION_REWRITER.register_rewriter(
'mmdet.models.detectors.base.BaseDetector.forward')
def base_detector__forward(ctx, self, img, img_metas=None, **kwargs):
def base_detector__forward(self, img, img_metas=None, **kwargs):
...
# call the mark function
return __forward_impl(...)
@ -32,8 +32,7 @@ from mmdeploy.core import FUNCTION_REWRITER, mark
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.dense_heads.YOLOV3Head.get_bboxes')
def yolov3_head__get_bboxes(ctx,
self,
def yolov3_head__get_bboxes(self,
pred_maps,
img_metas,
cfg=None,

View File

@ -10,7 +10,8 @@ PyTorch 神经网络是用 python 编写的,可以简化算法的开发。但
from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
func_name='torch.Tensor.repeat', backend='tensorrt')
def repeat_static(ctx, input, *size):
def repeat_static(input, *size):
ctx = FUNCTION_REWRITER.get_context()
origin_func = ctx.origin_func
if input.dim() == 1 and len(size) == 1:
return origin_func(input.unsqueeze(0), *([1] + list(size))).squeeze(0)
@ -67,7 +68,7 @@ PyTorch 和 ONNX 之间的映射是通过 PyTorch 中的符号函数进行定义
```python
@SYMBOLIC_REWRITER.register_symbolic('squeeze', is_pytorch=True)
def squeeze_default(ctx, g, self, dim=None):
def squeeze_default(g, self, dim=None):
if dim is None:
dims = []
for i, size in enumerate(self.type().sizes()):

View File

@ -18,7 +18,7 @@ class BaseClassifier(BaseModule, metaclass=ABCMeta):
# Custom rewritten function
@FUNCTION_REWRITER.register_rewriter(
'mmcls.models.classifiers.BaseClassifier.forward', backend='default')
def forward_of_base_classifier(ctx, self, img, *args, **kwargs):
def forward_of_base_classifier(self, img, *args, **kwargs):
"""Rewrite `forward` for default backend."""
return self.simple_test(img, {})
```
@ -63,7 +63,8 @@ def test_baseclassfier_forward():
# Custom rewritten function
@FUNCTION_REWRITER.register_rewriter(
func_name='mmseg.models.segmentors.BaseSegmentor.forward')
def base_segmentor__forward(ctx, self, img, img_metas=None, **kwargs):
def base_segmentor__forward(self, img, img_metas=None, **kwargs):
ctx = FUNCTION_REWRITER.get_context()
if img_metas is None:
img_metas = {}
assert isinstance(img_metas, dict)

View File

@ -116,6 +116,15 @@ def export(model: torch.nn.Module,
input_metas, dict
), f'Expect input_metas type is dict, get {type(input_metas)}.'
model_forward = patched_model.forward
def wrap_forward(forward):
def wrapper(*arg, **kwargs):
return forward(*arg, **kwargs)
return wrapper
patched_model.forward = wrap_forward(patched_model.forward)
patched_model.forward = partial(patched_model.forward,
**input_metas)

View File

@ -5,8 +5,9 @@ from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter('torch.onnx.utils._model_to_graph')
def model_to_graph__custom_optimizer(ctx, *args, **kwargs):
def model_to_graph__custom_optimizer(*args, **kwargs):
"""Rewriter of _model_to_graph, add custom passes."""
ctx = FUNCTION_REWRITER.get_context()
graph, params_dict, torch_out = ctx.origin_func(*args, **kwargs)
custom_passes = getattr(ctx, 'onnx_custom_passes', None)
@ -23,8 +24,7 @@ def model_to_graph__custom_optimizer(ctx, *args, **kwargs):
@FUNCTION_REWRITER.register_rewriter(
'torch._C._jit_pass_onnx_deduplicate_initializers', backend='tensorrt')
def jit_pass_onnx_deduplicate_initializers__disable(ctx, graph, param_dict,
arg2):
def jit_pass_onnx_deduplicate_initializers__disable(graph, param_dict, arg2):
"""This pass will disable TensorRT topk export.
disable for TensorRT.
@ -34,6 +34,6 @@ def jit_pass_onnx_deduplicate_initializers__disable(ctx, graph, param_dict,
@FUNCTION_REWRITER.register_rewriter(
'torch._C._jit_pass_onnx_autograd_function_process')
def jit_pass_onnx_autograd_function_process__disable(ctx, graph):
def jit_pass_onnx_autograd_function_process__disable(graph):
"""Disable process autograph function."""
return

View File

@ -8,8 +8,7 @@ from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
'mmaction.models.recognizers.BaseRecognizer.forward')
def base_recognizer__forward(ctx,
self,
def base_recognizer__forward(self,
inputs: Tensor,
data_samples: OptSampleList = None,
mode: str = 'tensor',

View File

@ -6,7 +6,7 @@ from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
func_name='mmcls.models.backbones.shufflenet_v2.InvertedResidual.forward')
def shufflenetv2_backbone__forward__default(ctx, self, x):
def shufflenetv2_backbone__forward__default(self, x):
"""Rewrite `forward` of InvertedResidual used in shufflenet_v2.
The chunk in original InvertedResidual.forward will convert to dynamic

View File

@ -9,7 +9,7 @@ from mmdeploy.utils import Backend
func_name= # noqa: E251
'mmcls.models.backbones.vision_transformer.VisionTransformer.forward',
backend=Backend.NCNN.value)
def visiontransformer__forward__ncnn(ctx, self, x):
def visiontransformer__forward__ncnn(self, x):
"""Rewrite `forward` of VisionTransformer for ncnn backend.
The chunk in original VisionTransformer.forward will convert

View File

@ -12,7 +12,6 @@ from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
'mmcls.models.classifiers.BaseClassifier.forward', backend='default')
def base_classifier__forward(
ctx,
self,
batch_inputs: Tensor,
data_samples: Optional[List[BaseDataElement]] = None,

View File

@ -9,7 +9,7 @@ from mmdeploy.utils import Backend
@FUNCTION_REWRITER.register_rewriter(
'mmcls.models.necks.GlobalAveragePooling.forward',
backend=Backend.DEFAULT.value)
def gap__forward(ctx, self, inputs):
def gap__forward(self, inputs):
"""Rewrite `forward` of GlobalAveragePooling for default backend.
Replace `view` with `flatten` to export simple onnx graph.

View File

@ -11,7 +11,7 @@ from mmdeploy.utils import Backend, get_dynamic_axes
@FUNCTION_REWRITER.register_rewriter(
func_name='mmcls.models.utils.attention.MultiheadAttention.forward',
backend=Backend.NCNN.value)
def multiheadattention__forward__ncnn(ctx, self, qkv_input):
def multiheadattention__forward__ncnn(self, qkv_input):
"""Rewrite `forward` of MultiheadAttention used in vision_transformer for
ncnn backend.
@ -53,12 +53,13 @@ def multiheadattention__forward__ncnn(ctx, self, qkv_input):
func_name= # noqa: E251
'mmcls.models.utils.ShiftWindowMSA.forward',
extra_checkers=LibVersionChecker('mmcls', min_version='0.21.0'))
def shift_window_msa__forward__default(ctx, self, query, hw_shape):
def shift_window_msa__forward__default(self, query, hw_shape):
"""Rewrite forward function of ShiftWindowMSA class for TensorRT.
1. replace dynamic padding with static padding and dynamic slice.
2. always do slice `x = x[:, :H, :W, :].contiguous()` for stability.
"""
ctx = FUNCTION_REWRITER.get_context()
if get_dynamic_axes(ctx.cfg) is None:
# avoid the weird bug of torch to onnx
return ctx.origin_func(self, query, hw_shape)
@ -142,8 +143,7 @@ def shift_window_msa__forward__default(ctx, self, query, hw_shape):
func_name= # noqa: E251
'mmcls.models.utils.ShiftWindowMSA.get_attn_mask',
extra_checkers=LibVersionChecker('mmcls', min_version='0.21.0'))
def shift_window_msa__get_attn_mask__default(ctx,
self,
def shift_window_msa__get_attn_mask__default(self,
hw_shape,
window_size,
shift_size,

View File

@ -76,7 +76,7 @@ def clip_bboxes(x1: Tensor, y1: Tensor, x2: Tensor, y2: Tensor,
func_name='mmdeploy.codebase.mmdet.deploy.utils.clip_bboxes',
backend='tensorrt',
extra_checkers=LibVersionChecker('tensorrt', min_version='8'))
def clip_bboxes__trt8(ctx, x1: Tensor, y1: Tensor, x2: Tensor, y2: Tensor,
def clip_bboxes__trt8(x1: Tensor, y1: Tensor, x2: Tensor, y2: Tensor,
max_shape: Union[Tensor, Sequence[int]]):
"""Clip bboxes for onnx. From TensorRT 8 we can do the operators on the
tensors directly.
@ -165,7 +165,6 @@ def __pad_with_value_if_necessary(x: Tensor,
'mmdeploy.codebase.mmdet.deploy.utils.__pad_with_value_if_necessary',
backend=Backend.TENSORRT.value)
def __pad_with_value_if_necessary__tensorrt(
ctx,
x: Tensor,
pad_dim: int,
pad_size: int,
@ -223,12 +222,12 @@ class TRTGatherTopk(torch.autograd.Function):
@FUNCTION_REWRITER.register_rewriter(
'mmdeploy.codebase.mmdet.deploy.utils.__gather_topk',
backend=Backend.TENSORRT.value)
def __gather_topk__trt(ctx,
*inputs: Sequence[torch.Tensor],
def __gather_topk__trt(*inputs: Sequence[torch.Tensor],
inds: torch.Tensor,
batch_size: int,
is_batched: bool = True) -> Tuple[torch.Tensor]:
"""TensorRT gather_topk."""
ctx = FUNCTION_REWRITER.get_context()
_ = ctx
if is_batched:
index_shape = inds.shape
@ -253,8 +252,7 @@ def __gather_topk__trt(ctx,
@FUNCTION_REWRITER.register_rewriter(
'mmdeploy.codebase.mmdet.deploy.utils.__gather_topk',
backend=Backend.COREML.value)
def __gather_topk__nonbatch(ctx,
*inputs: Sequence[torch.Tensor],
def __gather_topk__nonbatch(*inputs: Sequence[torch.Tensor],
inds: torch.Tensor,
batch_size: int,
is_batched: bool = True) -> Tuple[torch.Tensor]:

View File

@ -7,7 +7,7 @@ from mmdeploy.utils import get_common_config
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.backbones.csp_darknet.Focus.forward')
def focus__forward__default(ctx, self, x):
def focus__forward__default(self, x):
"""Rewrite forward function of Focus class.
Replace slice with transpose.
@ -27,7 +27,7 @@ def focus__forward__default(ctx, self, x):
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.backbones.csp_darknet.Focus.forward',
backend='ncnn')
def focus__forward__ncnn(ctx, self, x):
def focus__forward__ncnn(self, x):
"""Rewrite forward function of Focus class for ncnn.
Focus width and height information into channel space. ncnn does not
@ -69,7 +69,7 @@ def focus__forward__ncnn(ctx, self, x):
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.backbones.swin.WindowMSA.forward',
backend='tensorrt')
def windowmsa__forward__tensorrt(ctx, self, x, mask=None):
def windowmsa__forward__tensorrt(self, x, mask=None):
"""Rewrite forward function of WindowMSA class for TensorRT.
1. replace Gather operation of qkv with split.
@ -80,6 +80,7 @@ def windowmsa__forward__tensorrt(ctx, self, x, mask=None):
mask (tensor | None, Optional): mask with shape of (num_windows,
Wh*Ww, Wh*Ww), value should be between (-inf, 0].
"""
ctx = FUNCTION_REWRITER.get_context()
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
-1).permute(2, 0, 3, 1, 4).contiguous()
@ -129,7 +130,7 @@ def windowmsa__forward__tensorrt(ctx, self, x, mask=None):
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.backbones.swin.ShiftWindowMSA.window_reverse',
backend='tensorrt')
def shift_window_msa__window_reverse__tensorrt(ctx, self, windows, H, W):
def shift_window_msa__window_reverse__tensorrt(self, windows, H, W):
"""Rewrite window_reverse function of ShiftWindowMSA class for TensorRT.
For TensorRT, seems radical shape transformations are not allowed. Replace
them with soft ones.
@ -155,7 +156,7 @@ def shift_window_msa__window_reverse__tensorrt(ctx, self, windows, H, W):
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.backbones.swin.ShiftWindowMSA.window_partition',
backend='tensorrt')
def shift_window_msa__window_partition__tensorrt(ctx, self, x):
def shift_window_msa__window_partition__tensorrt(self, x):
"""Rewrite window_partition function of ShiftWindowMSA class for TensorRT.
For TensorRT, seems radical shape transformations are not allowed. Replace
them with soft ones.
@ -176,7 +177,7 @@ def shift_window_msa__window_partition__tensorrt(ctx, self, x):
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.backbones.swin.ShiftWindowMSA.forward')
def shift_window_msa__forward__default(ctx, self, query, hw_shape):
def shift_window_msa__forward__default(self, query, hw_shape):
"""Rewrite forward function of ShiftWindowMSA class.
1. replace dynamic padding with static padding and dynamic slice.

View File

@ -24,7 +24,6 @@ from mmdeploy.utils import Backend, is_dynamic_shape
func_name='mmdet.models.dense_heads.base_dense_head.'
'BaseDenseHead.predict_by_feat')
def base_dense_head__predict_by_feat(
ctx,
self,
cls_scores: List[Tensor],
bbox_preds: List[Tensor],
@ -66,6 +65,7 @@ def base_dense_head__predict_by_feat(
tuple[Tensor, Tensor, Tensor]: batch_mlvl_bboxes,
batch_mlvl_scores, batch_mlvl_centerness
"""
ctx = FUNCTION_REWRITER.get_context()
deploy_cfg = ctx.cfg
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
num_levels = len(cls_scores)
@ -211,7 +211,6 @@ def base_dense_head__predict_by_feat(
'BaseDenseHead.predict_by_feat',
backend=Backend.RKNN.value)
def base_dense_head__predict_by_feat__rknn(
ctx,
self,
cls_scores: List[Tensor],
bbox_preds: List[Tensor],
@ -253,6 +252,8 @@ def base_dense_head__predict_by_feat__rknn(
tuple[Tensor, Tensor, Tensor]: batch_mlvl_bboxes,
batch_mlvl_scores, batch_mlvl_centerness
"""
ctx = FUNCTION_REWRITER.get_context()
# mark nodes for partition
@mark('BaseDenseHead', outputs=['BaseDenseHead.cls', 'BaseDenseHead.loc'])
def __mark_dense_head(cls_scores, bbox_preds):
@ -337,7 +338,6 @@ def base_dense_head__predict_by_feat__rknn(
'BaseDenseHead.predict_by_feat',
backend=Backend.NCNN.value)
def base_dense_head__predict_by_feat__ncnn(
ctx,
self,
cls_scores: List[Tensor],
bbox_preds: List[Tensor],
@ -376,6 +376,7 @@ def base_dense_head__predict_by_feat__ncnn(
Returns:
output__ncnn (Tensor): outputs, shape is [N, num_det, 6].
"""
ctx = FUNCTION_REWRITER.get_context()
assert len(cls_scores) == len(bbox_preds)
deploy_cfg = ctx.cfg
assert not is_dynamic_shape(deploy_cfg), 'base_dense_head for ncnn\

View File

@ -9,7 +9,6 @@ from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
'mmdet.models.dense_heads.centernet_head.CenterNetHead.predict_by_feat')
def centernet_head__predict_by_feat__default(
ctx,
self,
center_heatmap_preds: List[Tensor],
wh_preds: List[Tensor],

View File

@ -10,7 +10,7 @@ from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
'mmdet.models.dense_heads.DETRHead.forward_single')
def detrhead__forward_single__default(ctx, self, x, img_metas):
def detrhead__forward_single__default(self, x, img_metas):
"""forward_single of DETRHead.
Ease the mask computation
@ -35,8 +35,7 @@ def detrhead__forward_single__default(ctx, self, x, img_metas):
@FUNCTION_REWRITER.register_rewriter(
'mmdet.models.dense_heads.DETRHead.predict_by_feat')
def detrhead__predict_by_feat__default(ctx,
self,
def detrhead__predict_by_feat__default(self,
all_cls_scores_list: List[Tensor],
all_bbox_preds_list: List[Tensor],
batch_img_metas: List[dict],

View File

@ -13,8 +13,7 @@ from mmdeploy.mmcv.ops import multiclass_nms
@FUNCTION_REWRITER.register_rewriter(
'mmdet.models.dense_heads.fovea_head.FoveaHead.predict_by_feat')
def fovea_head__predict_by_feat(ctx,
self,
def fovea_head__predict_by_feat(self,
cls_scores: List[Tensor],
bbox_preds: List[Tensor],
score_factors: Optional[List[Tensor]] = None,
@ -49,6 +48,7 @@ def fovea_head__predict_by_feat(ctx,
`dets` of shape [N, num_det, 5] and `labels` of shape
[N, num_det].
"""
ctx = FUNCTION_REWRITER.get_context()
assert len(cls_scores) == len(bbox_preds)
cfg = self.test_cfg if cfg is None else cfg
num_levels = len(cls_scores)

View File

@ -18,8 +18,7 @@ from mmdeploy.utils import Backend, get_backend, is_dynamic_shape
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.dense_heads.gfl_head.'
'GFLHead.predict_by_feat')
def gfl_head__predict_by_feat(ctx,
self,
def gfl_head__predict_by_feat(self,
cls_scores: List[Tensor],
bbox_preds: List[Tensor],
score_factors: Optional[List[Tensor]] = None,
@ -58,6 +57,7 @@ def gfl_head__predict_by_feat(ctx,
tuple[Tensor, Tensor, Tensor]: batch_mlvl_bboxes,
batch_mlvl_scores, batch_mlvl_centerness
"""
ctx = FUNCTION_REWRITER.get_context()
deploy_cfg = ctx.cfg
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
backend = get_backend(deploy_cfg)

View File

@ -35,12 +35,13 @@ def _bbox_post_decode(bboxes: torch.Tensor, max_shape: Sequence[int]):
@FUNCTION_REWRITER.register_rewriter(
'mmdet.models.dense_heads.reppoints_head.RepPointsHead.points2bbox')
def reppoints_head__points2bbox(ctx, self, pts, y_first=True):
def reppoints_head__points2bbox(self, pts, y_first=True):
"""Rewrite of `points2bbox` in `RepPointsHead`.
Use `self.moment_transfer` in `points2bbox` will cause error:
RuntimeError: Input, output and indices must be on the current device
"""
ctx = FUNCTION_REWRITER.get_context()
update_moment = hasattr(self, 'moment_transfer')
if update_moment:
moment_transfer = self.moment_transfer
@ -55,7 +56,6 @@ def reppoints_head__points2bbox(ctx, self, pts, y_first=True):
@FUNCTION_REWRITER.register_rewriter(
'mmdet.models.dense_heads.reppoints_head.RepPointsHead.predict_by_feat')
def reppoints_head__predict_by_feat(
ctx,
self,
cls_scores: List[Tensor],
bbox_preds: List[Tensor],
@ -91,6 +91,7 @@ def reppoints_head__predict_by_feat(
`dets` of shape [N, num_det, 5] and `labels` of shape
[N, num_det].
"""
ctx = FUNCTION_REWRITER.get_context()
deploy_cfg = ctx.cfg
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
num_levels = len(cls_scores)

View File

@ -16,8 +16,7 @@ from mmdeploy.utils import Backend, is_dynamic_shape
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.dense_heads.rpn_head.'
'RPNHead.predict_by_feat')
def rpn_head__predict_by_feat(ctx,
self,
def rpn_head__predict_by_feat(self,
cls_scores: List[Tensor],
bbox_preds: List[Tensor],
score_factors: Optional[List[Tensor]] = None,
@ -61,6 +60,7 @@ def rpn_head__predict_by_feat(ctx,
tuple[Tensor, Tensor, Tensor]: batch_mlvl_bboxes,
batch_mlvl_scores, batch_mlvl_centerness
"""
ctx = FUNCTION_REWRITER.get_context()
img_metas = batch_img_metas
assert len(cls_scores) == len(bbox_preds)
deploy_cfg = ctx.cfg
@ -166,8 +166,7 @@ def rpn_head__predict_by_feat(ctx,
# TODO: Fix for 1.x
@FUNCTION_REWRITER.register_rewriter(
'mmdet.models.dense_heads.RPNHead.get_bboxes', backend=Backend.NCNN.value)
def rpn_head__get_bboxes__ncnn(ctx,
self,
def rpn_head__get_bboxes__ncnn(self,
cls_scores,
bbox_preds,
img_metas,
@ -204,6 +203,7 @@ def rpn_head__get_bboxes__ncnn(ctx,
Else:
tuple[Tensor, Tensor]: batch_mlvl_bboxes, batch_mlvl_scores
"""
ctx = FUNCTION_REWRITER.get_context()
assert len(cls_scores) == len(bbox_preds)
deploy_cfg = ctx.cfg
assert not is_dynamic_shape(deploy_cfg)

View File

@ -14,8 +14,7 @@ from mmdeploy.mmcv.ops import multiclass_nms
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.dense_heads.rtmdet_head.'
'RTMDetHead.predict_by_feat')
def rtmdet_head__predict_by_feat(ctx,
self,
def rtmdet_head__predict_by_feat(self,
cls_scores: List[Tensor],
bbox_preds: List[Tensor],
batch_img_metas: Optional[List[dict]] = None,
@ -52,6 +51,7 @@ def rtmdet_head__predict_by_feat(ctx,
tensor in the tuple is (N, num_box), and each element
represents the class label of the corresponding box.
"""
ctx = FUNCTION_REWRITER.get_context()
assert len(cls_scores) == len(bbox_preds)
device = cls_scores[0].device
cfg = self.test_cfg if cfg is None else cfg

View File

@ -15,8 +15,7 @@ from mmdeploy.utils import Backend, is_dynamic_shape
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.dense_heads.yolo_head.'
'YOLOV3Head.predict_by_feat')
def yolov3_head__predict_by_feat(ctx,
self,
def yolov3_head__predict_by_feat(self,
pred_maps: Sequence[Tensor],
cfg: OptConfigType = None,
rescale: bool = False,
@ -47,6 +46,7 @@ def yolov3_head__predict_by_feat(ctx,
Else:
tuple[Tensor, Tensor]: batch_mlvl_bboxes, batch_mlvl_scores
"""
ctx = FUNCTION_REWRITER.get_context()
deploy_cfg = ctx.cfg
# mark pred_maps
@ -154,8 +154,7 @@ def yolov3_head__predict_by_feat(ctx,
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.dense_heads.YOLOV3Head.predict_by_feat',
backend=Backend.NCNN.value)
def yolov3_head__predict_by_feat__ncnn(ctx,
self,
def yolov3_head__predict_by_feat__ncnn(self,
pred_maps,
with_nms=True,
cfg=None,
@ -188,6 +187,7 @@ def yolov3_head__predict_by_feat__ncnn(ctx,
fore-ground class label in Yolov3DetectionOutput starts
from `1`. x1, y1, x2, y2 are normalized in range(0,1).
"""
ctx = FUNCTION_REWRITER.get_context()
num_levels = len(pred_maps)
cfg = self.test_cfg if cfg is None else cfg
post_params = get_post_processing_params(ctx.cfg)

View File

@ -15,8 +15,7 @@ from mmdeploy.utils import Backend
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.dense_heads.yolox_head.'
'YOLOXHead.predict_by_feat')
def yolox_head__predict_by_feat(ctx,
self,
def yolox_head__predict_by_feat(self,
cls_scores: List[Tensor],
bbox_preds: List[Tensor],
objectnesses: Optional[List[Tensor]],
@ -57,6 +56,8 @@ def yolox_head__predict_by_feat(ctx,
tensor in the tuple is (N, num_box), and each element
represents the class label of the corresponding box.
"""
ctx = FUNCTION_REWRITER.get_context()
# mark pred_maps
@mark('yolo_head', inputs=['cls_scores', 'bbox_preds', 'objectnesses'])
def __mark_pred_maps(cls_scores, bbox_preds, objectnesses):
@ -125,7 +126,6 @@ def yolox_head__predict_by_feat(ctx,
'YOLOXHead.predict_by_feat',
backend=Backend.NCNN.value)
def yolox_head__predict_by_feat__ncnn(
ctx,
self,
cls_scores: List[Tensor],
bbox_preds: List[Tensor],
@ -169,6 +169,7 @@ def yolox_head__predict_by_feat__ncnn(
Returns:
output__ncnn (Tensor): outputs, shape is [N, num_det, 6].
"""
ctx = FUNCTION_REWRITER.get_context()
from mmdeploy.codebase.mmdet.ops import ncnn_detection_output_forward
from mmdeploy.utils import get_root_logger
from mmdeploy.utils.config_utils import is_dynamic_shape

View File

@ -12,7 +12,7 @@ from mmdeploy.utils import is_dynamic_shape
@mark(
'detector_forward', inputs=['input'], outputs=['dets', 'labels', 'masks'])
def __forward_impl(ctx, self, batch_inputs, data_samples, **kwargs):
def __forward_impl(self, batch_inputs, data_samples):
"""Rewrite and adding mark for `forward`.
Encapsulate this function for rewriting `forward` of BaseDetector.
@ -25,10 +25,30 @@ def __forward_impl(ctx, self, batch_inputs, data_samples, **kwargs):
return output
@torch.fx.wrap
def _set_metainfo(data_samples, img_shape):
"""Set the metainfo.
Code in this function cannot be traced by fx.
"""
# fx can not trace deepcopy correctly
data_samples = copy.deepcopy(data_samples)
if data_samples is None:
data_samples = [DetDataSample()]
# note that we can not use `set_metainfo`, deepcopy would crash the
# onnx trace.
for data_sample in data_samples:
data_sample.set_field(
name='img_shape', value=img_shape, field_type='metainfo')
return data_samples
@FUNCTION_REWRITER.register_rewriter(
'mmdet.models.detectors.single_stage.SingleStageDetector.forward')
def single_stage_detector__forward(ctx,
self,
def single_stage_detector__forward(self,
batch_inputs: torch.Tensor,
data_samples: OptSampleList = None,
mode: str = 'tensor',
@ -53,9 +73,8 @@ def single_stage_detector__forward(ctx,
- labels (Tensor): Labels of bboxes, has a shape
(num_instances, ).
"""
data_samples = copy.deepcopy(data_samples)
if data_samples is None:
data_samples = [DetDataSample()]
ctx = FUNCTION_REWRITER.get_context()
deploy_cfg = ctx.cfg
# get origin input shape as tensor to support onnx dynamic shape
@ -65,11 +84,6 @@ def single_stage_detector__forward(ctx,
img_shape = [int(val) for val in img_shape]
# set the metainfo
# note that we can not use `set_metainfo`, deepcopy would crash the
# onnx trace.
for data_sample in data_samples:
data_sample.set_field(
name='img_shape', value=img_shape, field_type='metainfo')
data_samples = _set_metainfo(data_samples, img_shape)
return __forward_impl(
ctx, self, batch_inputs, data_samples=data_samples, **kwargs)
return __forward_impl(self, batch_inputs, data_samples=data_samples)

View File

@ -11,8 +11,7 @@ from mmdeploy.utils import is_dynamic_shape
@FUNCTION_REWRITER.register_rewriter(
'mmdet.models.detectors.two_stage.TwoStageDetector.extract_feat')
@mark('extract_feat', inputs='img', outputs='feat')
def two_stage_detector__extract_feat(ctx, self, img):
def two_stage_detector__extract_feat(self, img):
"""Rewrite `extract_feat` for default backend.
This function uses the specific `extract_feat` function for the two
@ -27,13 +26,18 @@ def two_stage_detector__extract_feat(ctx, self, img):
list[Tensor]: Each item with shape (N, C, H, W) corresponds one
level of backbone and neck features.
"""
return ctx.origin_func(self, img)
ctx = FUNCTION_REWRITER.get_context()
@mark('extract_feat', inputs='img', outputs='feat')
def __extract_feat_impl(self, img):
return ctx.origin_func(self, img)
return __extract_feat_impl(self, img)
@FUNCTION_REWRITER.register_rewriter(
'mmdet.models.detectors.two_stage.TwoStageDetector.forward')
def two_stage_detector__forward(ctx,
self,
def two_stage_detector__forward(self,
batch_inputs: torch.Tensor,
data_samples: OptSampleList = None,
mode: str = 'tensor',
@ -58,6 +62,7 @@ def two_stage_detector__forward(ctx,
- labels (Tensor): Labels of bboxes, has a shape
(num_instances, ).
"""
ctx = FUNCTION_REWRITER.get_context()
data_samples = copy.deepcopy(data_samples)
deploy_cfg = ctx.cfg

View File

@ -7,7 +7,7 @@ from mmdeploy.utils import Backend, get_root_logger
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.necks.ssd_neck.L2Norm.forward')
def l2norm__forward__default(ctx, self, x):
def l2norm__forward__default(self, x):
"""Default rewriter for l2norm.
Implement with functinoal.normalize .
@ -19,11 +19,12 @@ def l2norm__forward__default(ctx, self, x):
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.necks.ssd_neck.L2Norm.forward',
backend=Backend.TENSORRT.value)
def l2norm__forward__tensorrt(ctx, self, x):
def l2norm__forward__tensorrt(self, x):
"""rewrite `l2norm` for TensorRT.
TensorRT7 does not support dynamic clamp, which is used in normalize.
"""
ctx = FUNCTION_REWRITER.get_context()
logger = get_root_logger()
trt_version_major = 8
try:
@ -34,6 +35,6 @@ def l2norm__forward__tensorrt(ctx, self, x):
except Exception:
logger.warning('Can not get TensorRT version.')
if trt_version_major >= 8:
return l2norm__forward__default(ctx, self, x)
return l2norm__forward__default(self, x)
else:
return ctx.origin_func(self, x)

View File

@ -17,11 +17,7 @@ from mmdeploy.mmcv.ops import multiclass_nms
@FUNCTION_REWRITER.register_rewriter(
'mmdet.models.roi_heads.bbox_heads.convfc_bbox_head.ConvFCBBoxHead.forward'
)
@mark(
'bbox_head_forward',
inputs=['bbox_feats'],
outputs=['cls_score', 'bbox_pred'])
def bbox_head__forward(ctx, self, x):
def bbox_head__forward(self, x):
"""Rewrite `forward` for default backend.
This function uses the specific `forward` function for the BBoxHead
@ -37,13 +33,21 @@ def bbox_head__forward(ctx, self, x):
has shape (N, num_det, num_cls) and the bbox_pred has shape
(N, num_det, 4).
"""
return ctx.origin_func(self, x)
ctx = FUNCTION_REWRITER.get_context()
@mark(
'bbox_head_forward',
inputs=['bbox_feats'],
outputs=['cls_score', 'bbox_pred'])
def __forward(self, x):
return ctx.origin_func(self, x)
return __forward(self, x)
@FUNCTION_REWRITER.register_rewriter(
'mmdet.models.roi_heads.bbox_heads.bbox_head.BBoxHead.predict_by_feat')
def bbox_head__predict_by_feat(ctx,
self,
def bbox_head__predict_by_feat(self,
rois: Tuple[Tensor],
cls_scores: Tuple[Tensor],
bbox_preds: Tuple[Tensor],
@ -75,6 +79,7 @@ def bbox_head__predict_by_feat(ctx,
- labels (Tensor): Labels of bboxes, has a shape
(num_instances, ).
"""
ctx = FUNCTION_REWRITER.get_context()
assert rois.ndim == 3, 'Only support export two stage ' \
'model to ONNX ' \
'with batch dimension. '

View File

@ -11,8 +11,7 @@ from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
'mmdet.models.roi_heads.cascade_roi_head.CascadeRoIHead.predict_bbox')
def cascade_roi_head__predict_bbox(ctx,
self,
def cascade_roi_head__predict_bbox(self,
x: Tuple[Tensor],
batch_img_metas: List[dict],
rpn_results_list: List[Tensor],
@ -86,8 +85,7 @@ def cascade_roi_head__predict_bbox(ctx,
@FUNCTION_REWRITER.register_rewriter(
'mmdet.models.roi_heads.cascade_roi_head.CascadeRoIHead.predict_mask')
def cascade_roi_head__predict_mask(ctx,
self,
def cascade_roi_head__predict_mask(self,
x: Tuple[Tensor],
batch_img_metas: List[dict],
results_list: List[Tensor],

View File

@ -14,8 +14,7 @@ from mmdeploy.utils import Backend, get_backend
@FUNCTION_REWRITER.register_rewriter(
'mmdet.models.roi_heads.'
'mask_heads.fcn_mask_head.FCNMaskHead.predict_by_feat')
def fcn_mask_head__predict_by_feat(ctx,
self,
def fcn_mask_head__predict_by_feat(self,
mask_preds: Tuple[Tensor],
results_list: List[Tensor],
batch_img_metas: List[dict],
@ -48,6 +47,7 @@ def fcn_mask_head__predict_by_feat(ctx,
(num_instances, ).
- masks (Tensor): Has a shape (num_instances, H, W).
"""
ctx = FUNCTION_REWRITER.get_context()
ori_shape = batch_img_metas[0]['img_shape']
dets, det_labels = results_list
dets = dets.view(-1, 5)

View File

@ -61,13 +61,12 @@ class MultiLevelRoiAlign(Function):
(num_proposals, channel, output_size[1], output_size[0]))
@mark('roi_extractor', inputs=['feats', 'rois'], outputs=['bbox_feats'])
@FUNCTION_REWRITER.register_rewriter(
'mmdet.models.roi_heads.roi_extractors.'
'single_level_roi_extractor.SingleRoIExtractor.forward',
backend='tensorrt')
@mark('roi_extractor', inputs=['feats', 'rois'], outputs=['bbox_feats'])
def single_roi_extractor__forward__tensorrt(ctx,
self,
def single_roi_extractor__forward__tensorrt(self,
feats,
rois,
roi_scale_factor=None):
@ -154,8 +153,7 @@ class AscendRoiExtractor(Function):
'mmdet.models.roi_heads.roi_extractors.'
'single_level_roi_extractor.SingleRoIExtractor.forward',
backend='ascend')
def single_roi_extractor__forward__ascend(ctx,
self,
def single_roi_extractor__forward__ascend(self,
feats,
rois,
roi_scale_factor=None):
@ -185,14 +183,10 @@ def single_roi_extractor__forward__ascend(ctx,
finest_scale, featmap_strides, aligned)
@mark('roi_extractor', inputs=['feats', 'rois'], outputs=['bbox_feats'])
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.roi_heads.SingleRoIExtractor.forward')
@mark('roi_extractor', inputs=['feats', 'rois'], outputs=['bbox_feats'])
def single_roi_extractor__forward(ctx,
self,
feats,
rois,
roi_scale_factor=None):
def single_roi_extractor__forward(self, feats, rois, roi_scale_factor=None):
"""Rewrite `forward` of SingleRoIExtractor for default backend.
Rewrite this function to:
@ -206,6 +200,8 @@ def single_roi_extractor__forward(ctx,
3. use the roi align in torhcvision to accelerate the inference.
"""
ctx = FUNCTION_REWRITER.get_context(
'mmdet.models.roi_heads.SingleRoIExtractor.forward')
backend = get_backend(ctx.cfg)
out_size = self.roi_layers[0].output_size
num_levels = len(feats)
@ -291,8 +287,7 @@ class SingleRoIExtractorOpenVINO(Function):
'mmdet.models.roi_heads.roi_extractors.'
'single_level_roi_extractor.SingleRoIExtractor.forward',
backend='openvino')
def single_roi_extractor__forward__openvino(ctx,
self,
def single_roi_extractor__forward__openvino(self,
feats,
rois,
roi_scale_factor=None):
@ -301,6 +296,7 @@ def single_roi_extractor__forward__openvino(ctx,
This function uses ExperimentalDetectronROIFeatureExtractor for OpenVINO.
"""
ctx = FUNCTION_REWRITER.get_context()
# Adding original output to SingleRoIExtractorOpenVINO.
state = torch._C._get_tracing_state()
@ -317,12 +313,11 @@ def single_roi_extractor__forward__openvino(ctx,
return result
@mark('roi_extractor', inputs=['feats', 'rois'], outputs=['bbox_feats'])
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.roi_heads.SingleRoIExtractor.forward',
backend=Backend.COREML.value)
@mark('roi_extractor', inputs=['feats', 'rois'], outputs=['bbox_feats'])
def single_roi_extractor__forward__coreml(ctx,
self,
def single_roi_extractor__forward__coreml(self,
feats,
rois,
roi_scale_factor=None):

View File

@ -10,8 +10,7 @@ from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
'mmdet.models.roi_heads.standard_roi_head.StandardRoIHead.predict_bbox')
def standard_roi_head__predict_bbox(ctx,
self,
def standard_roi_head__predict_bbox(self,
x: Tuple[Tensor],
batch_img_metas: List[dict],
rpn_results_list: List[Tensor],
@ -72,8 +71,7 @@ def standard_roi_head__predict_bbox(ctx,
@FUNCTION_REWRITER.register_rewriter(
'mmdet.models.roi_heads.standard_roi_head.StandardRoIHead.predict_mask')
def standard_roi_head__predict_mask(ctx,
self,
def standard_roi_head__predict_mask(self,
x: Tuple[Tensor],
batch_img_metas: List[dict],
results_list: List[Tensor],

View File

@ -9,8 +9,7 @@ from mmdeploy.core import FUNCTION_REWRITER
func_name='mmdet.models.task_modules.coders.delta_xywh_bbox_coder.'
'DeltaXYWHBBoxCoder.decode',
backend='default')
def deltaxywhbboxcoder__decode(ctx,
self,
def deltaxywhbboxcoder__decode(self,
bboxes,
pred_bboxes,
max_shape=None,
@ -51,8 +50,7 @@ def deltaxywhbboxcoder__decode(ctx,
func_name='mmdet.models.task_modules.coders'
'.delta_xywh_bbox_coder.delta2bbox',
backend='default')
def delta2bbox(ctx,
rois,
def delta2bbox(rois,
deltas,
means=(0., 0., 0., 0.),
stds=(1., 1., 1., 1.),
@ -143,8 +141,7 @@ def delta2bbox(ctx,
func_name='mmdet.models.task_modules.coders.'
'delta_xywh_bbox_coder.delta2bbox',
backend='ncnn')
def delta2bbox__ncnn(ctx,
rois,
def delta2bbox__ncnn(rois,
deltas,
means=(0., 0., 0., 0.),
stds=(1., 1., 1., 1.),

View File

@ -8,11 +8,7 @@ from mmdeploy.core import FUNCTION_REWRITER
func_name='mmdet.models.task_modules.coders.distance_point_bbox_coder'
'.DistancePointBBoxCoder.decode',
backend='default')
def distancepointbboxcoder__decode(ctx,
self,
points,
pred_bboxes,
max_shape=None):
def distancepointbboxcoder__decode(self, points, pred_bboxes, max_shape=None):
"""Rewrite `mmdet.models.task_modules.coders.distance_point_bbox_coder. \
DistancePointBBoxCoder.decode`

View File

@ -7,8 +7,7 @@ from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.task_modules.coders.tblr_bbox_coder.tblr2bboxes',
backend='default')
def tblr2bboxes(ctx,
priors,
def tblr2bboxes(priors,
tblr,
normalizer=4.0,
normalize_by_wh=True,

View File

@ -77,7 +77,6 @@ grid_priors_trt = GridPriorsTRTOp.apply
'AnchorGenerator.single_level_grid_priors',
backend='tensorrt')
def anchorgenerator__single_level_grid_priors__trt(
ctx,
self,
featmap_size: Tuple[int],
level_idx: int,
@ -98,6 +97,7 @@ def anchorgenerator__single_level_grid_priors__trt(
Returns:
torch.Tensor: Anchors in the overall feature maps.
"""
ctx = FUNCTION_REWRITER.get_context()
from mmdet.models.task_modules.prior_generators import AnchorGenerator
if type(self) != AnchorGenerator:
# only use custom node on default generator.

View File

@ -10,7 +10,6 @@ from mmdeploy.utils.constants import Backend
'.single_level_grid_priors',
backend=Backend.TENSORRT.value)
def mlvl_point_generator__single_level_grid_priors__tensorrt(
ctx,
self,
featmap_size,
level_idx,

View File

@ -7,7 +7,7 @@ from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.utils.transformer.PatchMerging.forward',
backend='tensorrt')
def patch_merging__forward__tensorrt(ctx, self, x, input_size):
def patch_merging__forward__tensorrt(self, x, input_size):
"""Rewrite forward function of PatchMerging class for TensorRT. In original
implementation, mmdet applies nn.unfold to accelerate the inference.
However, the onnx graph of it can not be parsed correctly by TensorRT. In

View File

@ -8,7 +8,7 @@ from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.structures.bbox.transforms.distance2bbox' # noqa
)
def distance2bbox__default(ctx, points, distance, max_shape=None):
def distance2bbox__default(points, distance, max_shape=None):
"""Rewrite `mmdet.core.bbox.transforms.distance2bbox`
Decode distance prediction to bounding box.

View File

@ -9,8 +9,7 @@ from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
'mmdet3d.models.detectors.Base3DDetector.forward' # noqa: E501
)
def basedetector__forward(ctx,
self,
def basedetector__forward(self,
inputs: list,
data_samples=None,
**kwargs) -> Tuple[List[torch.Tensor]]:

View File

@ -7,8 +7,7 @@ from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
'mmdet3d.models.detectors.mvx_two_stage.MVXTwoStageDetector.extract_img_feat' # noqa: E501
)
def mvxtwostagedetector__extract_img_feat(ctx, self,
img: torch.Tensor) -> dict:
def mvxtwostagedetector__extract_img_feat(self, img: torch.Tensor) -> dict:
"""Extract features of images."""
if self.with_img_backbone and img is not None:
if img.dim() == 5 and img.size(0) == 1:
@ -26,8 +25,7 @@ def mvxtwostagedetector__extract_img_feat(ctx, self,
@FUNCTION_REWRITER.register_rewriter(
'mmdet3d.models.detectors.mvx_two_stage.MVXTwoStageDetector.extract_feat')
def mvxtwostagedetector__extract_feat(ctx, self,
batch_inputs_dict: dict) -> tuple:
def mvxtwostagedetector__extract_feat(self, batch_inputs_dict: dict) -> tuple:
"""Rewrite this func to remove voxelize op.
Args:
@ -47,7 +45,7 @@ def mvxtwostagedetector__extract_feat(ctx, self,
@FUNCTION_REWRITER.register_rewriter(
'mmdet3d.models.detectors.mvx_two_stage.MVXTwoStageDetector.forward')
def mvxtwostagedetector__forward(ctx, self, inputs: list, **kwargs):
def mvxtwostagedetector__forward(self, inputs: list, **kwargs):
"""Rewrite this func to remove voxelize op.
Args:

View File

@ -7,7 +7,7 @@ from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
'mmdet3d.models.voxel_encoders.pillar_encoder.PillarFeatureNet.forward')
def pillar_encoder__forward(ctx, self, features, num_points, coors, *args,
def pillar_encoder__forward(self, features, num_points, coors, *args,
**kwargs):
"""Rewrite this func to optimize node. Modify the code at
_with_voxel_center and use slice instead of the original operation.

View File

@ -7,11 +7,7 @@ from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
'mmdet3d.models.middle_encoders.pillar_scatter.'
'PointPillarsScatter.forward_batch', )
def pointpillarsscatter__forward(ctx,
self,
voxel_features,
coors,
batch_size=1):
def pointpillarsscatter__forward(self, voxel_features, coors, batch_size=1):
"""Scatter features of single sample.
Args:

View File

@ -10,7 +10,6 @@ from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
'mmedit.models.base_models.BaseEditModel.forward', backend='default')
def base_edit_model__forward(
ctx,
self,
batch_inputs: Tensor,
data_samples: Optional[List[BaseDataElement]] = None,

View File

@ -8,7 +8,7 @@ from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
func_name='mmocr.models.textdet.FPNC.forward', backend='tensorrt')
def fpnc__forward__tensorrt(ctx, self, inputs, **kwargs):
def fpnc__forward__tensorrt(self, inputs, **kwargs):
"""Rewrite `forward` of FPNC for tensorrt backend.
Rewrite this function to replace nearest upsampling with bilinear

View File

@ -10,7 +10,7 @@ from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
func_name='mmocr.models.textdet.heads.BaseTextDetHead.predict')
def base_text_det_head__predict(
ctx, self, x: torch.Tensor,
self, x: torch.Tensor,
batch_data_samples: DetSampleList) -> DetSampleList:
"""Rewrite `predict` of BaseTextDetHead for default backend.
@ -38,7 +38,7 @@ def base_text_det_head__predict(
@FUNCTION_REWRITER.register_rewriter(
func_name='mmocr.models.textdet.heads.DBHead.predict')
def db_head__predict(ctx, self, x: torch.Tensor,
def db_head__predict(self, x: torch.Tensor,
batch_data_samples: DetSampleList) -> DetSampleList:
"""Rewrite to avoid post-process of text detection head.

View File

@ -10,7 +10,6 @@ from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
func_name='mmocr.models.textdet.SingleStageTextDetector.forward')
def single_stage_text_detector__forward(
ctx,
self,
batch_inputs: torch.Tensor,
data_samples: TextDetDataSample = None,

View File

@ -10,7 +10,6 @@ from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
func_name='mmocr.models.textrecog.decoders.BaseDecoder.predict')
def base_decoder__forward(
ctx,
self,
feat: Optional[torch.Tensor] = None,
out_enc: Optional[torch.Tensor] = None,

View File

@ -5,7 +5,7 @@ from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
func_name='mmocr.models.textrecog.decoders.CRNNDecoder.forward_train',
backend='ncnn')
def crnndecoder__forward_train__ncnn(ctx, self, feat, *args, **kwargs):
def crnndecoder__forward_train__ncnn(self, feat, *args, **kwargs):
"""Rewrite `forward_train` of CRNNDecoder for ncnn backend.
Rewrite this function to skip permuting dims of outputs from `[W, N, C]` to

View File

@ -7,7 +7,7 @@ from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
func_name='mmocr.models.textrecog.EncoderDecoderRecognizer.forward')
def encoder_decoder_recognizer__forward(ctx, self, batch_inputs: torch.Tensor,
def encoder_decoder_recognizer__forward(self, batch_inputs: torch.Tensor,
data_samples: TextRecogDataSample,
**kwargs) -> TextRecogDataSample:
"""Rewrite `forward` of EncoderDecoderRecognizer for default backend.

View File

@ -6,7 +6,7 @@ from mmdeploy.core import FUNCTION_REWRITER
func_name='mmocr.models.textrecog.layers.lstm_layer'
'.BidirectionalLSTM.forward',
backend='ncnn')
def bidirectionallstm__forward__ncnn(ctx, self, input):
def bidirectionallstm__forward__ncnn(self, input):
"""Rewrite `forward` of BidirectionalLSTM for ncnn backend.
Rewrite this function to set batch_first of rnn layer to true. RNN in ncnn

View File

@ -15,7 +15,6 @@ from mmdeploy.core import FUNCTION_REWRITER, MODULE_REWRITER
'._2d_attention',
backend='default')
def parallel_sar_decoder__2d_attention(
ctx,
self,
decoder_input: torch.Tensor,
feat: torch.Tensor,
@ -85,8 +84,7 @@ def parallel_sar_decoder__2d_attention(
func_name='mmocr.models.textrecog.decoders.SequentialSARDecoder'
'._2d_attention',
backend='default')
def sequential_sar_decoder__2d_attention(ctx,
self,
def sequential_sar_decoder__2d_attention(self,
y_prev,
feat,
holistic_feat,
@ -151,7 +149,6 @@ def sequential_sar_decoder__2d_attention(ctx,
'.forward_test',
backend='default')
def sequential_sar_decoder__forward_test(
ctx,
self,
feat: torch.Tensor,
out_enc: torch.Tensor,

View File

@ -12,7 +12,6 @@ from mmdeploy.core import FUNCTION_REWRITER
func_name='mmocr.models.textrecog.encoders.SAREncoder.forward',
backend='default')
def sar_encoder__forward(
ctx,
self,
feat: torch.Tensor,
data_samples: Optional[Sequence[TextRecogDataSample]] = None):

View File

@ -7,7 +7,7 @@ from mmdeploy.core import FUNCTION_REWRITER
'mmpose.models.heads.heatmap_heads.CPMHead.forward')
@FUNCTION_REWRITER.register_rewriter(
'mmpose.models.heads.heatmap_heads.MSPNHead.forward')
def mspn_head__forward(ctx, self, feats):
def mspn_head__forward(self, feats):
"""Rewrite `forward` of MSPNHead and CPMHead for default backend.
1. return last stage heatmaps directly.
@ -18,6 +18,7 @@ def mspn_head__forward(ctx, self, feats):
Returns:
output_heatmap (torch.Tensor): Output heatmaps.
"""
ctx = FUNCTION_REWRITER.get_context()
msmu_batch_heatmaps = ctx.origin_func(self, feats)
batch_heatmaps = msmu_batch_heatmaps[-1]
return batch_heatmaps

View File

@ -4,7 +4,7 @@ from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
'mmpose.models.pose_estimators.base.BasePoseEstimator.forward')
def base_pose_estimator__forward(ctx, self, inputs, *args, **kwargs):
def base_pose_estimator__forward(self, inputs, *args, **kwargs):
"""Rewrite `forward` of TopDown for default backend.'.
1.directly call _forward of subclass.

View File

@ -17,8 +17,7 @@ from mmdeploy.utils import is_dynamic_shape
@FUNCTION_REWRITER.register_rewriter(
'mmrotate.models.dense_heads.OrientedRPNHead.predict_by_feat')
def rpn_head__predict_by_feat(ctx,
self,
def rpn_head__predict_by_feat(self,
cls_scores: List[Tensor],
bbox_preds: List[Tensor],
score_factors: Optional[List[Tensor]] = None,
@ -62,6 +61,7 @@ def rpn_head__predict_by_feat(ctx,
tuple[Tensor, Tensor, Tensor]: batch_mlvl_bboxes,
batch_mlvl_scores, batch_mlvl_centerness
"""
ctx = FUNCTION_REWRITER.get_context()
img_metas = batch_img_metas
assert len(cls_scores) == len(bbox_preds)
deploy_cfg = ctx.cfg

View File

@ -15,8 +15,7 @@ from mmdeploy.mmcv.ops import multiclass_nms
@FUNCTION_REWRITER.register_rewriter(
'mmrotate.models.roi_heads.bbox_heads.GVBBoxHead.predict_by_feat')
def gv_bbox_head__predict_by_feat(ctx,
self,
def gv_bbox_head__predict_by_feat(self,
rois: Tuple[Tensor],
cls_scores: Tuple[Tensor],
bbox_preds: Tuple[Tensor],
@ -60,6 +59,7 @@ def gv_bbox_head__predict_by_feat(ctx,
assert rois.ndim == 3, 'Only support export two stage ' \
'model to ONNX ' \
'with batch dimension. '
ctx = FUNCTION_REWRITER.get_context()
img_shape = batch_img_metas[0]['img_shape']
if self.custom_cls_channels:

View File

@ -11,8 +11,7 @@ from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
'mmrotate.models.roi_heads.gv_ratio_roi_head'
'.GVRatioRoIHead.predict_bbox')
def gv_ratio_roi_head__predict_bbox(ctx,
self,
def gv_ratio_roi_head__predict_bbox(self,
x: Tuple[Tensor],
batch_img_metas: List[dict],
rpn_results_list: InstanceList,

View File

@ -66,8 +66,7 @@ class MultiLevelRotatedRoiAlign(Function):
backend='tensorrt')
@mark(
'rotated_roi_extractor', inputs=['feats', 'rois'], outputs=['bbox_feats'])
def rotated_single_roi_extractor__forward__tensorrt(ctx,
self,
def rotated_single_roi_extractor__forward__tensorrt(self,
feats,
rois,
roi_scale_factor=None):

View File

@ -8,7 +8,7 @@ from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
'mmrotate.models.task_modules.coders.gliding_vertex_coder'
'.GVFixCoder.decode')
def gvfixcoder__decode(ctx, self, hboxes, fix_deltas):
def gvfixcoder__decode(self, hboxes, fix_deltas):
"""Rewriter for GVFixCoder decode, support more dimension input."""
assert hboxes.size(

View File

@ -21,7 +21,7 @@ def _dist_torch(point1, point2):
@FUNCTION_REWRITER.register_rewriter(
'mmrotate.structures.bbox.box_converters.qbox2rbox')
def qbox2rbox__default(ctx, boxes: Tensor) -> Tensor:
def qbox2rbox__default(boxes: Tensor) -> Tensor:
"""Convert quadrilateral boxes to rotated boxes.
Implement with PyTorch.

View File

@ -7,7 +7,7 @@ from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
func_name='mmseg.models.decode_heads.ema_head.EMAModule.forward')
def ema_module__forward(ctx, self, feats):
def ema_module__forward(self, feats):
"""Rewrite `forward` for default backend.
Replace torch.einsum with other operations.

View File

@ -7,8 +7,8 @@ from mmdeploy.utils import get_root_logger
@FUNCTION_REWRITER.register_rewriter(
func_name='mmseg.models.decode_heads.point_head.PointHead.get_points_test',
backend='tensorrt')
def point_head__get_points_test__tensorrt(ctx, self, seg_logits,
uncertainty_func, cfg):
def point_head__get_points_test__tensorrt(self, seg_logits, uncertainty_func,
cfg):
"""Sample points for testing.
1. set `num_points` no greater than TENSORRT_MAX_TOPK for tensorrt backend
@ -26,6 +26,7 @@ def point_head__get_points_test__tensorrt(ctx, self, seg_logits,
2) that contains [0, 1] x [0, 1] normalized coordinates of the
most uncertain points from the ``height x width`` grid .
"""
ctx = FUNCTION_REWRITER.get_context()
from mmdeploy.utils.constants import TENSORRT_MAX_TOPK
if cfg.subdivision_num_points > TENSORRT_MAX_TOPK:

View File

@ -7,8 +7,7 @@ from mmdeploy.utils import is_dynamic_shape
@FUNCTION_REWRITER.register_rewriter(
func_name='mmseg.models.segmentors.BaseSegmentor.forward')
def base_segmentor__forward(ctx,
self,
def base_segmentor__forward(self,
inputs,
data_samples=None,
mode='predict',
@ -27,6 +26,7 @@ def base_segmentor__forward(ctx,
Returns:
torch.Tensor: Output segmentation map pf shape [N, 1, H, W].
"""
ctx = FUNCTION_REWRITER.get_context()
if data_samples is None:
data_samples = [SegDataSample()]

View File

@ -4,8 +4,7 @@ from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
func_name='mmseg.models.segmentors.CascadeEncoderDecoder.predict')
def cascade_encoder_decoder__predict(ctx, self, inputs, data_samples,
**kwargs):
def cascade_encoder_decoder__predict(self, inputs, data_samples, **kwargs):
"""Rewrite `predict` for default backend.
1. only support mode=`whole` inference

View File

@ -5,7 +5,7 @@ from mmdeploy.utils.constants import Backend
@FUNCTION_REWRITER.register_rewriter(
func_name='mmseg.models.segmentors.EncoderDecoder.predict')
def encoder_decoder__predict(ctx, self, inputs, data_samples, **kwargs):
def encoder_decoder__predict(self, inputs, data_samples, **kwargs):
"""Rewrite `predict` for default backend.
1. only support mode=`whole` inference
@ -32,7 +32,7 @@ def encoder_decoder__predict(ctx, self, inputs, data_samples, **kwargs):
@FUNCTION_REWRITER.register_rewriter(
func_name='mmseg.models.segmentors.EncoderDecoder.predict',
backend=Backend.RKNN.value)
def encoder_decoder__predict__rknn(ctx, self, inputs, data_samples, **kwargs):
def encoder_decoder__predict__rknn(self, inputs, data_samples, **kwargs):
"""Rewrite `predict` for RKNN backend.
Early return to avoid argmax operator.

View File

@ -8,7 +8,7 @@ from mmdeploy.utils import is_dynamic_shape
@FUNCTION_REWRITER.register_rewriter(
func_name='mmseg.models.utils.UpConvBlock.forward')
def up_conv_block__forward(ctx, self, skip, x):
def up_conv_block__forward(self, skip, x):
"""Rewrite `forward` for default backend.
To support dynamic shape for UNet backbone,
@ -23,6 +23,7 @@ def up_conv_block__forward(ctx, self, skip, x):
Returns:
Tensor: Upsampled output feature map.
"""
ctx = FUNCTION_REWRITER.get_context()
from mmcv.cnn import ConvModule
# only valid when self.upsample is from build_upsample_layer

View File

@ -62,18 +62,20 @@ class Mark(torch.autograd.Function):
@FUNCTION_REWRITER.register_rewriter(
'mmdeploy.core.optimizers.function_marker.Mark.symbolic')
def mark_symbolic(rewriter, g, x, *args):
def mark_symbolic(g, x, *args):
"""Rewrite symbolic of mark op."""
if cfg_apply_marks(rewriter.cfg):
return rewriter.origin_func(g, x, *args)
ctx = FUNCTION_REWRITER.get_context()
if cfg_apply_marks(ctx.cfg):
return ctx.origin_func(g, x, *args)
return x
@FUNCTION_REWRITER.register_rewriter(
'mmdeploy.core.optimizers.function_marker.Mark.forward')
def forward_of_mark(rewriter, ctx, x, dtype, shape, func, func_id, type, name,
id, attrs) -> torch.Tensor:
def forward_of_mark(ctx, x, dtype, shape, func, func_id, type, name, id,
attrs) -> torch.Tensor:
"""Rewrite forward of mark op."""
rewriter = FUNCTION_REWRITER.get_context()
deploy_cfg = rewriter.cfg
# save calib data
apply_marks = cfg_apply_marks(deploy_cfg)
@ -182,7 +184,7 @@ def mark_tensors(xs: Any, func: str, func_id: int, io_type: str, ctx: Any,
@FUNCTION_REWRITER.register_rewriter(
'mmdeploy.core.optimizers.function_marker.mark_tensors', ir=IR.TORCHSCRIPT)
def remove_mark__torchscript(ctx, xs: Any, *args, **kwargs):
def remove_mark__torchscript(xs: Any, *args, **kwargs):
"""Disable all marks for TorchScript backend.
As the Node `mark` is not able to be traced, we just return original input
@ -216,12 +218,15 @@ def mark(func_name: Optional[str] = None,
>>> from mmdeploy.core import FUNCTION_REWRITER, mark
>>> @FUNCTION_REWRITER.register_rewriter(
>>> func_name='mmdet.models.roi_heads.ConvFCBBoxHead.forward')
>>> @mark(
>>> 'bbox_head_forward',
>>> inputs=['bbox_feats'],
>>> outputs=['cls_score', 'bbox_pred'])
>>> def forward_of_bbox_head(ctx, self, x):
>>> return ctx.origin_func(self, x)
>>> def forward_of_bbox_head(self, x):
>>> ctx = FUNCTION_REWRITER.get_context()
>>> @mark(
>>> 'bbox_head_forward',
>>> inputs=['bbox_feats'],
>>> outputs=['cls_score', 'bbox_pred'])
>>> def _impl():
>>> return ctx.origin_func(self, x)
>>> return _impl()
"""
MARK_FUNCTION_COUNT[func_name] = 0

View File

@ -1,11 +1,25 @@
# Copyright (c) OpenMMLab. All rights reserved.
import types
from collections import defaultdict
from typing import (Any, Callable, Dict, List, MutableSequence, Optional,
Tuple, Union)
from mmdeploy.utils import IR, Backend, get_root_logger
from .rewriter_utils import (Checker, ContextCaller, RewriterRegistry,
copy_function, get_frame_func, get_func_qualname,
import_function)
try:
try:
# torch>=1.10.0
from torch.fx._symbolic_trace import _wrapped_fns_to_patch
except ImportError:
# 1.10.0>torch>=1.8.0
from torch.fx.symbolic_trace import _wrapped_fns_to_patch
except ImportError:
# torch<1.8.0
_wrapped_fns_to_patch = []
def _replace_all_obj(obj: Any,
new_obj: Any,
@ -92,6 +106,24 @@ def _del_func(path: str):
continue
def _fx_wrap_copied_fn(func: types.FunctionType,
copied_func: types.FunctionType):
"""If a function is wrapped by torch.fx.wrap, its copy also needs to be
wrapped by torch.fx.wrap."""
if not hasattr(func, '__globals__'):
return
wrapped_fns_globals = [item[0] for item in _wrapped_fns_to_patch]
wrapped_fns_names = [item[1] for item in _wrapped_fns_to_patch]
# check if wrapped by torch.fx.wrap
if func.__globals__ in wrapped_fns_globals:
idx = wrapped_fns_globals.index(func.__globals__)
fn_name = wrapped_fns_names[idx]
# a hacky way to wrap the func in copied func
_wrapped_fns_to_patch.append((copied_func.__globals__, fn_name))
class FunctionRewriter:
"""A function rewriter which maintains rewritten functions.
@ -102,7 +134,8 @@ class FunctionRewriter:
Examples:
>>> @FUNCTION_REWRITER.register_rewriter(
>>> func_name='torch.Tensor.size', backend='ncnn')
>>> def size_of_tensor_static(ctx, self, *args):
>>> def size_of_tensor_static(self, *args):
>>> ctx = FUNCTION_REWRITER.get_context()
>>> ret = ctx.origin_func(self, *args)
>>> if isinstance(ret, torch.Tensor):
>>> ret = int(ret)
@ -114,6 +147,7 @@ class FunctionRewriter:
def __init__(self):
self._registry = RewriterRegistry()
self._func_contexts = defaultdict(list)
def register_rewriter(
self,
@ -140,8 +174,11 @@ class FunctionRewriter:
def enter(self, cfg: Dict = dict(), env: Dict = dict(), **kwargs):
"""The implementation of function rewrite."""
self._func_contexts.clear()
# Get current records
functions_records = self._registry.get_records(env)
# Get current fx wrapped func nums
self._ori_fx_wrap_num = len(_wrapped_fns_to_patch)
self._origin_functions = list()
self._additional_functions = list()
@ -181,15 +218,25 @@ class FunctionRewriter:
# Create context_caller
rewrite_function = record_dict['_object']
# The func before and after copy has different globals
rewrite_function = copy_function(rewrite_function)
extra_kwargs = kwargs.copy()
extra_kwargs.update(record_dict)
context_caller = ContextCaller(
rewrite_function, origin_func, cfg,
**extra_kwargs).get_wrapped_caller()
context_caller = ContextCaller(rewrite_function, origin_func,
cfg, **extra_kwargs)
# If there is a function wrapped by torch.fx.wrap in
# rewrite_function's globals, we need to wrap the same name
# function in copied function's globals.
_fx_wrap_copied_fn(record_dict['_object'], context_caller.func)
qualname = get_func_qualname(rewrite_function)
self._func_contexts[qualname].append(context_caller)
self._func_contexts[function_path].append(context_caller)
# Cache new the function to avoid homonymic bug
new_functions.append(
dict(func_path=function_path, origin_func=context_caller))
dict(
func_path=function_path, origin_func=rewrite_function))
for func_dict in new_functions:
function_path = func_dict['func_path']
@ -199,9 +246,46 @@ class FunctionRewriter:
def exit(self):
"""Recover the function rewrite."""
# Restore _wrapped_fns_to_patch
cur_fx_wrap_num = len(_wrapped_fns_to_patch)
for _ in range(cur_fx_wrap_num - self._ori_fx_wrap_num):
_wrapped_fns_to_patch.pop(-1)
for func_dict in self._origin_functions:
func_path = func_dict['func_path']
func = func_dict['origin_func']
_set_func(func_path, func)
for func_path in self._additional_functions:
_del_func(func_path)
self._func_contexts.clear()
def get_context(self, key: Optional[str] = None) -> ContextCaller:
"""Get the context of rewriter.
Args:
key: key to the context.
Returns:
ContextCaller: context of function
"""
func = None
if key is None:
func = get_frame_func(2)
key = get_func_qualname(func)
# get all contexts
ctxs = self._func_contexts.get(key, [])
if func is None:
assert len(ctxs) == 1
return ctxs[0]
ctx = None
for tmp_ctx in ctxs:
if tmp_ctx.func == func:
ctx = tmp_ctx
if ctx is None:
get_root_logger().warning(f'Can not found context of {key}')
return ctx

View File

@ -1,5 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import functools
import inspect
import types
import warnings
from abc import ABCMeta, abstractmethod
from functools import wraps
@ -342,8 +344,9 @@ class ContextCaller:
Example:
>>> @FUNCTION_REWRITER.register_rewriter(func_name='torch.add')
>>> def func(ctx, x, y):
>>> def func(x, y):
>>> # ctx is an instance of ContextCaller
>>> ctx = FUNCTION_REWRITER.get_context()
>>> print(ctx.cfg)
>>> return x + y
"""
@ -379,3 +382,61 @@ class ContextCaller:
return self.func(self, *args, **kwargs)
return wrapper
def get_func_qualname(func: Callable) -> str:
"""get function name."""
assert isinstance(func, Callable), f'{func} is not a Callable object.'
_func_name = None
if hasattr(func, '__qualname__'):
_func_name = f'{func.__module__}.{func.__qualname__}'
elif hasattr(func, '__class__'):
_func_name = func.__class__
else:
_func_name = str(func)
return _func_name
def get_frame_func(top: int = 1) -> Callable:
"""get func of frame."""
frameinfo = inspect.stack()[top]
frame = frameinfo.frame
g_vars = frame.f_globals
func_name = frameinfo.function
assert func_name in g_vars, \
f'Can not find function: {func_name} in global.'
func = g_vars[func_name]
return func
def get_frame_qualname(top: int = 1) -> str:
"""get frame name."""
frameinfo = inspect.stack()[top]
frame = frameinfo.frame
g_vars = frame.f_globals
func_name = frameinfo.function
assert func_name in g_vars, \
f'Can not find function: {func_name} in global.'
func = g_vars[func_name]
module_name = inspect.getmodule(func).__name__
return f'{module_name}.{func_name}'
def copy_function(f: types.FunctionType):
"""Copy the function."""
# copy the global so we can get different func for different origin
glb = f.__globals__.copy()
name = f.__name__
g = types.FunctionType(
f.__code__,
glb,
name=name,
argdefs=f.__defaults__,
closure=f.__closure__)
g = functools.update_wrapper(g, f)
g.__kwdefaults__ = f.__kwdefaults__
glb[name] = g
return g

View File

@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from collections import defaultdict
from typing import Callable, Dict, List, Optional, Sequence, Union
import torch
@ -7,7 +8,8 @@ from torch.onnx.symbolic_helper import parse_args
from mmdeploy.utils import IR, Backend, get_root_logger
from .rewriter_utils import (Checker, ContextCaller, RewriterRegistry,
eval_with_import)
copy_function, eval_with_import, get_frame_func,
get_func_qualname)
class SymbolicRewriter:
@ -21,7 +23,7 @@ class SymbolicRewriter:
Examples:
>>> @SYMBOLIC_REWRITER.register_symbolic('squeeze', \
>>> is_pytorch=True)
>>> def squeeze_default(ctx, g, self, dim=None):
>>> def squeeze_default(g, self, dim=None):
>>> if dim is None:
>>> dims = []
>>> for i, size in enumerate(self.type().sizes()):
@ -34,6 +36,7 @@ class SymbolicRewriter:
def __init__(self) -> None:
self._registry = RewriterRegistry()
self._func_contexts = defaultdict(list)
def register_symbolic(self,
func_name: str,
@ -75,6 +78,9 @@ class SymbolicRewriter:
opset: int = 11,
**kwargs):
"""The implementation of symbolic register."""
# clear context
self._func_contexts.clear()
# Get current records
symbolic_records = self._registry.get_records(env)
@ -84,19 +90,27 @@ class SymbolicRewriter:
for function_name, record_dict in symbolic_records:
symbolic_function = record_dict['_object']
symbolic_function = copy_function(symbolic_function)
arg_descriptors = record_dict['arg_descriptors']
extra_kwargs = kwargs.copy()
extra_kwargs.update(record_dict)
context_caller = ContextCaller(symbolic_function, None, cfg,
**extra_kwargs)
# register context
qualname = get_func_qualname(symbolic_function)
self._func_contexts[qualname].append(context_caller)
self._func_contexts[function_name].append(context_caller)
if arg_descriptors is not None and len(arg_descriptors) > 0:
context_caller = parse_args(*arg_descriptors)(context_caller)
symbolic_function = parse_args(*arg_descriptors)(
symbolic_function)
is_pytorch = record_dict['is_pytorch']
if is_pytorch:
from torch.onnx import register_custom_op_symbolic
register_custom_op_symbolic(f'::{function_name}',
context_caller, opset)
symbolic_function, opset)
# Save domain and version
self._pytorch_symbolic.append((function_name, '', opset))
@ -123,7 +137,7 @@ class SymbolicRewriter:
self._extra_symbolic.append((origin_func, origin_symbolic))
# Cache new the function to avoid homonymic bug
new_functions.append((origin_func, context_caller))
new_functions.append((origin_func, symbolic_function))
for origin_func, new_func in new_functions:
origin_symbolic = getattr(origin_func, 'symbolic', None)
@ -132,6 +146,9 @@ class SymbolicRewriter:
def exit(self):
"""The implementation of symbolic unregister."""
# clear context
self._func_contexts.clear()
# Unregister pytorch op
if hasattr(torch.onnx, 'unregister_custom_op_symbolic'):
from torch.onnx import unregister_custom_op_symbolic
@ -149,3 +166,33 @@ class SymbolicRewriter:
# Unregister custom op
for origin_func, origin_symbolic in self._extra_symbolic:
origin_func.symbolic = origin_symbolic
def get_context(self, key: Optional[str] = None) -> ContextCaller:
"""Get the context of rewriter.
Args:
key: key to the context.
Returns:
ContextCaller: context of function
"""
func = None
if key is None:
func = get_frame_func(2)
key = get_func_qualname(func)
# get all contexts
ctxs = self._func_contexts.get(key, [])
if func is None:
assert len(ctxs) == 1
return ctxs[0]
ctx = None
for tmp_ctx in ctxs:
if tmp_ctx.func == func:
ctx = tmp_ctx
if ctx is None:
get_root_logger().warning(f'Can not found context of {key}')
return ctx

View File

@ -1,5 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from . import conv2d_adaptive_padding # noqa: F401,F403
from .transformer import MultiHeadAttentionop
__all__ = ['MultiHeadAttentionop']

View File

@ -1,86 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
import torch
import torch.nn.functional as F
from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.utils import Backend, is_dynamic_batch, is_dynamic_shape
def compute_padding(input_size, kernel_size, stride, dilation):
"""Compute padding."""
input_h, input_w = input_size
kernel_h, kernel_w = kernel_size
stride_h, stride_w = stride
dilation_h, dilation_w = dilation
output_h = math.ceil(input_h / stride_h)
output_w = math.ceil(input_w / stride_w)
pad_h = max(
(output_h - 1) * stride_h + (kernel_h - 1) * dilation_h + 1 - input_h,
0)
pad_w = max(
(output_w - 1) * stride_w + (kernel_w - 1) * dilation_w + 1 - input_w,
0)
if pad_w > 0 or pad_h > 0:
padded = [
pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2
]
else:
padded = None
return padded
class AdaptivePadOp(torch.autograd.Function):
"""Dummy adaptive pad op."""
@staticmethod
def forward(ctx, x, padded):
if padded is not None:
x = F.pad(x, padded)
return x
@staticmethod
def symbolic(g, x, padded):
if padded is None:
return g.op('Identity', x)
padded = g.op(
'Constant', value_t=torch.tensor(padded, dtype=torch.int64))
constant_value = g.op(
'Constant', value_t=torch.tensor(0, dtype=torch.int64))
return g.op(
'Pad', x, padded, constant_value, mode_s='constant', outputs=1)
@FUNCTION_REWRITER.register_rewriter(
func_name='mmcv.cnn.bricks.conv2d_adaptive_padding. \
Conv2dAdaptivePadding.forward',
backend=Backend.TENSORRT.value)
def conv2d_adaptive_padding__forward__tensorrt(ctx, self, x):
"""Rewrite `forward` of Conv2dAdaptivePadding used in EfficientNet for
TensorRT backend. Main changes of this rewritten function is to separate
the computation of padding and encapsulate it into another
`torch.autograd.Function` so that the adaptive padding could be parsed as
`Pad` ops in ONNX with the padding information computed in advance (Only
for static shape configuration).
Args:
x (Tensor): Input tensor of Conv2dAdaptivePadding ops
Returns:
Tensor: forward result of 2D convolution after padding
"""
deploy_cfg = ctx.cfg
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
if (not is_dynamic_flag) or is_dynamic_batch(deploy_cfg):
padded = compute_padding(x.shape[2:], self.weight.shape[2:],
self.stride, self.dilation)
if padded is not None:
padded = [int(_) for _ in padded]
x = AdaptivePadOp.apply(x, padded)
return F.conv2d(x, self.weight, self.bias, self.stride, self.padding,
self.dilation, self.groups)
else:
x = ctx.origin_func(x)
return x

View File

@ -57,8 +57,7 @@ class MultiHeadAttentionop(torch.autograd.Function):
@FUNCTION_REWRITER.register_rewriter(
func_name='mmcv.cnn.bricks.transformer.MultiheadAttention.forward',
backend=Backend.NCNN.value)
def multiheadattention__forward__ncnn(ctx,
self,
def multiheadattention__forward__ncnn(self,
query,
key=None,
value=None,

View File

@ -4,8 +4,7 @@ from mmdeploy.core import SYMBOLIC_REWRITER
@SYMBOLIC_REWRITER.register_symbolic(
'mmcv.ops.deform_conv.DeformConv2dFunction')
def deform_conv__default(ctx,
g,
def deform_conv__default(g,
input,
offset,
weight,
@ -31,8 +30,7 @@ def deform_conv__default(ctx,
@SYMBOLIC_REWRITER.register_symbolic(
'mmcv.ops.deform_conv.DeformConv2dFunction', backend='openvino')
def deform_conv_openvino(ctx,
g,
def deform_conv_openvino(g,
input,
offset,
weight,

View File

@ -4,9 +4,8 @@ from mmdeploy.core import SYMBOLIC_REWRITER
@SYMBOLIC_REWRITER.register_symbolic(
'mmcv.ops.modulated_deform_conv.ModulatedDeformConv2dFunction')
def modulated_deform_conv_default(ctx, g, input, offset, mask, weight, bias,
stride, padding, dilation, groups,
deform_groups):
def modulated_deform_conv_default(g, input, offset, mask, weight, bias, stride,
padding, dilation, groups, deform_groups):
"""Rewrite mdcn symbolic function for all backend."""
input_tensors = [input, offset, mask, weight]
if bias is not None:

View File

@ -82,7 +82,6 @@ class ONNXNMSop(torch.autograd.Function):
Returns:
NonMaxSuppression op for onnx.
"""
if not sym_help._is_value(max_output_boxes_per_class):
max_output_boxes_per_class = g.op(
'Constant',
@ -354,8 +353,7 @@ def _multiclass_nms_single(boxes: Tensor,
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdeploy.mmcv.ops.nms._multiclass_nms')
def multiclass_nms__default(ctx,
boxes: Tensor,
def multiclass_nms__default(boxes: Tensor,
scores: Tensor,
max_output_boxes_per_class: int = 1000,
iou_threshold: float = 0.5,
@ -388,6 +386,7 @@ def multiclass_nms__default(ctx,
tuple[Tensor, Tensor]: (dets, labels), `dets` of shape [N, num_det, 5]
and `labels` of shape [N, num_det].
"""
ctx = FUNCTION_REWRITER.get_context()
deploy_cfg = ctx.cfg
batch_size = boxes.size(0)
if not is_dynamic_batch(deploy_cfg) and batch_size == 1:
@ -414,8 +413,7 @@ def multiclass_nms__default(ctx,
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdeploy.mmcv.ops.nms._multiclass_nms', backend='tensorrt')
def multiclass_nms_static(ctx,
boxes: Tensor,
def multiclass_nms_static(boxes: Tensor,
scores: Tensor,
max_output_boxes_per_class: int = 1000,
iou_threshold: float = 0.5,
@ -479,12 +477,35 @@ def multiclass_nms_static(ctx,
'multiclass_nms',
inputs=['boxes', 'scores'],
outputs=['dets', 'labels', 'index'])
def multiclass_nms(*args, nms_type='nms', **kwargs):
def multiclass_nms(boxes: Tensor,
scores: Tensor,
max_output_boxes_per_class: int = 1000,
iou_threshold: float = 0.5,
score_threshold: float = 0.05,
pre_top_k: int = -1,
keep_top_k: int = -1,
output_index: bool = False,
nms_type='nms'):
"""Apis for multiclass nms."""
if nms_type == 'nms':
return _multiclass_nms(*args, **kwargs)
return _multiclass_nms(
boxes,
scores,
max_output_boxes_per_class=max_output_boxes_per_class,
iou_threshold=iou_threshold,
score_threshold=score_threshold,
pre_top_k=pre_top_k,
keep_top_k=keep_top_k,
output_index=output_index)
elif nms_type == 'nms_rotated':
return multiclass_nms_rotated(*args, **kwargs)
return multiclass_nms_rotated(
boxes,
scores,
max_output_boxes_per_class=max_output_boxes_per_class,
iou_threshold=iou_threshold,
score_threshold=score_threshold,
pre_top_k=pre_top_k,
keep_top_k=keep_top_k)
else:
raise NotImplementedError(f'Unsupported nms type: {nms_type}.')
@ -492,8 +513,7 @@ def multiclass_nms(*args, nms_type='nms', **kwargs):
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdeploy.mmcv.ops.nms.bbox_nms._multiclass_nms',
backend=Backend.COREML.value)
def multiclass_nms__coreml(ctx,
boxes: Tensor,
def multiclass_nms__coreml(boxes: Tensor,
scores: Tensor,
max_output_boxes_per_class: int = 1000,
iou_threshold: float = 0.5,
@ -556,8 +576,7 @@ def multiclass_nms__coreml(ctx,
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdeploy.mmcv.ops.nms.bbox_nms._multiclass_nms',
ir=IR.TORCHSCRIPT)
def multiclass_nms__torchscript(ctx,
boxes: Tensor,
def multiclass_nms__torchscript(boxes: Tensor,
scores: Tensor,
max_output_boxes_per_class: int = 1000,
iou_threshold: float = 0.5,
@ -659,8 +678,7 @@ class AscendBatchNMSOp(torch.autograd.Function):
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdeploy.mmcv.ops.nms.bbox_nms._multiclass_nms',
backend='ascend')
def multiclass_nms__ascend(ctx,
boxes: Tensor,
def multiclass_nms__ascend(boxes: Tensor,
scores: Tensor,
max_output_boxes_per_class: int = 1000,
iou_threshold: float = 0.5,

View File

@ -272,8 +272,7 @@ def _multiclass_nms_rotated(boxes: Tensor,
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdeploy.mmcv.ops.nms_rotated._multiclass_nms_rotated',
backend='tensorrt')
def multiclass_nms_rotated__tensorrt(ctx,
boxes: Tensor,
def multiclass_nms_rotated__tensorrt(boxes: Tensor,
scores: Tensor,
max_output_boxes_per_class: int = 1000,
iou_threshold: float = 0.5,
@ -317,7 +316,19 @@ def multiclass_nms_rotated__tensorrt(ctx,
'multiclass_nms_rotated',
inputs=['boxes', 'scores'],
outputs=['dets', 'labels'])
def multiclass_nms_rotated(*args, **kwargs):
def multiclass_nms_rotated(boxes: Tensor,
scores: Tensor,
max_output_boxes_per_class: int = 1000,
iou_threshold: float = 0.1,
score_threshold: float = 0.05,
pre_top_k: int = -1,
keep_top_k: int = -1):
"""Wrapper function for `_multiclass_nms`."""
return mmdeploy.mmcv.ops.nms_rotated._multiclass_nms_rotated(
*args, **kwargs)
boxes,
scores,
max_output_boxes_per_class=max_output_boxes_per_class,
iou_threshold=iou_threshold,
score_threshold=score_threshold,
pre_top_k=pre_top_k,
keep_top_k=keep_top_k)

View File

@ -6,7 +6,7 @@ from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
func_name='mmcv.ops.point_sample', backend='default')
def point_sample__default(ctx, input, points, align_corners=False, **kwargs):
def point_sample__default(input, points, align_corners=False, **kwargs):
"""A wrapper around :func:`grid_sample` to support 3D point_coords tensors
Unlike :func:`torch.nn.functional.grid_sample` it assumes point_coords to
lie inside ``[0, 1] x [0, 1]`` square.
@ -37,7 +37,7 @@ def point_sample__default(ctx, input, points, align_corners=False, **kwargs):
@FUNCTION_REWRITER.register_rewriter(
func_name='mmcv.ops.SimpleRoIAlign.forward')
def simple_roialign__forward(ctx, self, features, rois):
def simple_roialign__forward(self, features, rois):
"""Rewrite `forward` of SimpleRoIAlign.
Args:

View File

@ -13,9 +13,9 @@ from mmdeploy.utils import Backend, get_backend, get_ir_config
# visible in mmcv.
@SYMBOLIC_REWRITER.register_symbolic(
'mmcv.ops.roi_align.__self__', backend='default')
def roi_align_default(ctx, g, input: Tensor, rois: Tensor,
output_size: List[int], spatial_scale: float,
sampling_ratio: int, pool_mode: str, aligned: bool):
def roi_align_default(g, input: Tensor, rois: Tensor, output_size: List[int],
spatial_scale: float, sampling_ratio: int,
pool_mode: str, aligned: bool):
"""Rewrite symbolic function for default backend.
Replace onnx::RoiAlign with mmcv::MMCVRoiAlign for PPLNN. For ONNXRuntime,
@ -41,6 +41,7 @@ def roi_align_default(ctx, g, input: Tensor, rois: Tensor,
Returns:
MMCVRoiAlign op for onnx.
"""
ctx = SYMBOLIC_REWRITER.get_context()
backend = get_backend(ctx.cfg)
if backend == Backend.PPLNN or backend == Backend.TENSORRT:
domain = 'mmcv'

View File

@ -11,7 +11,7 @@ from mmdeploy.core import SYMBOLIC_REWRITER
# is not visible in mmcv.
@SYMBOLIC_REWRITER.register_symbolic(
'mmcv.ops.roi_align_rotated.__self__', backend='default')
def roi_align_rotated_default(ctx, g, input: Tensor, rois: Tensor,
def roi_align_rotated_default(g, input: Tensor, rois: Tensor,
output_size: List[int], spatial_scale: float,
sampling_ratio: int, aligned: bool,
clockwise: bool):

View File

@ -6,7 +6,7 @@ from mmdeploy.utils import Backend
@FUNCTION_REWRITER.register_rewriter(
func_name='mmcv.cnn.bricks.transformer.PatchEmbed.forward',
backend=Backend.NCNN.value)
def patch_embed__forward__ncnn(ctx, self, x):
def patch_embed__forward__ncnn(self, x):
"""Rewrite `forward` of PatchEmbed for ncnn backend.
Args:

View File

@ -9,8 +9,9 @@ from mmdeploy.utils import Backend, get_root_logger, is_dynamic_shape
@FUNCTION_REWRITER.register_rewriter(
func_name='torch.nn.functional.adaptive_avg_pool2d')
def adaptive_avg_pool2d__default(ctx, input, output_size):
def adaptive_avg_pool2d__default(input, output_size):
"""Rewrite `adaptive_avg_pool2d` for default backend."""
ctx = FUNCTION_REWRITER.get_context()
output_size = _pair(output_size)
if int(output_size[0]) == int(output_size[1]) == 1:
out = ctx.origin_func(input, output_size)
@ -39,6 +40,7 @@ def adaptive_avg_pool2d__default(ctx, input, output_size):
@FUNCTION_REWRITER.register_rewriter(
func_name='torch.nn.functional.adaptive_avg_pool2d',
backend=Backend.TORCHSCRIPT.value)
def adaptive_avg_pool2d__ncnn(ctx, input, output_size):
def adaptive_avg_pool2d__ncnn(input, output_size):
ctx = FUNCTION_REWRITER.get_context()
"""Rewrite `adaptive_avg_pool2d` for ncnn and torchscript backend."""
return ctx.origin_func(input, output_size)

View File

@ -7,7 +7,6 @@ from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
func_name='torch.atan2', backend='default')
def atan2__default(
ctx,
input1: torch.Tensor,
input2: torch.Tensor,
):

View File

@ -7,7 +7,7 @@ from mmdeploy.utils import IR
@FUNCTION_REWRITER.register_rewriter(
func_name='torch.Tensor.chunk', backend='ncnn')
def chunk__ncnn(ctx, self, num_chunks: int, dim: int = 0) -> torch.Tensor:
def chunk__ncnn(self, num_chunks: int, dim: int = 0) -> torch.Tensor:
"""Rewrite `chunk` for NCNN backend.
Chunk in ncnn are not supported, so it should be rewritten.
@ -36,10 +36,7 @@ def chunk__ncnn(ctx, self, num_chunks: int, dim: int = 0) -> torch.Tensor:
@FUNCTION_REWRITER.register_rewriter(
func_name='torch.Tensor.chunk', ir=IR.TORCHSCRIPT)
def chunk__torchscript(ctx,
self,
num_chunks: int,
dim: int = 0) -> torch.Tensor:
def chunk__torchscript(self, num_chunks: int, dim: int = 0) -> torch.Tensor:
"""Rewrite `chunk` for Torchscript.
Replace chunk op with split op

View File

@ -13,11 +13,12 @@ from mmdeploy.utils import Backend
func_name='torch.Tensor.clamp', backend=Backend.COREML.value)
@FUNCTION_REWRITER.register_rewriter(
func_name='torch.clamp', backend=Backend.COREML.value)
def clip__coreml(ctx, input, min=None, max=None, **kwargs) -> torch.Tensor:
def clip__coreml(input, min=None, max=None, **kwargs) -> torch.Tensor:
"""Rewrite `clip` for coreml backend.
Cast data type.
"""
ctx = FUNCTION_REWRITER.get_context()
if min is not None and not isinstance(min, torch.Tensor):
min = input.new_tensor(min)

View File

@ -6,11 +6,12 @@ from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
func_name='torch.Tensor.expand', backend='ncnn')
def expand__ncnn(ctx, self, *sizes) -> torch.Tensor:
def expand__ncnn(self, *sizes) -> torch.Tensor:
"""Rewrite `expand` for NCNN backend.
Do not expand on batch dim for tensor with ndim >= 3
"""
ctx = FUNCTION_REWRITER.get_context()
if self.ndim < 3 or sizes[0] not in [1, -1]:
return ctx.origin_func(*sizes)
return self

View File

@ -13,7 +13,7 @@ from mmdeploy.utils import Backend
func_name='torch.Tensor.flatten', backend=Backend.COREML.value)
@FUNCTION_REWRITER.register_rewriter(
func_name='torch.flatten', backend=Backend.COREML.value)
def flatten__coreml(ctx, input, start_dim=0, end_dim=-1) -> torch.Tensor:
def flatten__coreml(input, start_dim=0, end_dim=-1) -> torch.Tensor:
"""Rewrite `flatten` for coreml backend.
Use reshape instead of flatten

View File

@ -6,13 +6,14 @@ from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
func_name='torch.Tensor.__getattribute__', backend='ncnn')
def tensor__getattribute__ncnn(ctx, self: torch.Tensor, name: str):
def tensor__getattribute__ncnn(self: torch.Tensor, name: str):
"""Rewrite `__getattribute__` of `torch.Tensor` for ncnn backend.
Shape node is not supported by ncnn. This function transform dynamic shape
to constant shape.
"""
ctx = FUNCTION_REWRITER.get_context()
ret = ctx.origin_func(self, name)
if name == 'shape':
ret = torch.Size([int(s) for s in ret])

View File

@ -9,7 +9,6 @@ from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
func_name='torch.nn.functional.group_norm', backend='ncnn')
def group_norm__ncnn(
ctx,
input: torch.Tensor,
num_groups: int,
weight: Union[torch.Tensor, torch.NoneType] = None,

View File

@ -10,8 +10,7 @@ from mmdeploy.utils import Backend, get_root_logger
@FUNCTION_REWRITER.register_rewriter(
func_name='torch.nn.functional.interpolate', backend='ncnn')
def interpolate__ncnn(ctx,
input: torch.Tensor,
def interpolate__ncnn(input: torch.Tensor,
size: Optional[Union[int, Tuple[int], Tuple[int, int],
Tuple[int, int, int]]] = None,
scale_factor: Optional[Union[float,
@ -24,6 +23,7 @@ def interpolate__ncnn(ctx,
ncnn require `size` should be constant in ONNX Node. We use `scale_factor`
instead of `size` to avoid dynamic size.
"""
ctx = FUNCTION_REWRITER.get_context()
input_size = input.shape
if scale_factor is None:
@ -42,8 +42,7 @@ def interpolate__ncnn(ctx,
@FUNCTION_REWRITER.register_rewriter(
func_name='torch.nn.functional.interpolate', backend='rknn')
def interpolate__rknn(ctx,
input: torch.Tensor,
def interpolate__rknn(input: torch.Tensor,
size: Optional[Union[int, Tuple[int], Tuple[int, int],
Tuple[int, int, int]]] = None,
scale_factor: Optional[Union[float,
@ -56,6 +55,7 @@ def interpolate__rknn(ctx,
rknn require `size` should be constant in ONNX Node. We use `scale_factor`
instead of `size` to avoid dynamic size.
"""
ctx = FUNCTION_REWRITER.get_context()
input_size = input.shape
if scale_factor is None:
scale_factor = [(s_out / s_in)
@ -77,7 +77,6 @@ def interpolate__rknn(ctx,
is_pytorch=True,
backend=Backend.TENSORRT.value)
def interpolate__tensorrt(
ctx,
input: torch.Tensor,
size: Optional[Union[int, Tuple[int], Tuple[int, int], Tuple[int, int,
int]]] = None,
@ -87,6 +86,7 @@ def interpolate__tensorrt(
recompute_scale_factor: Optional[bool] = None,
):
"""Register default symbolic function for `interpolate`."""
ctx = FUNCTION_REWRITER.get_context()
class BicubicInterpolate(Function):

View File

@ -30,7 +30,6 @@ class GemmOp(torch.autograd.Function):
@FUNCTION_REWRITER.register_rewriter(
func_name='torch.nn.functional.linear', backend='ncnn')
def linear__ncnn(
ctx,
input: torch.Tensor,
weight: torch.Tensor,
bias: Optional[Union[torch.Tensor, torch.NoneType]] = None,
@ -41,6 +40,7 @@ def linear__ncnn(
add extra reshape and transpose to support linear operation of different
input shape.
"""
ctx = FUNCTION_REWRITER.get_context()
origin_func = ctx.origin_func
dim = input.dim()

View File

@ -13,13 +13,14 @@ from mmdeploy.utils.constants import Backend
@FUNCTION_REWRITER.register_rewriter(
func_name='torch.Tensor.masked_fill', backend=Backend.ONNXRUNTIME.value)
def masked_fill__onnxruntime(
ctx, input, mask: torch.Tensor, value: Union[torch.Tensor,
Number]) -> torch.Tensor:
input, mask: torch.Tensor, value: Union[torch.Tensor,
Number]) -> torch.Tensor:
"""Rewrite `masked_fill` for onnxruntime backend.
SATRN model as example, when value is set to `float('-inf')`, the results
of ORT inferencing turns out to be NAN.
"""
ctx = FUNCTION_REWRITER.get_context()
if value == float('-inf'):
value = -1e34 # hard coding number
return ctx.origin_func(input, mask, value)

View File

@ -11,10 +11,11 @@ from mmdeploy.utils.constants import Backend
# TODO add version control when MOD is supported by TensorRT
@FUNCTION_REWRITER.register_rewriter(
func_name='torch.Tensor.__mod__', backend=Backend.TENSORRT.value)
def mod__tensorrt(ctx, input: torch.Tensor, other: Union[torch.Tensor,
torch.NumberType],
*args, **kwargs) -> torch.Tensor:
def mod__tensorrt(input: torch.Tensor, other: Union[torch.Tensor,
torch.NumberType], *args,
**kwargs) -> torch.Tensor:
"""Rewrite `mod` when exporting model to ONNX for TensorRT backend."""
ctx = FUNCTION_REWRITER.get_context()
if version.parse(torch.__version__) > version.parse('1.10.0'):
return input - (input // other) * other
return ctx.origin_func(input, other, *args, **kwargs)

View File

@ -46,7 +46,6 @@ class ScaledDotProductAttentionTRT(torch.autograd.Function):
func_name='torch.nn.functional._scaled_dot_product_attention',
backend=Backend.TENSORRT.value)
def _scaled_dot_product_attention__tensorrt(
ctx,
q: Tensor,
k: Tensor,
v: Tensor,

View File

@ -7,8 +7,7 @@ from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
func_name='torch.nn.functional.normalize', backend='ncnn')
def normalize__ncnn(ctx,
input: torch.Tensor,
def normalize__ncnn(input: torch.Tensor,
p: int = 2,
dim: int = 1,
eps: float = 1e-12,
@ -18,6 +17,7 @@ def normalize__ncnn(ctx,
Make sure L2 norm on channel dim and be exported to ncnn correctly.
"""
ctx = FUNCTION_REWRITER.get_context()
if dim < 0:
dim += input.ndim
assert dim != 0, 'Should not normalize on batch index'

View File

@ -11,7 +11,7 @@ from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
func_name='torch.onnx.symbolic_opset11._prepare_onnx_paddings',
backend='tensorrt')
def _prepare_onnx_paddings__tensorrt(ctx, g, input, pad):
def _prepare_onnx_paddings__tensorrt(g, input, pad):
"""Rewrite `_prepare_onnx_paddings` for TensorRT backend.
For codes like `x = torch.nn.ZeroPad2d((0, a, 0, b))(x)`, where a and b are
@ -26,6 +26,7 @@ def _prepare_onnx_paddings__tensorrt(ctx, g, input, pad):
..., dim_m_begin, dim_m_end,
where m is in range [0, n].
"""
ctx = FUNCTION_REWRITER.get_context()
torch_version = version_parse(torch.__version__)
if torch_version.major == 1 and torch_version.minor < 10:
return ctx.origin_func(g, input, pad)

Some files were not shown because too many files have changed in this diff Show More