[Refactor] Refactor rewriter context for MMRazor (#1483)
* wip * update rewriter * Support all codebase * update docs * fix ssd * rename qualname * support torch.fx.wrap * import by torch version Co-authored-by: pppppM <gjf_mail@126.com>pull/1536/head
parent
78901a2451
commit
3f261e6d50
demo
docs
en/07-developer-guide
zh_cn/07-developer-guide
mmdeploy
apis/onnx
codebase
mmaction/models/recognizers
mmcls/models
mmdet
deploy
models
detectors
task_modules
structures/bbox
mmdet3d/models
mmedit/models/base_models
mmocr/models
mmpose/models
heads
pose_estimators
mmrotate
models
dense_heads
task_modules
structures
mmseg/models
decode_heads
utils
core
optimizers
|
@ -13,7 +13,7 @@ from mmdeploy.utils import get_root_logger
|
|||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='torchvision.models.ResNet._forward_impl')
|
||||
def forward_of_resnet(ctx, self, x):
|
||||
def forward_of_resnet(self, x):
|
||||
"""Rewrite the forward implementation of resnet.
|
||||
|
||||
Early return the feature map after two down-sampling steps.
|
||||
|
|
|
@ -13,13 +13,13 @@ from mmdeploy.core import FUNCTION_REWRITER, mark
|
|||
|
||||
@mark(
|
||||
'detector_forward', inputs=['input'], outputs=['dets', 'labels', 'masks'])
|
||||
def __forward_impl(ctx, self, img, img_metas=None, **kwargs):
|
||||
def __forward_impl(self, img, img_metas=None, **kwargs):
|
||||
...
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmdet.models.detectors.base.BaseDetector.forward')
|
||||
def base_detector__forward(ctx, self, img, img_metas=None, **kwargs):
|
||||
def base_detector__forward(self, img, img_metas=None, **kwargs):
|
||||
...
|
||||
# call the mark function
|
||||
return __forward_impl(...)
|
||||
|
@ -32,8 +32,7 @@ from mmdeploy.core import FUNCTION_REWRITER, mark
|
|||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdet.models.dense_heads.YOLOV3Head.get_bboxes')
|
||||
def yolov3_head__get_bboxes(ctx,
|
||||
self,
|
||||
def yolov3_head__get_bboxes(self,
|
||||
pred_maps,
|
||||
img_metas,
|
||||
cfg=None,
|
||||
|
|
|
@ -11,7 +11,8 @@ from mmdeploy.core import FUNCTION_REWRITER
|
|||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='torch.Tensor.repeat', backend='tensorrt')
|
||||
def repeat_static(ctx, input, *size):
|
||||
def repeat_static(input, *size):
|
||||
ctx = FUNCTION_REWRITER.get_context()
|
||||
origin_func = ctx.origin_func
|
||||
if input.dim() == 1 and len(size) == 1:
|
||||
return origin_func(input.unsqueeze(0), *([1] + list(size))).squeeze(0)
|
||||
|
@ -72,7 +73,7 @@ The mappings between PyTorch and ONNX are defined in PyTorch with symbolic funct
|
|||
|
||||
```python
|
||||
@SYMBOLIC_REWRITER.register_symbolic('squeeze', is_pytorch=True)
|
||||
def squeeze_default(ctx, g, self, dim=None):
|
||||
def squeeze_default(g, self, dim=None):
|
||||
if dim is None:
|
||||
dims = []
|
||||
for i, size in enumerate(self.type().sizes()):
|
||||
|
|
|
@ -18,7 +18,7 @@ class BaseClassifier(BaseModule, metaclass=ABCMeta):
|
|||
# Custom rewritten function
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmcls.models.classifiers.BaseClassifier.forward', backend='default')
|
||||
def forward_of_base_classifier(ctx, self, img, *args, **kwargs):
|
||||
def forward_of_base_classifier(self, img, *args, **kwargs):
|
||||
"""Rewrite `forward` for default backend."""
|
||||
return self.simple_test(img, {})
|
||||
```
|
||||
|
@ -63,7 +63,8 @@ In the first example, the output is generated in Python. Sometimes we may make b
|
|||
# Custom rewritten function
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmseg.models.segmentors.BaseSegmentor.forward')
|
||||
def base_segmentor__forward(ctx, self, img, img_metas=None, **kwargs):
|
||||
def base_segmentor__forward(self, img, img_metas=None, **kwargs):
|
||||
ctx = FUNCTION_REWRITER.get_context()
|
||||
if img_metas is None:
|
||||
img_metas = {}
|
||||
assert isinstance(img_metas, dict)
|
||||
|
|
|
@ -13,13 +13,13 @@ from mmdeploy.core import FUNCTION_REWRITER, mark
|
|||
|
||||
@mark(
|
||||
'detector_forward', inputs=['input'], outputs=['dets', 'labels', 'masks'])
|
||||
def __forward_impl(ctx, self, img, img_metas=None, **kwargs):
|
||||
def __forward_impl(self, img, img_metas=None, **kwargs):
|
||||
...
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmdet.models.detectors.base.BaseDetector.forward')
|
||||
def base_detector__forward(ctx, self, img, img_metas=None, **kwargs):
|
||||
def base_detector__forward(self, img, img_metas=None, **kwargs):
|
||||
...
|
||||
# call the mark function
|
||||
return __forward_impl(...)
|
||||
|
@ -32,8 +32,7 @@ from mmdeploy.core import FUNCTION_REWRITER, mark
|
|||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdet.models.dense_heads.YOLOV3Head.get_bboxes')
|
||||
def yolov3_head__get_bboxes(ctx,
|
||||
self,
|
||||
def yolov3_head__get_bboxes(self,
|
||||
pred_maps,
|
||||
img_metas,
|
||||
cfg=None,
|
||||
|
|
|
@ -10,7 +10,8 @@ PyTorch 神经网络是用 python 编写的,可以简化算法的开发。但
|
|||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='torch.Tensor.repeat', backend='tensorrt')
|
||||
def repeat_static(ctx, input, *size):
|
||||
def repeat_static(input, *size):
|
||||
ctx = FUNCTION_REWRITER.get_context()
|
||||
origin_func = ctx.origin_func
|
||||
if input.dim() == 1 and len(size) == 1:
|
||||
return origin_func(input.unsqueeze(0), *([1] + list(size))).squeeze(0)
|
||||
|
@ -67,7 +68,7 @@ PyTorch 和 ONNX 之间的映射是通过 PyTorch 中的符号函数进行定义
|
|||
|
||||
```python
|
||||
@SYMBOLIC_REWRITER.register_symbolic('squeeze', is_pytorch=True)
|
||||
def squeeze_default(ctx, g, self, dim=None):
|
||||
def squeeze_default(g, self, dim=None):
|
||||
if dim is None:
|
||||
dims = []
|
||||
for i, size in enumerate(self.type().sizes()):
|
||||
|
|
|
@ -18,7 +18,7 @@ class BaseClassifier(BaseModule, metaclass=ABCMeta):
|
|||
# Custom rewritten function
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmcls.models.classifiers.BaseClassifier.forward', backend='default')
|
||||
def forward_of_base_classifier(ctx, self, img, *args, **kwargs):
|
||||
def forward_of_base_classifier(self, img, *args, **kwargs):
|
||||
"""Rewrite `forward` for default backend."""
|
||||
return self.simple_test(img, {})
|
||||
```
|
||||
|
@ -63,7 +63,8 @@ def test_baseclassfier_forward():
|
|||
# Custom rewritten function
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmseg.models.segmentors.BaseSegmentor.forward')
|
||||
def base_segmentor__forward(ctx, self, img, img_metas=None, **kwargs):
|
||||
def base_segmentor__forward(self, img, img_metas=None, **kwargs):
|
||||
ctx = FUNCTION_REWRITER.get_context()
|
||||
if img_metas is None:
|
||||
img_metas = {}
|
||||
assert isinstance(img_metas, dict)
|
||||
|
|
|
@ -116,6 +116,15 @@ def export(model: torch.nn.Module,
|
|||
input_metas, dict
|
||||
), f'Expect input_metas type is dict, get {type(input_metas)}.'
|
||||
model_forward = patched_model.forward
|
||||
|
||||
def wrap_forward(forward):
|
||||
|
||||
def wrapper(*arg, **kwargs):
|
||||
return forward(*arg, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
patched_model.forward = wrap_forward(patched_model.forward)
|
||||
patched_model.forward = partial(patched_model.forward,
|
||||
**input_metas)
|
||||
|
||||
|
|
|
@ -5,8 +5,9 @@ from mmdeploy.core import FUNCTION_REWRITER
|
|||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter('torch.onnx.utils._model_to_graph')
|
||||
def model_to_graph__custom_optimizer(ctx, *args, **kwargs):
|
||||
def model_to_graph__custom_optimizer(*args, **kwargs):
|
||||
"""Rewriter of _model_to_graph, add custom passes."""
|
||||
ctx = FUNCTION_REWRITER.get_context()
|
||||
graph, params_dict, torch_out = ctx.origin_func(*args, **kwargs)
|
||||
|
||||
custom_passes = getattr(ctx, 'onnx_custom_passes', None)
|
||||
|
@ -23,8 +24,7 @@ def model_to_graph__custom_optimizer(ctx, *args, **kwargs):
|
|||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'torch._C._jit_pass_onnx_deduplicate_initializers', backend='tensorrt')
|
||||
def jit_pass_onnx_deduplicate_initializers__disable(ctx, graph, param_dict,
|
||||
arg2):
|
||||
def jit_pass_onnx_deduplicate_initializers__disable(graph, param_dict, arg2):
|
||||
"""This pass will disable TensorRT topk export.
|
||||
|
||||
disable for TensorRT.
|
||||
|
@ -34,6 +34,6 @@ def jit_pass_onnx_deduplicate_initializers__disable(ctx, graph, param_dict,
|
|||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'torch._C._jit_pass_onnx_autograd_function_process')
|
||||
def jit_pass_onnx_autograd_function_process__disable(ctx, graph):
|
||||
def jit_pass_onnx_autograd_function_process__disable(graph):
|
||||
"""Disable process autograph function."""
|
||||
return
|
||||
|
|
|
@ -8,8 +8,7 @@ from mmdeploy.core import FUNCTION_REWRITER
|
|||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmaction.models.recognizers.BaseRecognizer.forward')
|
||||
def base_recognizer__forward(ctx,
|
||||
self,
|
||||
def base_recognizer__forward(self,
|
||||
inputs: Tensor,
|
||||
data_samples: OptSampleList = None,
|
||||
mode: str = 'tensor',
|
||||
|
|
|
@ -6,7 +6,7 @@ from mmdeploy.core import FUNCTION_REWRITER
|
|||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmcls.models.backbones.shufflenet_v2.InvertedResidual.forward')
|
||||
def shufflenetv2_backbone__forward__default(ctx, self, x):
|
||||
def shufflenetv2_backbone__forward__default(self, x):
|
||||
"""Rewrite `forward` of InvertedResidual used in shufflenet_v2.
|
||||
|
||||
The chunk in original InvertedResidual.forward will convert to dynamic
|
||||
|
|
|
@ -9,7 +9,7 @@ from mmdeploy.utils import Backend
|
|||
func_name= # noqa: E251
|
||||
'mmcls.models.backbones.vision_transformer.VisionTransformer.forward',
|
||||
backend=Backend.NCNN.value)
|
||||
def visiontransformer__forward__ncnn(ctx, self, x):
|
||||
def visiontransformer__forward__ncnn(self, x):
|
||||
"""Rewrite `forward` of VisionTransformer for ncnn backend.
|
||||
|
||||
The chunk in original VisionTransformer.forward will convert
|
||||
|
|
|
@ -12,7 +12,6 @@ from mmdeploy.core import FUNCTION_REWRITER
|
|||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmcls.models.classifiers.BaseClassifier.forward', backend='default')
|
||||
def base_classifier__forward(
|
||||
ctx,
|
||||
self,
|
||||
batch_inputs: Tensor,
|
||||
data_samples: Optional[List[BaseDataElement]] = None,
|
||||
|
|
|
@ -9,7 +9,7 @@ from mmdeploy.utils import Backend
|
|||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmcls.models.necks.GlobalAveragePooling.forward',
|
||||
backend=Backend.DEFAULT.value)
|
||||
def gap__forward(ctx, self, inputs):
|
||||
def gap__forward(self, inputs):
|
||||
"""Rewrite `forward` of GlobalAveragePooling for default backend.
|
||||
|
||||
Replace `view` with `flatten` to export simple onnx graph.
|
||||
|
|
|
@ -11,7 +11,7 @@ from mmdeploy.utils import Backend, get_dynamic_axes
|
|||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmcls.models.utils.attention.MultiheadAttention.forward',
|
||||
backend=Backend.NCNN.value)
|
||||
def multiheadattention__forward__ncnn(ctx, self, qkv_input):
|
||||
def multiheadattention__forward__ncnn(self, qkv_input):
|
||||
"""Rewrite `forward` of MultiheadAttention used in vision_transformer for
|
||||
ncnn backend.
|
||||
|
||||
|
@ -53,12 +53,13 @@ def multiheadattention__forward__ncnn(ctx, self, qkv_input):
|
|||
func_name= # noqa: E251
|
||||
'mmcls.models.utils.ShiftWindowMSA.forward',
|
||||
extra_checkers=LibVersionChecker('mmcls', min_version='0.21.0'))
|
||||
def shift_window_msa__forward__default(ctx, self, query, hw_shape):
|
||||
def shift_window_msa__forward__default(self, query, hw_shape):
|
||||
"""Rewrite forward function of ShiftWindowMSA class for TensorRT.
|
||||
|
||||
1. replace dynamic padding with static padding and dynamic slice.
|
||||
2. always do slice `x = x[:, :H, :W, :].contiguous()` for stability.
|
||||
"""
|
||||
ctx = FUNCTION_REWRITER.get_context()
|
||||
if get_dynamic_axes(ctx.cfg) is None:
|
||||
# avoid the weird bug of torch to onnx
|
||||
return ctx.origin_func(self, query, hw_shape)
|
||||
|
@ -142,8 +143,7 @@ def shift_window_msa__forward__default(ctx, self, query, hw_shape):
|
|||
func_name= # noqa: E251
|
||||
'mmcls.models.utils.ShiftWindowMSA.get_attn_mask',
|
||||
extra_checkers=LibVersionChecker('mmcls', min_version='0.21.0'))
|
||||
def shift_window_msa__get_attn_mask__default(ctx,
|
||||
self,
|
||||
def shift_window_msa__get_attn_mask__default(self,
|
||||
hw_shape,
|
||||
window_size,
|
||||
shift_size,
|
||||
|
|
|
@ -76,7 +76,7 @@ def clip_bboxes(x1: Tensor, y1: Tensor, x2: Tensor, y2: Tensor,
|
|||
func_name='mmdeploy.codebase.mmdet.deploy.utils.clip_bboxes',
|
||||
backend='tensorrt',
|
||||
extra_checkers=LibVersionChecker('tensorrt', min_version='8'))
|
||||
def clip_bboxes__trt8(ctx, x1: Tensor, y1: Tensor, x2: Tensor, y2: Tensor,
|
||||
def clip_bboxes__trt8(x1: Tensor, y1: Tensor, x2: Tensor, y2: Tensor,
|
||||
max_shape: Union[Tensor, Sequence[int]]):
|
||||
"""Clip bboxes for onnx. From TensorRT 8 we can do the operators on the
|
||||
tensors directly.
|
||||
|
@ -165,7 +165,6 @@ def __pad_with_value_if_necessary(x: Tensor,
|
|||
'mmdeploy.codebase.mmdet.deploy.utils.__pad_with_value_if_necessary',
|
||||
backend=Backend.TENSORRT.value)
|
||||
def __pad_with_value_if_necessary__tensorrt(
|
||||
ctx,
|
||||
x: Tensor,
|
||||
pad_dim: int,
|
||||
pad_size: int,
|
||||
|
@ -223,12 +222,12 @@ class TRTGatherTopk(torch.autograd.Function):
|
|||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmdeploy.codebase.mmdet.deploy.utils.__gather_topk',
|
||||
backend=Backend.TENSORRT.value)
|
||||
def __gather_topk__trt(ctx,
|
||||
*inputs: Sequence[torch.Tensor],
|
||||
def __gather_topk__trt(*inputs: Sequence[torch.Tensor],
|
||||
inds: torch.Tensor,
|
||||
batch_size: int,
|
||||
is_batched: bool = True) -> Tuple[torch.Tensor]:
|
||||
"""TensorRT gather_topk."""
|
||||
ctx = FUNCTION_REWRITER.get_context()
|
||||
_ = ctx
|
||||
if is_batched:
|
||||
index_shape = inds.shape
|
||||
|
@ -253,8 +252,7 @@ def __gather_topk__trt(ctx,
|
|||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmdeploy.codebase.mmdet.deploy.utils.__gather_topk',
|
||||
backend=Backend.COREML.value)
|
||||
def __gather_topk__nonbatch(ctx,
|
||||
*inputs: Sequence[torch.Tensor],
|
||||
def __gather_topk__nonbatch(*inputs: Sequence[torch.Tensor],
|
||||
inds: torch.Tensor,
|
||||
batch_size: int,
|
||||
is_batched: bool = True) -> Tuple[torch.Tensor]:
|
||||
|
|
|
@ -7,7 +7,7 @@ from mmdeploy.utils import get_common_config
|
|||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdet.models.backbones.csp_darknet.Focus.forward')
|
||||
def focus__forward__default(ctx, self, x):
|
||||
def focus__forward__default(self, x):
|
||||
"""Rewrite forward function of Focus class.
|
||||
|
||||
Replace slice with transpose.
|
||||
|
@ -27,7 +27,7 @@ def focus__forward__default(ctx, self, x):
|
|||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdet.models.backbones.csp_darknet.Focus.forward',
|
||||
backend='ncnn')
|
||||
def focus__forward__ncnn(ctx, self, x):
|
||||
def focus__forward__ncnn(self, x):
|
||||
"""Rewrite forward function of Focus class for ncnn.
|
||||
|
||||
Focus width and height information into channel space. ncnn does not
|
||||
|
@ -69,7 +69,7 @@ def focus__forward__ncnn(ctx, self, x):
|
|||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdet.models.backbones.swin.WindowMSA.forward',
|
||||
backend='tensorrt')
|
||||
def windowmsa__forward__tensorrt(ctx, self, x, mask=None):
|
||||
def windowmsa__forward__tensorrt(self, x, mask=None):
|
||||
"""Rewrite forward function of WindowMSA class for TensorRT.
|
||||
|
||||
1. replace Gather operation of qkv with split.
|
||||
|
@ -80,6 +80,7 @@ def windowmsa__forward__tensorrt(ctx, self, x, mask=None):
|
|||
mask (tensor | None, Optional): mask with shape of (num_windows,
|
||||
Wh*Ww, Wh*Ww), value should be between (-inf, 0].
|
||||
"""
|
||||
ctx = FUNCTION_REWRITER.get_context()
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
|
||||
-1).permute(2, 0, 3, 1, 4).contiguous()
|
||||
|
@ -129,7 +130,7 @@ def windowmsa__forward__tensorrt(ctx, self, x, mask=None):
|
|||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdet.models.backbones.swin.ShiftWindowMSA.window_reverse',
|
||||
backend='tensorrt')
|
||||
def shift_window_msa__window_reverse__tensorrt(ctx, self, windows, H, W):
|
||||
def shift_window_msa__window_reverse__tensorrt(self, windows, H, W):
|
||||
"""Rewrite window_reverse function of ShiftWindowMSA class for TensorRT.
|
||||
For TensorRT, seems radical shape transformations are not allowed. Replace
|
||||
them with soft ones.
|
||||
|
@ -155,7 +156,7 @@ def shift_window_msa__window_reverse__tensorrt(ctx, self, windows, H, W):
|
|||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdet.models.backbones.swin.ShiftWindowMSA.window_partition',
|
||||
backend='tensorrt')
|
||||
def shift_window_msa__window_partition__tensorrt(ctx, self, x):
|
||||
def shift_window_msa__window_partition__tensorrt(self, x):
|
||||
"""Rewrite window_partition function of ShiftWindowMSA class for TensorRT.
|
||||
For TensorRT, seems radical shape transformations are not allowed. Replace
|
||||
them with soft ones.
|
||||
|
@ -176,7 +177,7 @@ def shift_window_msa__window_partition__tensorrt(ctx, self, x):
|
|||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdet.models.backbones.swin.ShiftWindowMSA.forward')
|
||||
def shift_window_msa__forward__default(ctx, self, query, hw_shape):
|
||||
def shift_window_msa__forward__default(self, query, hw_shape):
|
||||
"""Rewrite forward function of ShiftWindowMSA class.
|
||||
|
||||
1. replace dynamic padding with static padding and dynamic slice.
|
||||
|
|
|
@ -24,7 +24,6 @@ from mmdeploy.utils import Backend, is_dynamic_shape
|
|||
func_name='mmdet.models.dense_heads.base_dense_head.'
|
||||
'BaseDenseHead.predict_by_feat')
|
||||
def base_dense_head__predict_by_feat(
|
||||
ctx,
|
||||
self,
|
||||
cls_scores: List[Tensor],
|
||||
bbox_preds: List[Tensor],
|
||||
|
@ -66,6 +65,7 @@ def base_dense_head__predict_by_feat(
|
|||
tuple[Tensor, Tensor, Tensor]: batch_mlvl_bboxes,
|
||||
batch_mlvl_scores, batch_mlvl_centerness
|
||||
"""
|
||||
ctx = FUNCTION_REWRITER.get_context()
|
||||
deploy_cfg = ctx.cfg
|
||||
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
|
||||
num_levels = len(cls_scores)
|
||||
|
@ -211,7 +211,6 @@ def base_dense_head__predict_by_feat(
|
|||
'BaseDenseHead.predict_by_feat',
|
||||
backend=Backend.RKNN.value)
|
||||
def base_dense_head__predict_by_feat__rknn(
|
||||
ctx,
|
||||
self,
|
||||
cls_scores: List[Tensor],
|
||||
bbox_preds: List[Tensor],
|
||||
|
@ -253,6 +252,8 @@ def base_dense_head__predict_by_feat__rknn(
|
|||
tuple[Tensor, Tensor, Tensor]: batch_mlvl_bboxes,
|
||||
batch_mlvl_scores, batch_mlvl_centerness
|
||||
"""
|
||||
ctx = FUNCTION_REWRITER.get_context()
|
||||
|
||||
# mark nodes for partition
|
||||
@mark('BaseDenseHead', outputs=['BaseDenseHead.cls', 'BaseDenseHead.loc'])
|
||||
def __mark_dense_head(cls_scores, bbox_preds):
|
||||
|
@ -337,7 +338,6 @@ def base_dense_head__predict_by_feat__rknn(
|
|||
'BaseDenseHead.predict_by_feat',
|
||||
backend=Backend.NCNN.value)
|
||||
def base_dense_head__predict_by_feat__ncnn(
|
||||
ctx,
|
||||
self,
|
||||
cls_scores: List[Tensor],
|
||||
bbox_preds: List[Tensor],
|
||||
|
@ -376,6 +376,7 @@ def base_dense_head__predict_by_feat__ncnn(
|
|||
Returns:
|
||||
output__ncnn (Tensor): outputs, shape is [N, num_det, 6].
|
||||
"""
|
||||
ctx = FUNCTION_REWRITER.get_context()
|
||||
assert len(cls_scores) == len(bbox_preds)
|
||||
deploy_cfg = ctx.cfg
|
||||
assert not is_dynamic_shape(deploy_cfg), 'base_dense_head for ncnn\
|
||||
|
|
|
@ -9,7 +9,6 @@ from mmdeploy.core import FUNCTION_REWRITER
|
|||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmdet.models.dense_heads.centernet_head.CenterNetHead.predict_by_feat')
|
||||
def centernet_head__predict_by_feat__default(
|
||||
ctx,
|
||||
self,
|
||||
center_heatmap_preds: List[Tensor],
|
||||
wh_preds: List[Tensor],
|
||||
|
|
|
@ -10,7 +10,7 @@ from mmdeploy.core import FUNCTION_REWRITER
|
|||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmdet.models.dense_heads.DETRHead.forward_single')
|
||||
def detrhead__forward_single__default(ctx, self, x, img_metas):
|
||||
def detrhead__forward_single__default(self, x, img_metas):
|
||||
"""forward_single of DETRHead.
|
||||
|
||||
Ease the mask computation
|
||||
|
@ -35,8 +35,7 @@ def detrhead__forward_single__default(ctx, self, x, img_metas):
|
|||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmdet.models.dense_heads.DETRHead.predict_by_feat')
|
||||
def detrhead__predict_by_feat__default(ctx,
|
||||
self,
|
||||
def detrhead__predict_by_feat__default(self,
|
||||
all_cls_scores_list: List[Tensor],
|
||||
all_bbox_preds_list: List[Tensor],
|
||||
batch_img_metas: List[dict],
|
||||
|
|
|
@ -13,8 +13,7 @@ from mmdeploy.mmcv.ops import multiclass_nms
|
|||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmdet.models.dense_heads.fovea_head.FoveaHead.predict_by_feat')
|
||||
def fovea_head__predict_by_feat(ctx,
|
||||
self,
|
||||
def fovea_head__predict_by_feat(self,
|
||||
cls_scores: List[Tensor],
|
||||
bbox_preds: List[Tensor],
|
||||
score_factors: Optional[List[Tensor]] = None,
|
||||
|
@ -49,6 +48,7 @@ def fovea_head__predict_by_feat(ctx,
|
|||
`dets` of shape [N, num_det, 5] and `labels` of shape
|
||||
[N, num_det].
|
||||
"""
|
||||
ctx = FUNCTION_REWRITER.get_context()
|
||||
assert len(cls_scores) == len(bbox_preds)
|
||||
cfg = self.test_cfg if cfg is None else cfg
|
||||
num_levels = len(cls_scores)
|
||||
|
|
|
@ -18,8 +18,7 @@ from mmdeploy.utils import Backend, get_backend, is_dynamic_shape
|
|||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdet.models.dense_heads.gfl_head.'
|
||||
'GFLHead.predict_by_feat')
|
||||
def gfl_head__predict_by_feat(ctx,
|
||||
self,
|
||||
def gfl_head__predict_by_feat(self,
|
||||
cls_scores: List[Tensor],
|
||||
bbox_preds: List[Tensor],
|
||||
score_factors: Optional[List[Tensor]] = None,
|
||||
|
@ -58,6 +57,7 @@ def gfl_head__predict_by_feat(ctx,
|
|||
tuple[Tensor, Tensor, Tensor]: batch_mlvl_bboxes,
|
||||
batch_mlvl_scores, batch_mlvl_centerness
|
||||
"""
|
||||
ctx = FUNCTION_REWRITER.get_context()
|
||||
deploy_cfg = ctx.cfg
|
||||
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
|
||||
backend = get_backend(deploy_cfg)
|
||||
|
|
|
@ -35,12 +35,13 @@ def _bbox_post_decode(bboxes: torch.Tensor, max_shape: Sequence[int]):
|
|||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmdet.models.dense_heads.reppoints_head.RepPointsHead.points2bbox')
|
||||
def reppoints_head__points2bbox(ctx, self, pts, y_first=True):
|
||||
def reppoints_head__points2bbox(self, pts, y_first=True):
|
||||
"""Rewrite of `points2bbox` in `RepPointsHead`.
|
||||
|
||||
Use `self.moment_transfer` in `points2bbox` will cause error:
|
||||
RuntimeError: Input, output and indices must be on the current device
|
||||
"""
|
||||
ctx = FUNCTION_REWRITER.get_context()
|
||||
update_moment = hasattr(self, 'moment_transfer')
|
||||
if update_moment:
|
||||
moment_transfer = self.moment_transfer
|
||||
|
@ -55,7 +56,6 @@ def reppoints_head__points2bbox(ctx, self, pts, y_first=True):
|
|||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmdet.models.dense_heads.reppoints_head.RepPointsHead.predict_by_feat')
|
||||
def reppoints_head__predict_by_feat(
|
||||
ctx,
|
||||
self,
|
||||
cls_scores: List[Tensor],
|
||||
bbox_preds: List[Tensor],
|
||||
|
@ -91,6 +91,7 @@ def reppoints_head__predict_by_feat(
|
|||
`dets` of shape [N, num_det, 5] and `labels` of shape
|
||||
[N, num_det].
|
||||
"""
|
||||
ctx = FUNCTION_REWRITER.get_context()
|
||||
deploy_cfg = ctx.cfg
|
||||
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
|
||||
num_levels = len(cls_scores)
|
||||
|
|
|
@ -16,8 +16,7 @@ from mmdeploy.utils import Backend, is_dynamic_shape
|
|||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdet.models.dense_heads.rpn_head.'
|
||||
'RPNHead.predict_by_feat')
|
||||
def rpn_head__predict_by_feat(ctx,
|
||||
self,
|
||||
def rpn_head__predict_by_feat(self,
|
||||
cls_scores: List[Tensor],
|
||||
bbox_preds: List[Tensor],
|
||||
score_factors: Optional[List[Tensor]] = None,
|
||||
|
@ -61,6 +60,7 @@ def rpn_head__predict_by_feat(ctx,
|
|||
tuple[Tensor, Tensor, Tensor]: batch_mlvl_bboxes,
|
||||
batch_mlvl_scores, batch_mlvl_centerness
|
||||
"""
|
||||
ctx = FUNCTION_REWRITER.get_context()
|
||||
img_metas = batch_img_metas
|
||||
assert len(cls_scores) == len(bbox_preds)
|
||||
deploy_cfg = ctx.cfg
|
||||
|
@ -166,8 +166,7 @@ def rpn_head__predict_by_feat(ctx,
|
|||
# TODO: Fix for 1.x
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmdet.models.dense_heads.RPNHead.get_bboxes', backend=Backend.NCNN.value)
|
||||
def rpn_head__get_bboxes__ncnn(ctx,
|
||||
self,
|
||||
def rpn_head__get_bboxes__ncnn(self,
|
||||
cls_scores,
|
||||
bbox_preds,
|
||||
img_metas,
|
||||
|
@ -204,6 +203,7 @@ def rpn_head__get_bboxes__ncnn(ctx,
|
|||
Else:
|
||||
tuple[Tensor, Tensor]: batch_mlvl_bboxes, batch_mlvl_scores
|
||||
"""
|
||||
ctx = FUNCTION_REWRITER.get_context()
|
||||
assert len(cls_scores) == len(bbox_preds)
|
||||
deploy_cfg = ctx.cfg
|
||||
assert not is_dynamic_shape(deploy_cfg)
|
||||
|
|
|
@ -14,8 +14,7 @@ from mmdeploy.mmcv.ops import multiclass_nms
|
|||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdet.models.dense_heads.rtmdet_head.'
|
||||
'RTMDetHead.predict_by_feat')
|
||||
def rtmdet_head__predict_by_feat(ctx,
|
||||
self,
|
||||
def rtmdet_head__predict_by_feat(self,
|
||||
cls_scores: List[Tensor],
|
||||
bbox_preds: List[Tensor],
|
||||
batch_img_metas: Optional[List[dict]] = None,
|
||||
|
@ -52,6 +51,7 @@ def rtmdet_head__predict_by_feat(ctx,
|
|||
tensor in the tuple is (N, num_box), and each element
|
||||
represents the class label of the corresponding box.
|
||||
"""
|
||||
ctx = FUNCTION_REWRITER.get_context()
|
||||
assert len(cls_scores) == len(bbox_preds)
|
||||
device = cls_scores[0].device
|
||||
cfg = self.test_cfg if cfg is None else cfg
|
||||
|
|
|
@ -15,8 +15,7 @@ from mmdeploy.utils import Backend, is_dynamic_shape
|
|||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdet.models.dense_heads.yolo_head.'
|
||||
'YOLOV3Head.predict_by_feat')
|
||||
def yolov3_head__predict_by_feat(ctx,
|
||||
self,
|
||||
def yolov3_head__predict_by_feat(self,
|
||||
pred_maps: Sequence[Tensor],
|
||||
cfg: OptConfigType = None,
|
||||
rescale: bool = False,
|
||||
|
@ -47,6 +46,7 @@ def yolov3_head__predict_by_feat(ctx,
|
|||
Else:
|
||||
tuple[Tensor, Tensor]: batch_mlvl_bboxes, batch_mlvl_scores
|
||||
"""
|
||||
ctx = FUNCTION_REWRITER.get_context()
|
||||
deploy_cfg = ctx.cfg
|
||||
|
||||
# mark pred_maps
|
||||
|
@ -154,8 +154,7 @@ def yolov3_head__predict_by_feat(ctx,
|
|||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdet.models.dense_heads.YOLOV3Head.predict_by_feat',
|
||||
backend=Backend.NCNN.value)
|
||||
def yolov3_head__predict_by_feat__ncnn(ctx,
|
||||
self,
|
||||
def yolov3_head__predict_by_feat__ncnn(self,
|
||||
pred_maps,
|
||||
with_nms=True,
|
||||
cfg=None,
|
||||
|
@ -188,6 +187,7 @@ def yolov3_head__predict_by_feat__ncnn(ctx,
|
|||
fore-ground class label in Yolov3DetectionOutput starts
|
||||
from `1`. x1, y1, x2, y2 are normalized in range(0,1).
|
||||
"""
|
||||
ctx = FUNCTION_REWRITER.get_context()
|
||||
num_levels = len(pred_maps)
|
||||
cfg = self.test_cfg if cfg is None else cfg
|
||||
post_params = get_post_processing_params(ctx.cfg)
|
||||
|
|
|
@ -15,8 +15,7 @@ from mmdeploy.utils import Backend
|
|||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdet.models.dense_heads.yolox_head.'
|
||||
'YOLOXHead.predict_by_feat')
|
||||
def yolox_head__predict_by_feat(ctx,
|
||||
self,
|
||||
def yolox_head__predict_by_feat(self,
|
||||
cls_scores: List[Tensor],
|
||||
bbox_preds: List[Tensor],
|
||||
objectnesses: Optional[List[Tensor]],
|
||||
|
@ -57,6 +56,8 @@ def yolox_head__predict_by_feat(ctx,
|
|||
tensor in the tuple is (N, num_box), and each element
|
||||
represents the class label of the corresponding box.
|
||||
"""
|
||||
ctx = FUNCTION_REWRITER.get_context()
|
||||
|
||||
# mark pred_maps
|
||||
@mark('yolo_head', inputs=['cls_scores', 'bbox_preds', 'objectnesses'])
|
||||
def __mark_pred_maps(cls_scores, bbox_preds, objectnesses):
|
||||
|
@ -125,7 +126,6 @@ def yolox_head__predict_by_feat(ctx,
|
|||
'YOLOXHead.predict_by_feat',
|
||||
backend=Backend.NCNN.value)
|
||||
def yolox_head__predict_by_feat__ncnn(
|
||||
ctx,
|
||||
self,
|
||||
cls_scores: List[Tensor],
|
||||
bbox_preds: List[Tensor],
|
||||
|
@ -169,6 +169,7 @@ def yolox_head__predict_by_feat__ncnn(
|
|||
Returns:
|
||||
output__ncnn (Tensor): outputs, shape is [N, num_det, 6].
|
||||
"""
|
||||
ctx = FUNCTION_REWRITER.get_context()
|
||||
from mmdeploy.codebase.mmdet.ops import ncnn_detection_output_forward
|
||||
from mmdeploy.utils import get_root_logger
|
||||
from mmdeploy.utils.config_utils import is_dynamic_shape
|
||||
|
|
|
@ -12,7 +12,7 @@ from mmdeploy.utils import is_dynamic_shape
|
|||
|
||||
@mark(
|
||||
'detector_forward', inputs=['input'], outputs=['dets', 'labels', 'masks'])
|
||||
def __forward_impl(ctx, self, batch_inputs, data_samples, **kwargs):
|
||||
def __forward_impl(self, batch_inputs, data_samples):
|
||||
"""Rewrite and adding mark for `forward`.
|
||||
|
||||
Encapsulate this function for rewriting `forward` of BaseDetector.
|
||||
|
@ -25,10 +25,30 @@ def __forward_impl(ctx, self, batch_inputs, data_samples, **kwargs):
|
|||
return output
|
||||
|
||||
|
||||
@torch.fx.wrap
|
||||
def _set_metainfo(data_samples, img_shape):
|
||||
"""Set the metainfo.
|
||||
|
||||
Code in this function cannot be traced by fx.
|
||||
"""
|
||||
|
||||
# fx can not trace deepcopy correctly
|
||||
data_samples = copy.deepcopy(data_samples)
|
||||
if data_samples is None:
|
||||
data_samples = [DetDataSample()]
|
||||
|
||||
# note that we can not use `set_metainfo`, deepcopy would crash the
|
||||
# onnx trace.
|
||||
for data_sample in data_samples:
|
||||
data_sample.set_field(
|
||||
name='img_shape', value=img_shape, field_type='metainfo')
|
||||
|
||||
return data_samples
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmdet.models.detectors.single_stage.SingleStageDetector.forward')
|
||||
def single_stage_detector__forward(ctx,
|
||||
self,
|
||||
def single_stage_detector__forward(self,
|
||||
batch_inputs: torch.Tensor,
|
||||
data_samples: OptSampleList = None,
|
||||
mode: str = 'tensor',
|
||||
|
@ -53,9 +73,8 @@ def single_stage_detector__forward(ctx,
|
|||
- labels (Tensor): Labels of bboxes, has a shape
|
||||
(num_instances, ).
|
||||
"""
|
||||
data_samples = copy.deepcopy(data_samples)
|
||||
if data_samples is None:
|
||||
data_samples = [DetDataSample()]
|
||||
ctx = FUNCTION_REWRITER.get_context()
|
||||
|
||||
deploy_cfg = ctx.cfg
|
||||
|
||||
# get origin input shape as tensor to support onnx dynamic shape
|
||||
|
@ -65,11 +84,6 @@ def single_stage_detector__forward(ctx,
|
|||
img_shape = [int(val) for val in img_shape]
|
||||
|
||||
# set the metainfo
|
||||
# note that we can not use `set_metainfo`, deepcopy would crash the
|
||||
# onnx trace.
|
||||
for data_sample in data_samples:
|
||||
data_sample.set_field(
|
||||
name='img_shape', value=img_shape, field_type='metainfo')
|
||||
data_samples = _set_metainfo(data_samples, img_shape)
|
||||
|
||||
return __forward_impl(
|
||||
ctx, self, batch_inputs, data_samples=data_samples, **kwargs)
|
||||
return __forward_impl(self, batch_inputs, data_samples=data_samples)
|
||||
|
|
|
@ -11,8 +11,7 @@ from mmdeploy.utils import is_dynamic_shape
|
|||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmdet.models.detectors.two_stage.TwoStageDetector.extract_feat')
|
||||
@mark('extract_feat', inputs='img', outputs='feat')
|
||||
def two_stage_detector__extract_feat(ctx, self, img):
|
||||
def two_stage_detector__extract_feat(self, img):
|
||||
"""Rewrite `extract_feat` for default backend.
|
||||
|
||||
This function uses the specific `extract_feat` function for the two
|
||||
|
@ -27,13 +26,18 @@ def two_stage_detector__extract_feat(ctx, self, img):
|
|||
list[Tensor]: Each item with shape (N, C, H, W) corresponds one
|
||||
level of backbone and neck features.
|
||||
"""
|
||||
return ctx.origin_func(self, img)
|
||||
ctx = FUNCTION_REWRITER.get_context()
|
||||
|
||||
@mark('extract_feat', inputs='img', outputs='feat')
|
||||
def __extract_feat_impl(self, img):
|
||||
return ctx.origin_func(self, img)
|
||||
|
||||
return __extract_feat_impl(self, img)
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmdet.models.detectors.two_stage.TwoStageDetector.forward')
|
||||
def two_stage_detector__forward(ctx,
|
||||
self,
|
||||
def two_stage_detector__forward(self,
|
||||
batch_inputs: torch.Tensor,
|
||||
data_samples: OptSampleList = None,
|
||||
mode: str = 'tensor',
|
||||
|
@ -58,6 +62,7 @@ def two_stage_detector__forward(ctx,
|
|||
- labels (Tensor): Labels of bboxes, has a shape
|
||||
(num_instances, ).
|
||||
"""
|
||||
ctx = FUNCTION_REWRITER.get_context()
|
||||
data_samples = copy.deepcopy(data_samples)
|
||||
deploy_cfg = ctx.cfg
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ from mmdeploy.utils import Backend, get_root_logger
|
|||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdet.models.necks.ssd_neck.L2Norm.forward')
|
||||
def l2norm__forward__default(ctx, self, x):
|
||||
def l2norm__forward__default(self, x):
|
||||
"""Default rewriter for l2norm.
|
||||
|
||||
Implement with functinoal.normalize .
|
||||
|
@ -19,11 +19,12 @@ def l2norm__forward__default(ctx, self, x):
|
|||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdet.models.necks.ssd_neck.L2Norm.forward',
|
||||
backend=Backend.TENSORRT.value)
|
||||
def l2norm__forward__tensorrt(ctx, self, x):
|
||||
def l2norm__forward__tensorrt(self, x):
|
||||
"""rewrite `l2norm` for TensorRT.
|
||||
|
||||
TensorRT7 does not support dynamic clamp, which is used in normalize.
|
||||
"""
|
||||
ctx = FUNCTION_REWRITER.get_context()
|
||||
logger = get_root_logger()
|
||||
trt_version_major = 8
|
||||
try:
|
||||
|
@ -34,6 +35,6 @@ def l2norm__forward__tensorrt(ctx, self, x):
|
|||
except Exception:
|
||||
logger.warning('Can not get TensorRT version.')
|
||||
if trt_version_major >= 8:
|
||||
return l2norm__forward__default(ctx, self, x)
|
||||
return l2norm__forward__default(self, x)
|
||||
else:
|
||||
return ctx.origin_func(self, x)
|
||||
|
|
|
@ -17,11 +17,7 @@ from mmdeploy.mmcv.ops import multiclass_nms
|
|||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmdet.models.roi_heads.bbox_heads.convfc_bbox_head.ConvFCBBoxHead.forward'
|
||||
)
|
||||
@mark(
|
||||
'bbox_head_forward',
|
||||
inputs=['bbox_feats'],
|
||||
outputs=['cls_score', 'bbox_pred'])
|
||||
def bbox_head__forward(ctx, self, x):
|
||||
def bbox_head__forward(self, x):
|
||||
"""Rewrite `forward` for default backend.
|
||||
|
||||
This function uses the specific `forward` function for the BBoxHead
|
||||
|
@ -37,13 +33,21 @@ def bbox_head__forward(ctx, self, x):
|
|||
has shape (N, num_det, num_cls) and the bbox_pred has shape
|
||||
(N, num_det, 4).
|
||||
"""
|
||||
return ctx.origin_func(self, x)
|
||||
ctx = FUNCTION_REWRITER.get_context()
|
||||
|
||||
@mark(
|
||||
'bbox_head_forward',
|
||||
inputs=['bbox_feats'],
|
||||
outputs=['cls_score', 'bbox_pred'])
|
||||
def __forward(self, x):
|
||||
return ctx.origin_func(self, x)
|
||||
|
||||
return __forward(self, x)
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmdet.models.roi_heads.bbox_heads.bbox_head.BBoxHead.predict_by_feat')
|
||||
def bbox_head__predict_by_feat(ctx,
|
||||
self,
|
||||
def bbox_head__predict_by_feat(self,
|
||||
rois: Tuple[Tensor],
|
||||
cls_scores: Tuple[Tensor],
|
||||
bbox_preds: Tuple[Tensor],
|
||||
|
@ -75,6 +79,7 @@ def bbox_head__predict_by_feat(ctx,
|
|||
- labels (Tensor): Labels of bboxes, has a shape
|
||||
(num_instances, ).
|
||||
"""
|
||||
ctx = FUNCTION_REWRITER.get_context()
|
||||
assert rois.ndim == 3, 'Only support export two stage ' \
|
||||
'model to ONNX ' \
|
||||
'with batch dimension. '
|
||||
|
|
|
@ -11,8 +11,7 @@ from mmdeploy.core import FUNCTION_REWRITER
|
|||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmdet.models.roi_heads.cascade_roi_head.CascadeRoIHead.predict_bbox')
|
||||
def cascade_roi_head__predict_bbox(ctx,
|
||||
self,
|
||||
def cascade_roi_head__predict_bbox(self,
|
||||
x: Tuple[Tensor],
|
||||
batch_img_metas: List[dict],
|
||||
rpn_results_list: List[Tensor],
|
||||
|
@ -86,8 +85,7 @@ def cascade_roi_head__predict_bbox(ctx,
|
|||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmdet.models.roi_heads.cascade_roi_head.CascadeRoIHead.predict_mask')
|
||||
def cascade_roi_head__predict_mask(ctx,
|
||||
self,
|
||||
def cascade_roi_head__predict_mask(self,
|
||||
x: Tuple[Tensor],
|
||||
batch_img_metas: List[dict],
|
||||
results_list: List[Tensor],
|
||||
|
|
|
@ -14,8 +14,7 @@ from mmdeploy.utils import Backend, get_backend
|
|||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmdet.models.roi_heads.'
|
||||
'mask_heads.fcn_mask_head.FCNMaskHead.predict_by_feat')
|
||||
def fcn_mask_head__predict_by_feat(ctx,
|
||||
self,
|
||||
def fcn_mask_head__predict_by_feat(self,
|
||||
mask_preds: Tuple[Tensor],
|
||||
results_list: List[Tensor],
|
||||
batch_img_metas: List[dict],
|
||||
|
@ -48,6 +47,7 @@ def fcn_mask_head__predict_by_feat(ctx,
|
|||
(num_instances, ).
|
||||
- masks (Tensor): Has a shape (num_instances, H, W).
|
||||
"""
|
||||
ctx = FUNCTION_REWRITER.get_context()
|
||||
ori_shape = batch_img_metas[0]['img_shape']
|
||||
dets, det_labels = results_list
|
||||
dets = dets.view(-1, 5)
|
||||
|
|
|
@ -61,13 +61,12 @@ class MultiLevelRoiAlign(Function):
|
|||
(num_proposals, channel, output_size[1], output_size[0]))
|
||||
|
||||
|
||||
@mark('roi_extractor', inputs=['feats', 'rois'], outputs=['bbox_feats'])
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmdet.models.roi_heads.roi_extractors.'
|
||||
'single_level_roi_extractor.SingleRoIExtractor.forward',
|
||||
backend='tensorrt')
|
||||
@mark('roi_extractor', inputs=['feats', 'rois'], outputs=['bbox_feats'])
|
||||
def single_roi_extractor__forward__tensorrt(ctx,
|
||||
self,
|
||||
def single_roi_extractor__forward__tensorrt(self,
|
||||
feats,
|
||||
rois,
|
||||
roi_scale_factor=None):
|
||||
|
@ -154,8 +153,7 @@ class AscendRoiExtractor(Function):
|
|||
'mmdet.models.roi_heads.roi_extractors.'
|
||||
'single_level_roi_extractor.SingleRoIExtractor.forward',
|
||||
backend='ascend')
|
||||
def single_roi_extractor__forward__ascend(ctx,
|
||||
self,
|
||||
def single_roi_extractor__forward__ascend(self,
|
||||
feats,
|
||||
rois,
|
||||
roi_scale_factor=None):
|
||||
|
@ -185,14 +183,10 @@ def single_roi_extractor__forward__ascend(ctx,
|
|||
finest_scale, featmap_strides, aligned)
|
||||
|
||||
|
||||
@mark('roi_extractor', inputs=['feats', 'rois'], outputs=['bbox_feats'])
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdet.models.roi_heads.SingleRoIExtractor.forward')
|
||||
@mark('roi_extractor', inputs=['feats', 'rois'], outputs=['bbox_feats'])
|
||||
def single_roi_extractor__forward(ctx,
|
||||
self,
|
||||
feats,
|
||||
rois,
|
||||
roi_scale_factor=None):
|
||||
def single_roi_extractor__forward(self, feats, rois, roi_scale_factor=None):
|
||||
"""Rewrite `forward` of SingleRoIExtractor for default backend.
|
||||
|
||||
Rewrite this function to:
|
||||
|
@ -206,6 +200,8 @@ def single_roi_extractor__forward(ctx,
|
|||
|
||||
3. use the roi align in torhcvision to accelerate the inference.
|
||||
"""
|
||||
ctx = FUNCTION_REWRITER.get_context(
|
||||
'mmdet.models.roi_heads.SingleRoIExtractor.forward')
|
||||
backend = get_backend(ctx.cfg)
|
||||
out_size = self.roi_layers[0].output_size
|
||||
num_levels = len(feats)
|
||||
|
@ -291,8 +287,7 @@ class SingleRoIExtractorOpenVINO(Function):
|
|||
'mmdet.models.roi_heads.roi_extractors.'
|
||||
'single_level_roi_extractor.SingleRoIExtractor.forward',
|
||||
backend='openvino')
|
||||
def single_roi_extractor__forward__openvino(ctx,
|
||||
self,
|
||||
def single_roi_extractor__forward__openvino(self,
|
||||
feats,
|
||||
rois,
|
||||
roi_scale_factor=None):
|
||||
|
@ -301,6 +296,7 @@ def single_roi_extractor__forward__openvino(ctx,
|
|||
|
||||
This function uses ExperimentalDetectronROIFeatureExtractor for OpenVINO.
|
||||
"""
|
||||
ctx = FUNCTION_REWRITER.get_context()
|
||||
|
||||
# Adding original output to SingleRoIExtractorOpenVINO.
|
||||
state = torch._C._get_tracing_state()
|
||||
|
@ -317,12 +313,11 @@ def single_roi_extractor__forward__openvino(ctx,
|
|||
return result
|
||||
|
||||
|
||||
@mark('roi_extractor', inputs=['feats', 'rois'], outputs=['bbox_feats'])
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdet.models.roi_heads.SingleRoIExtractor.forward',
|
||||
backend=Backend.COREML.value)
|
||||
@mark('roi_extractor', inputs=['feats', 'rois'], outputs=['bbox_feats'])
|
||||
def single_roi_extractor__forward__coreml(ctx,
|
||||
self,
|
||||
def single_roi_extractor__forward__coreml(self,
|
||||
feats,
|
||||
rois,
|
||||
roi_scale_factor=None):
|
||||
|
|
|
@ -10,8 +10,7 @@ from mmdeploy.core import FUNCTION_REWRITER
|
|||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmdet.models.roi_heads.standard_roi_head.StandardRoIHead.predict_bbox')
|
||||
def standard_roi_head__predict_bbox(ctx,
|
||||
self,
|
||||
def standard_roi_head__predict_bbox(self,
|
||||
x: Tuple[Tensor],
|
||||
batch_img_metas: List[dict],
|
||||
rpn_results_list: List[Tensor],
|
||||
|
@ -72,8 +71,7 @@ def standard_roi_head__predict_bbox(ctx,
|
|||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmdet.models.roi_heads.standard_roi_head.StandardRoIHead.predict_mask')
|
||||
def standard_roi_head__predict_mask(ctx,
|
||||
self,
|
||||
def standard_roi_head__predict_mask(self,
|
||||
x: Tuple[Tensor],
|
||||
batch_img_metas: List[dict],
|
||||
results_list: List[Tensor],
|
||||
|
|
|
@ -9,8 +9,7 @@ from mmdeploy.core import FUNCTION_REWRITER
|
|||
func_name='mmdet.models.task_modules.coders.delta_xywh_bbox_coder.'
|
||||
'DeltaXYWHBBoxCoder.decode',
|
||||
backend='default')
|
||||
def deltaxywhbboxcoder__decode(ctx,
|
||||
self,
|
||||
def deltaxywhbboxcoder__decode(self,
|
||||
bboxes,
|
||||
pred_bboxes,
|
||||
max_shape=None,
|
||||
|
@ -51,8 +50,7 @@ def deltaxywhbboxcoder__decode(ctx,
|
|||
func_name='mmdet.models.task_modules.coders'
|
||||
'.delta_xywh_bbox_coder.delta2bbox',
|
||||
backend='default')
|
||||
def delta2bbox(ctx,
|
||||
rois,
|
||||
def delta2bbox(rois,
|
||||
deltas,
|
||||
means=(0., 0., 0., 0.),
|
||||
stds=(1., 1., 1., 1.),
|
||||
|
@ -143,8 +141,7 @@ def delta2bbox(ctx,
|
|||
func_name='mmdet.models.task_modules.coders.'
|
||||
'delta_xywh_bbox_coder.delta2bbox',
|
||||
backend='ncnn')
|
||||
def delta2bbox__ncnn(ctx,
|
||||
rois,
|
||||
def delta2bbox__ncnn(rois,
|
||||
deltas,
|
||||
means=(0., 0., 0., 0.),
|
||||
stds=(1., 1., 1., 1.),
|
||||
|
|
|
@ -8,11 +8,7 @@ from mmdeploy.core import FUNCTION_REWRITER
|
|||
func_name='mmdet.models.task_modules.coders.distance_point_bbox_coder'
|
||||
'.DistancePointBBoxCoder.decode',
|
||||
backend='default')
|
||||
def distancepointbboxcoder__decode(ctx,
|
||||
self,
|
||||
points,
|
||||
pred_bboxes,
|
||||
max_shape=None):
|
||||
def distancepointbboxcoder__decode(self, points, pred_bboxes, max_shape=None):
|
||||
"""Rewrite `mmdet.models.task_modules.coders.distance_point_bbox_coder. \
|
||||
DistancePointBBoxCoder.decode`
|
||||
|
||||
|
|
|
@ -7,8 +7,7 @@ from mmdeploy.core import FUNCTION_REWRITER
|
|||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdet.models.task_modules.coders.tblr_bbox_coder.tblr2bboxes',
|
||||
backend='default')
|
||||
def tblr2bboxes(ctx,
|
||||
priors,
|
||||
def tblr2bboxes(priors,
|
||||
tblr,
|
||||
normalizer=4.0,
|
||||
normalize_by_wh=True,
|
||||
|
|
|
@ -77,7 +77,6 @@ grid_priors_trt = GridPriorsTRTOp.apply
|
|||
'AnchorGenerator.single_level_grid_priors',
|
||||
backend='tensorrt')
|
||||
def anchorgenerator__single_level_grid_priors__trt(
|
||||
ctx,
|
||||
self,
|
||||
featmap_size: Tuple[int],
|
||||
level_idx: int,
|
||||
|
@ -98,6 +97,7 @@ def anchorgenerator__single_level_grid_priors__trt(
|
|||
Returns:
|
||||
torch.Tensor: Anchors in the overall feature maps.
|
||||
"""
|
||||
ctx = FUNCTION_REWRITER.get_context()
|
||||
from mmdet.models.task_modules.prior_generators import AnchorGenerator
|
||||
if type(self) != AnchorGenerator:
|
||||
# only use custom node on default generator.
|
||||
|
|
|
@ -10,7 +10,6 @@ from mmdeploy.utils.constants import Backend
|
|||
'.single_level_grid_priors',
|
||||
backend=Backend.TENSORRT.value)
|
||||
def mlvl_point_generator__single_level_grid_priors__tensorrt(
|
||||
ctx,
|
||||
self,
|
||||
featmap_size,
|
||||
level_idx,
|
||||
|
|
|
@ -7,7 +7,7 @@ from mmdeploy.core import FUNCTION_REWRITER
|
|||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdet.models.utils.transformer.PatchMerging.forward',
|
||||
backend='tensorrt')
|
||||
def patch_merging__forward__tensorrt(ctx, self, x, input_size):
|
||||
def patch_merging__forward__tensorrt(self, x, input_size):
|
||||
"""Rewrite forward function of PatchMerging class for TensorRT. In original
|
||||
implementation, mmdet applies nn.unfold to accelerate the inference.
|
||||
However, the onnx graph of it can not be parsed correctly by TensorRT. In
|
||||
|
|
|
@ -8,7 +8,7 @@ from mmdeploy.core import FUNCTION_REWRITER
|
|||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdet.structures.bbox.transforms.distance2bbox' # noqa
|
||||
)
|
||||
def distance2bbox__default(ctx, points, distance, max_shape=None):
|
||||
def distance2bbox__default(points, distance, max_shape=None):
|
||||
"""Rewrite `mmdet.core.bbox.transforms.distance2bbox`
|
||||
|
||||
Decode distance prediction to bounding box.
|
||||
|
|
|
@ -9,8 +9,7 @@ from mmdeploy.core import FUNCTION_REWRITER
|
|||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmdet3d.models.detectors.Base3DDetector.forward' # noqa: E501
|
||||
)
|
||||
def basedetector__forward(ctx,
|
||||
self,
|
||||
def basedetector__forward(self,
|
||||
inputs: list,
|
||||
data_samples=None,
|
||||
**kwargs) -> Tuple[List[torch.Tensor]]:
|
||||
|
|
|
@ -7,8 +7,7 @@ from mmdeploy.core import FUNCTION_REWRITER
|
|||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmdet3d.models.detectors.mvx_two_stage.MVXTwoStageDetector.extract_img_feat' # noqa: E501
|
||||
)
|
||||
def mvxtwostagedetector__extract_img_feat(ctx, self,
|
||||
img: torch.Tensor) -> dict:
|
||||
def mvxtwostagedetector__extract_img_feat(self, img: torch.Tensor) -> dict:
|
||||
"""Extract features of images."""
|
||||
if self.with_img_backbone and img is not None:
|
||||
if img.dim() == 5 and img.size(0) == 1:
|
||||
|
@ -26,8 +25,7 @@ def mvxtwostagedetector__extract_img_feat(ctx, self,
|
|||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmdet3d.models.detectors.mvx_two_stage.MVXTwoStageDetector.extract_feat')
|
||||
def mvxtwostagedetector__extract_feat(ctx, self,
|
||||
batch_inputs_dict: dict) -> tuple:
|
||||
def mvxtwostagedetector__extract_feat(self, batch_inputs_dict: dict) -> tuple:
|
||||
"""Rewrite this func to remove voxelize op.
|
||||
|
||||
Args:
|
||||
|
@ -47,7 +45,7 @@ def mvxtwostagedetector__extract_feat(ctx, self,
|
|||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmdet3d.models.detectors.mvx_two_stage.MVXTwoStageDetector.forward')
|
||||
def mvxtwostagedetector__forward(ctx, self, inputs: list, **kwargs):
|
||||
def mvxtwostagedetector__forward(self, inputs: list, **kwargs):
|
||||
"""Rewrite this func to remove voxelize op.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -7,7 +7,7 @@ from mmdeploy.core import FUNCTION_REWRITER
|
|||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmdet3d.models.voxel_encoders.pillar_encoder.PillarFeatureNet.forward')
|
||||
def pillar_encoder__forward(ctx, self, features, num_points, coors, *args,
|
||||
def pillar_encoder__forward(self, features, num_points, coors, *args,
|
||||
**kwargs):
|
||||
"""Rewrite this func to optimize node. Modify the code at
|
||||
_with_voxel_center and use slice instead of the original operation.
|
||||
|
|
|
@ -7,11 +7,7 @@ from mmdeploy.core import FUNCTION_REWRITER
|
|||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmdet3d.models.middle_encoders.pillar_scatter.'
|
||||
'PointPillarsScatter.forward_batch', )
|
||||
def pointpillarsscatter__forward(ctx,
|
||||
self,
|
||||
voxel_features,
|
||||
coors,
|
||||
batch_size=1):
|
||||
def pointpillarsscatter__forward(self, voxel_features, coors, batch_size=1):
|
||||
"""Scatter features of single sample.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -10,7 +10,6 @@ from mmdeploy.core import FUNCTION_REWRITER
|
|||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmedit.models.base_models.BaseEditModel.forward', backend='default')
|
||||
def base_edit_model__forward(
|
||||
ctx,
|
||||
self,
|
||||
batch_inputs: Tensor,
|
||||
data_samples: Optional[List[BaseDataElement]] = None,
|
||||
|
|
|
@ -8,7 +8,7 @@ from mmdeploy.core import FUNCTION_REWRITER
|
|||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmocr.models.textdet.FPNC.forward', backend='tensorrt')
|
||||
def fpnc__forward__tensorrt(ctx, self, inputs, **kwargs):
|
||||
def fpnc__forward__tensorrt(self, inputs, **kwargs):
|
||||
"""Rewrite `forward` of FPNC for tensorrt backend.
|
||||
|
||||
Rewrite this function to replace nearest upsampling with bilinear
|
||||
|
|
|
@ -10,7 +10,7 @@ from mmdeploy.core import FUNCTION_REWRITER
|
|||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmocr.models.textdet.heads.BaseTextDetHead.predict')
|
||||
def base_text_det_head__predict(
|
||||
ctx, self, x: torch.Tensor,
|
||||
self, x: torch.Tensor,
|
||||
batch_data_samples: DetSampleList) -> DetSampleList:
|
||||
"""Rewrite `predict` of BaseTextDetHead for default backend.
|
||||
|
||||
|
@ -38,7 +38,7 @@ def base_text_det_head__predict(
|
|||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmocr.models.textdet.heads.DBHead.predict')
|
||||
def db_head__predict(ctx, self, x: torch.Tensor,
|
||||
def db_head__predict(self, x: torch.Tensor,
|
||||
batch_data_samples: DetSampleList) -> DetSampleList:
|
||||
"""Rewrite to avoid post-process of text detection head.
|
||||
|
||||
|
|
|
@ -10,7 +10,6 @@ from mmdeploy.core import FUNCTION_REWRITER
|
|||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmocr.models.textdet.SingleStageTextDetector.forward')
|
||||
def single_stage_text_detector__forward(
|
||||
ctx,
|
||||
self,
|
||||
batch_inputs: torch.Tensor,
|
||||
data_samples: TextDetDataSample = None,
|
||||
|
|
|
@ -10,7 +10,6 @@ from mmdeploy.core import FUNCTION_REWRITER
|
|||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmocr.models.textrecog.decoders.BaseDecoder.predict')
|
||||
def base_decoder__forward(
|
||||
ctx,
|
||||
self,
|
||||
feat: Optional[torch.Tensor] = None,
|
||||
out_enc: Optional[torch.Tensor] = None,
|
||||
|
|
|
@ -5,7 +5,7 @@ from mmdeploy.core import FUNCTION_REWRITER
|
|||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmocr.models.textrecog.decoders.CRNNDecoder.forward_train',
|
||||
backend='ncnn')
|
||||
def crnndecoder__forward_train__ncnn(ctx, self, feat, *args, **kwargs):
|
||||
def crnndecoder__forward_train__ncnn(self, feat, *args, **kwargs):
|
||||
"""Rewrite `forward_train` of CRNNDecoder for ncnn backend.
|
||||
|
||||
Rewrite this function to skip permuting dims of outputs from `[W, N, C]` to
|
||||
|
|
|
@ -7,7 +7,7 @@ from mmdeploy.core import FUNCTION_REWRITER
|
|||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmocr.models.textrecog.EncoderDecoderRecognizer.forward')
|
||||
def encoder_decoder_recognizer__forward(ctx, self, batch_inputs: torch.Tensor,
|
||||
def encoder_decoder_recognizer__forward(self, batch_inputs: torch.Tensor,
|
||||
data_samples: TextRecogDataSample,
|
||||
**kwargs) -> TextRecogDataSample:
|
||||
"""Rewrite `forward` of EncoderDecoderRecognizer for default backend.
|
||||
|
|
|
@ -6,7 +6,7 @@ from mmdeploy.core import FUNCTION_REWRITER
|
|||
func_name='mmocr.models.textrecog.layers.lstm_layer'
|
||||
'.BidirectionalLSTM.forward',
|
||||
backend='ncnn')
|
||||
def bidirectionallstm__forward__ncnn(ctx, self, input):
|
||||
def bidirectionallstm__forward__ncnn(self, input):
|
||||
"""Rewrite `forward` of BidirectionalLSTM for ncnn backend.
|
||||
|
||||
Rewrite this function to set batch_first of rnn layer to true. RNN in ncnn
|
||||
|
|
|
@ -15,7 +15,6 @@ from mmdeploy.core import FUNCTION_REWRITER, MODULE_REWRITER
|
|||
'._2d_attention',
|
||||
backend='default')
|
||||
def parallel_sar_decoder__2d_attention(
|
||||
ctx,
|
||||
self,
|
||||
decoder_input: torch.Tensor,
|
||||
feat: torch.Tensor,
|
||||
|
@ -85,8 +84,7 @@ def parallel_sar_decoder__2d_attention(
|
|||
func_name='mmocr.models.textrecog.decoders.SequentialSARDecoder'
|
||||
'._2d_attention',
|
||||
backend='default')
|
||||
def sequential_sar_decoder__2d_attention(ctx,
|
||||
self,
|
||||
def sequential_sar_decoder__2d_attention(self,
|
||||
y_prev,
|
||||
feat,
|
||||
holistic_feat,
|
||||
|
@ -151,7 +149,6 @@ def sequential_sar_decoder__2d_attention(ctx,
|
|||
'.forward_test',
|
||||
backend='default')
|
||||
def sequential_sar_decoder__forward_test(
|
||||
ctx,
|
||||
self,
|
||||
feat: torch.Tensor,
|
||||
out_enc: torch.Tensor,
|
||||
|
|
|
@ -12,7 +12,6 @@ from mmdeploy.core import FUNCTION_REWRITER
|
|||
func_name='mmocr.models.textrecog.encoders.SAREncoder.forward',
|
||||
backend='default')
|
||||
def sar_encoder__forward(
|
||||
ctx,
|
||||
self,
|
||||
feat: torch.Tensor,
|
||||
data_samples: Optional[Sequence[TextRecogDataSample]] = None):
|
||||
|
|
|
@ -7,7 +7,7 @@ from mmdeploy.core import FUNCTION_REWRITER
|
|||
'mmpose.models.heads.heatmap_heads.CPMHead.forward')
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmpose.models.heads.heatmap_heads.MSPNHead.forward')
|
||||
def mspn_head__forward(ctx, self, feats):
|
||||
def mspn_head__forward(self, feats):
|
||||
"""Rewrite `forward` of MSPNHead and CPMHead for default backend.
|
||||
|
||||
1. return last stage heatmaps directly.
|
||||
|
@ -18,6 +18,7 @@ def mspn_head__forward(ctx, self, feats):
|
|||
Returns:
|
||||
output_heatmap (torch.Tensor): Output heatmaps.
|
||||
"""
|
||||
ctx = FUNCTION_REWRITER.get_context()
|
||||
msmu_batch_heatmaps = ctx.origin_func(self, feats)
|
||||
batch_heatmaps = msmu_batch_heatmaps[-1]
|
||||
return batch_heatmaps
|
||||
|
|
|
@ -4,7 +4,7 @@ from mmdeploy.core import FUNCTION_REWRITER
|
|||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmpose.models.pose_estimators.base.BasePoseEstimator.forward')
|
||||
def base_pose_estimator__forward(ctx, self, inputs, *args, **kwargs):
|
||||
def base_pose_estimator__forward(self, inputs, *args, **kwargs):
|
||||
"""Rewrite `forward` of TopDown for default backend.'.
|
||||
|
||||
1.directly call _forward of subclass.
|
||||
|
|
|
@ -17,8 +17,7 @@ from mmdeploy.utils import is_dynamic_shape
|
|||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmrotate.models.dense_heads.OrientedRPNHead.predict_by_feat')
|
||||
def rpn_head__predict_by_feat(ctx,
|
||||
self,
|
||||
def rpn_head__predict_by_feat(self,
|
||||
cls_scores: List[Tensor],
|
||||
bbox_preds: List[Tensor],
|
||||
score_factors: Optional[List[Tensor]] = None,
|
||||
|
@ -62,6 +61,7 @@ def rpn_head__predict_by_feat(ctx,
|
|||
tuple[Tensor, Tensor, Tensor]: batch_mlvl_bboxes,
|
||||
batch_mlvl_scores, batch_mlvl_centerness
|
||||
"""
|
||||
ctx = FUNCTION_REWRITER.get_context()
|
||||
img_metas = batch_img_metas
|
||||
assert len(cls_scores) == len(bbox_preds)
|
||||
deploy_cfg = ctx.cfg
|
||||
|
|
|
@ -15,8 +15,7 @@ from mmdeploy.mmcv.ops import multiclass_nms
|
|||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmrotate.models.roi_heads.bbox_heads.GVBBoxHead.predict_by_feat')
|
||||
def gv_bbox_head__predict_by_feat(ctx,
|
||||
self,
|
||||
def gv_bbox_head__predict_by_feat(self,
|
||||
rois: Tuple[Tensor],
|
||||
cls_scores: Tuple[Tensor],
|
||||
bbox_preds: Tuple[Tensor],
|
||||
|
@ -60,6 +59,7 @@ def gv_bbox_head__predict_by_feat(ctx,
|
|||
assert rois.ndim == 3, 'Only support export two stage ' \
|
||||
'model to ONNX ' \
|
||||
'with batch dimension. '
|
||||
ctx = FUNCTION_REWRITER.get_context()
|
||||
|
||||
img_shape = batch_img_metas[0]['img_shape']
|
||||
if self.custom_cls_channels:
|
||||
|
|
|
@ -11,8 +11,7 @@ from mmdeploy.core import FUNCTION_REWRITER
|
|||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmrotate.models.roi_heads.gv_ratio_roi_head'
|
||||
'.GVRatioRoIHead.predict_bbox')
|
||||
def gv_ratio_roi_head__predict_bbox(ctx,
|
||||
self,
|
||||
def gv_ratio_roi_head__predict_bbox(self,
|
||||
x: Tuple[Tensor],
|
||||
batch_img_metas: List[dict],
|
||||
rpn_results_list: InstanceList,
|
||||
|
|
|
@ -66,8 +66,7 @@ class MultiLevelRotatedRoiAlign(Function):
|
|||
backend='tensorrt')
|
||||
@mark(
|
||||
'rotated_roi_extractor', inputs=['feats', 'rois'], outputs=['bbox_feats'])
|
||||
def rotated_single_roi_extractor__forward__tensorrt(ctx,
|
||||
self,
|
||||
def rotated_single_roi_extractor__forward__tensorrt(self,
|
||||
feats,
|
||||
rois,
|
||||
roi_scale_factor=None):
|
||||
|
|
|
@ -8,7 +8,7 @@ from mmdeploy.core import FUNCTION_REWRITER
|
|||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmrotate.models.task_modules.coders.gliding_vertex_coder'
|
||||
'.GVFixCoder.decode')
|
||||
def gvfixcoder__decode(ctx, self, hboxes, fix_deltas):
|
||||
def gvfixcoder__decode(self, hboxes, fix_deltas):
|
||||
"""Rewriter for GVFixCoder decode, support more dimension input."""
|
||||
|
||||
assert hboxes.size(
|
||||
|
|
|
@ -21,7 +21,7 @@ def _dist_torch(point1, point2):
|
|||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmrotate.structures.bbox.box_converters.qbox2rbox')
|
||||
def qbox2rbox__default(ctx, boxes: Tensor) -> Tensor:
|
||||
def qbox2rbox__default(boxes: Tensor) -> Tensor:
|
||||
"""Convert quadrilateral boxes to rotated boxes.
|
||||
|
||||
Implement with PyTorch.
|
||||
|
|
|
@ -7,7 +7,7 @@ from mmdeploy.core import FUNCTION_REWRITER
|
|||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmseg.models.decode_heads.ema_head.EMAModule.forward')
|
||||
def ema_module__forward(ctx, self, feats):
|
||||
def ema_module__forward(self, feats):
|
||||
"""Rewrite `forward` for default backend.
|
||||
|
||||
Replace torch.einsum with other operations.
|
||||
|
|
|
@ -7,8 +7,8 @@ from mmdeploy.utils import get_root_logger
|
|||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmseg.models.decode_heads.point_head.PointHead.get_points_test',
|
||||
backend='tensorrt')
|
||||
def point_head__get_points_test__tensorrt(ctx, self, seg_logits,
|
||||
uncertainty_func, cfg):
|
||||
def point_head__get_points_test__tensorrt(self, seg_logits, uncertainty_func,
|
||||
cfg):
|
||||
"""Sample points for testing.
|
||||
|
||||
1. set `num_points` no greater than TENSORRT_MAX_TOPK for tensorrt backend
|
||||
|
@ -26,6 +26,7 @@ def point_head__get_points_test__tensorrt(ctx, self, seg_logits,
|
|||
2) that contains [0, 1] x [0, 1] normalized coordinates of the
|
||||
most uncertain points from the ``height x width`` grid .
|
||||
"""
|
||||
ctx = FUNCTION_REWRITER.get_context()
|
||||
from mmdeploy.utils.constants import TENSORRT_MAX_TOPK
|
||||
|
||||
if cfg.subdivision_num_points > TENSORRT_MAX_TOPK:
|
||||
|
|
|
@ -7,8 +7,7 @@ from mmdeploy.utils import is_dynamic_shape
|
|||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmseg.models.segmentors.BaseSegmentor.forward')
|
||||
def base_segmentor__forward(ctx,
|
||||
self,
|
||||
def base_segmentor__forward(self,
|
||||
inputs,
|
||||
data_samples=None,
|
||||
mode='predict',
|
||||
|
@ -27,6 +26,7 @@ def base_segmentor__forward(ctx,
|
|||
Returns:
|
||||
torch.Tensor: Output segmentation map pf shape [N, 1, H, W].
|
||||
"""
|
||||
ctx = FUNCTION_REWRITER.get_context()
|
||||
if data_samples is None:
|
||||
data_samples = [SegDataSample()]
|
||||
|
||||
|
|
|
@ -4,8 +4,7 @@ from mmdeploy.core import FUNCTION_REWRITER
|
|||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmseg.models.segmentors.CascadeEncoderDecoder.predict')
|
||||
def cascade_encoder_decoder__predict(ctx, self, inputs, data_samples,
|
||||
**kwargs):
|
||||
def cascade_encoder_decoder__predict(self, inputs, data_samples, **kwargs):
|
||||
"""Rewrite `predict` for default backend.
|
||||
|
||||
1. only support mode=`whole` inference
|
||||
|
|
|
@ -5,7 +5,7 @@ from mmdeploy.utils.constants import Backend
|
|||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmseg.models.segmentors.EncoderDecoder.predict')
|
||||
def encoder_decoder__predict(ctx, self, inputs, data_samples, **kwargs):
|
||||
def encoder_decoder__predict(self, inputs, data_samples, **kwargs):
|
||||
"""Rewrite `predict` for default backend.
|
||||
|
||||
1. only support mode=`whole` inference
|
||||
|
@ -32,7 +32,7 @@ def encoder_decoder__predict(ctx, self, inputs, data_samples, **kwargs):
|
|||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmseg.models.segmentors.EncoderDecoder.predict',
|
||||
backend=Backend.RKNN.value)
|
||||
def encoder_decoder__predict__rknn(ctx, self, inputs, data_samples, **kwargs):
|
||||
def encoder_decoder__predict__rknn(self, inputs, data_samples, **kwargs):
|
||||
"""Rewrite `predict` for RKNN backend.
|
||||
|
||||
Early return to avoid argmax operator.
|
||||
|
|
|
@ -8,7 +8,7 @@ from mmdeploy.utils import is_dynamic_shape
|
|||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmseg.models.utils.UpConvBlock.forward')
|
||||
def up_conv_block__forward(ctx, self, skip, x):
|
||||
def up_conv_block__forward(self, skip, x):
|
||||
"""Rewrite `forward` for default backend.
|
||||
|
||||
To support dynamic shape for UNet backbone,
|
||||
|
@ -23,6 +23,7 @@ def up_conv_block__forward(ctx, self, skip, x):
|
|||
Returns:
|
||||
Tensor: Upsampled output feature map.
|
||||
"""
|
||||
ctx = FUNCTION_REWRITER.get_context()
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
# only valid when self.upsample is from build_upsample_layer
|
||||
|
|
|
@ -62,18 +62,20 @@ class Mark(torch.autograd.Function):
|
|||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmdeploy.core.optimizers.function_marker.Mark.symbolic')
|
||||
def mark_symbolic(rewriter, g, x, *args):
|
||||
def mark_symbolic(g, x, *args):
|
||||
"""Rewrite symbolic of mark op."""
|
||||
if cfg_apply_marks(rewriter.cfg):
|
||||
return rewriter.origin_func(g, x, *args)
|
||||
ctx = FUNCTION_REWRITER.get_context()
|
||||
if cfg_apply_marks(ctx.cfg):
|
||||
return ctx.origin_func(g, x, *args)
|
||||
return x
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmdeploy.core.optimizers.function_marker.Mark.forward')
|
||||
def forward_of_mark(rewriter, ctx, x, dtype, shape, func, func_id, type, name,
|
||||
id, attrs) -> torch.Tensor:
|
||||
def forward_of_mark(ctx, x, dtype, shape, func, func_id, type, name, id,
|
||||
attrs) -> torch.Tensor:
|
||||
"""Rewrite forward of mark op."""
|
||||
rewriter = FUNCTION_REWRITER.get_context()
|
||||
deploy_cfg = rewriter.cfg
|
||||
# save calib data
|
||||
apply_marks = cfg_apply_marks(deploy_cfg)
|
||||
|
@ -182,7 +184,7 @@ def mark_tensors(xs: Any, func: str, func_id: int, io_type: str, ctx: Any,
|
|||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmdeploy.core.optimizers.function_marker.mark_tensors', ir=IR.TORCHSCRIPT)
|
||||
def remove_mark__torchscript(ctx, xs: Any, *args, **kwargs):
|
||||
def remove_mark__torchscript(xs: Any, *args, **kwargs):
|
||||
"""Disable all marks for TorchScript backend.
|
||||
|
||||
As the Node `mark` is not able to be traced, we just return original input
|
||||
|
@ -216,12 +218,15 @@ def mark(func_name: Optional[str] = None,
|
|||
>>> from mmdeploy.core import FUNCTION_REWRITER, mark
|
||||
>>> @FUNCTION_REWRITER.register_rewriter(
|
||||
>>> func_name='mmdet.models.roi_heads.ConvFCBBoxHead.forward')
|
||||
>>> @mark(
|
||||
>>> 'bbox_head_forward',
|
||||
>>> inputs=['bbox_feats'],
|
||||
>>> outputs=['cls_score', 'bbox_pred'])
|
||||
>>> def forward_of_bbox_head(ctx, self, x):
|
||||
>>> return ctx.origin_func(self, x)
|
||||
>>> def forward_of_bbox_head(self, x):
|
||||
>>> ctx = FUNCTION_REWRITER.get_context()
|
||||
>>> @mark(
|
||||
>>> 'bbox_head_forward',
|
||||
>>> inputs=['bbox_feats'],
|
||||
>>> outputs=['cls_score', 'bbox_pred'])
|
||||
>>> def _impl():
|
||||
>>> return ctx.origin_func(self, x)
|
||||
>>> return _impl()
|
||||
"""
|
||||
MARK_FUNCTION_COUNT[func_name] = 0
|
||||
|
||||
|
|
|
@ -1,11 +1,25 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import types
|
||||
from collections import defaultdict
|
||||
from typing import (Any, Callable, Dict, List, MutableSequence, Optional,
|
||||
Tuple, Union)
|
||||
|
||||
from mmdeploy.utils import IR, Backend, get_root_logger
|
||||
from .rewriter_utils import (Checker, ContextCaller, RewriterRegistry,
|
||||
copy_function, get_frame_func, get_func_qualname,
|
||||
import_function)
|
||||
|
||||
try:
|
||||
try:
|
||||
# torch>=1.10.0
|
||||
from torch.fx._symbolic_trace import _wrapped_fns_to_patch
|
||||
except ImportError:
|
||||
# 1.10.0>torch>=1.8.0
|
||||
from torch.fx.symbolic_trace import _wrapped_fns_to_patch
|
||||
except ImportError:
|
||||
# torch<1.8.0
|
||||
_wrapped_fns_to_patch = []
|
||||
|
||||
|
||||
def _replace_all_obj(obj: Any,
|
||||
new_obj: Any,
|
||||
|
@ -92,6 +106,24 @@ def _del_func(path: str):
|
|||
continue
|
||||
|
||||
|
||||
def _fx_wrap_copied_fn(func: types.FunctionType,
|
||||
copied_func: types.FunctionType):
|
||||
"""If a function is wrapped by torch.fx.wrap, its copy also needs to be
|
||||
wrapped by torch.fx.wrap."""
|
||||
if not hasattr(func, '__globals__'):
|
||||
return
|
||||
|
||||
wrapped_fns_globals = [item[0] for item in _wrapped_fns_to_patch]
|
||||
wrapped_fns_names = [item[1] for item in _wrapped_fns_to_patch]
|
||||
|
||||
# check if wrapped by torch.fx.wrap
|
||||
if func.__globals__ in wrapped_fns_globals:
|
||||
idx = wrapped_fns_globals.index(func.__globals__)
|
||||
fn_name = wrapped_fns_names[idx]
|
||||
# a hacky way to wrap the func in copied func
|
||||
_wrapped_fns_to_patch.append((copied_func.__globals__, fn_name))
|
||||
|
||||
|
||||
class FunctionRewriter:
|
||||
"""A function rewriter which maintains rewritten functions.
|
||||
|
||||
|
@ -102,7 +134,8 @@ class FunctionRewriter:
|
|||
Examples:
|
||||
>>> @FUNCTION_REWRITER.register_rewriter(
|
||||
>>> func_name='torch.Tensor.size', backend='ncnn')
|
||||
>>> def size_of_tensor_static(ctx, self, *args):
|
||||
>>> def size_of_tensor_static(self, *args):
|
||||
>>> ctx = FUNCTION_REWRITER.get_context()
|
||||
>>> ret = ctx.origin_func(self, *args)
|
||||
>>> if isinstance(ret, torch.Tensor):
|
||||
>>> ret = int(ret)
|
||||
|
@ -114,6 +147,7 @@ class FunctionRewriter:
|
|||
|
||||
def __init__(self):
|
||||
self._registry = RewriterRegistry()
|
||||
self._func_contexts = defaultdict(list)
|
||||
|
||||
def register_rewriter(
|
||||
self,
|
||||
|
@ -140,8 +174,11 @@ class FunctionRewriter:
|
|||
|
||||
def enter(self, cfg: Dict = dict(), env: Dict = dict(), **kwargs):
|
||||
"""The implementation of function rewrite."""
|
||||
self._func_contexts.clear()
|
||||
# Get current records
|
||||
functions_records = self._registry.get_records(env)
|
||||
# Get current fx wrapped func nums
|
||||
self._ori_fx_wrap_num = len(_wrapped_fns_to_patch)
|
||||
|
||||
self._origin_functions = list()
|
||||
self._additional_functions = list()
|
||||
|
@ -181,15 +218,25 @@ class FunctionRewriter:
|
|||
|
||||
# Create context_caller
|
||||
rewrite_function = record_dict['_object']
|
||||
# The func before and after copy has different globals
|
||||
rewrite_function = copy_function(rewrite_function)
|
||||
extra_kwargs = kwargs.copy()
|
||||
extra_kwargs.update(record_dict)
|
||||
context_caller = ContextCaller(
|
||||
rewrite_function, origin_func, cfg,
|
||||
**extra_kwargs).get_wrapped_caller()
|
||||
context_caller = ContextCaller(rewrite_function, origin_func,
|
||||
cfg, **extra_kwargs)
|
||||
# If there is a function wrapped by torch.fx.wrap in
|
||||
# rewrite_function's globals, we need to wrap the same name
|
||||
# function in copied function's globals.
|
||||
_fx_wrap_copied_fn(record_dict['_object'], context_caller.func)
|
||||
|
||||
qualname = get_func_qualname(rewrite_function)
|
||||
self._func_contexts[qualname].append(context_caller)
|
||||
self._func_contexts[function_path].append(context_caller)
|
||||
|
||||
# Cache new the function to avoid homonymic bug
|
||||
new_functions.append(
|
||||
dict(func_path=function_path, origin_func=context_caller))
|
||||
dict(
|
||||
func_path=function_path, origin_func=rewrite_function))
|
||||
|
||||
for func_dict in new_functions:
|
||||
function_path = func_dict['func_path']
|
||||
|
@ -199,9 +246,46 @@ class FunctionRewriter:
|
|||
|
||||
def exit(self):
|
||||
"""Recover the function rewrite."""
|
||||
# Restore _wrapped_fns_to_patch
|
||||
cur_fx_wrap_num = len(_wrapped_fns_to_patch)
|
||||
for _ in range(cur_fx_wrap_num - self._ori_fx_wrap_num):
|
||||
_wrapped_fns_to_patch.pop(-1)
|
||||
|
||||
for func_dict in self._origin_functions:
|
||||
func_path = func_dict['func_path']
|
||||
func = func_dict['origin_func']
|
||||
_set_func(func_path, func)
|
||||
for func_path in self._additional_functions:
|
||||
_del_func(func_path)
|
||||
|
||||
self._func_contexts.clear()
|
||||
|
||||
def get_context(self, key: Optional[str] = None) -> ContextCaller:
|
||||
"""Get the context of rewriter.
|
||||
|
||||
Args:
|
||||
key: key to the context.
|
||||
|
||||
Returns:
|
||||
ContextCaller: context of function
|
||||
"""
|
||||
func = None
|
||||
if key is None:
|
||||
func = get_frame_func(2)
|
||||
key = get_func_qualname(func)
|
||||
|
||||
# get all contexts
|
||||
ctxs = self._func_contexts.get(key, [])
|
||||
|
||||
if func is None:
|
||||
assert len(ctxs) == 1
|
||||
return ctxs[0]
|
||||
|
||||
ctx = None
|
||||
for tmp_ctx in ctxs:
|
||||
if tmp_ctx.func == func:
|
||||
ctx = tmp_ctx
|
||||
|
||||
if ctx is None:
|
||||
get_root_logger().warning(f'Can not found context of {key}')
|
||||
return ctx
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import functools
|
||||
import inspect
|
||||
import types
|
||||
import warnings
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from functools import wraps
|
||||
|
@ -342,8 +344,9 @@ class ContextCaller:
|
|||
|
||||
Example:
|
||||
>>> @FUNCTION_REWRITER.register_rewriter(func_name='torch.add')
|
||||
>>> def func(ctx, x, y):
|
||||
>>> def func(x, y):
|
||||
>>> # ctx is an instance of ContextCaller
|
||||
>>> ctx = FUNCTION_REWRITER.get_context()
|
||||
>>> print(ctx.cfg)
|
||||
>>> return x + y
|
||||
"""
|
||||
|
@ -379,3 +382,61 @@ class ContextCaller:
|
|||
return self.func(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def get_func_qualname(func: Callable) -> str:
|
||||
"""get function name."""
|
||||
assert isinstance(func, Callable), f'{func} is not a Callable object.'
|
||||
_func_name = None
|
||||
if hasattr(func, '__qualname__'):
|
||||
_func_name = f'{func.__module__}.{func.__qualname__}'
|
||||
elif hasattr(func, '__class__'):
|
||||
_func_name = func.__class__
|
||||
else:
|
||||
_func_name = str(func)
|
||||
return _func_name
|
||||
|
||||
|
||||
def get_frame_func(top: int = 1) -> Callable:
|
||||
"""get func of frame."""
|
||||
frameinfo = inspect.stack()[top]
|
||||
frame = frameinfo.frame
|
||||
|
||||
g_vars = frame.f_globals
|
||||
func_name = frameinfo.function
|
||||
assert func_name in g_vars, \
|
||||
f'Can not find function: {func_name} in global.'
|
||||
func = g_vars[func_name]
|
||||
return func
|
||||
|
||||
|
||||
def get_frame_qualname(top: int = 1) -> str:
|
||||
"""get frame name."""
|
||||
frameinfo = inspect.stack()[top]
|
||||
frame = frameinfo.frame
|
||||
|
||||
g_vars = frame.f_globals
|
||||
func_name = frameinfo.function
|
||||
assert func_name in g_vars, \
|
||||
f'Can not find function: {func_name} in global.'
|
||||
func = g_vars[func_name]
|
||||
module_name = inspect.getmodule(func).__name__
|
||||
|
||||
return f'{module_name}.{func_name}'
|
||||
|
||||
|
||||
def copy_function(f: types.FunctionType):
|
||||
"""Copy the function."""
|
||||
# copy the global so we can get different func for different origin
|
||||
glb = f.__globals__.copy()
|
||||
name = f.__name__
|
||||
g = types.FunctionType(
|
||||
f.__code__,
|
||||
glb,
|
||||
name=name,
|
||||
argdefs=f.__defaults__,
|
||||
closure=f.__closure__)
|
||||
g = functools.update_wrapper(g, f)
|
||||
g.__kwdefaults__ = f.__kwdefaults__
|
||||
glb[name] = g
|
||||
return g
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from collections import defaultdict
|
||||
from typing import Callable, Dict, List, Optional, Sequence, Union
|
||||
|
||||
import torch
|
||||
|
@ -7,7 +8,8 @@ from torch.onnx.symbolic_helper import parse_args
|
|||
|
||||
from mmdeploy.utils import IR, Backend, get_root_logger
|
||||
from .rewriter_utils import (Checker, ContextCaller, RewriterRegistry,
|
||||
eval_with_import)
|
||||
copy_function, eval_with_import, get_frame_func,
|
||||
get_func_qualname)
|
||||
|
||||
|
||||
class SymbolicRewriter:
|
||||
|
@ -21,7 +23,7 @@ class SymbolicRewriter:
|
|||
Examples:
|
||||
>>> @SYMBOLIC_REWRITER.register_symbolic('squeeze', \
|
||||
>>> is_pytorch=True)
|
||||
>>> def squeeze_default(ctx, g, self, dim=None):
|
||||
>>> def squeeze_default(g, self, dim=None):
|
||||
>>> if dim is None:
|
||||
>>> dims = []
|
||||
>>> for i, size in enumerate(self.type().sizes()):
|
||||
|
@ -34,6 +36,7 @@ class SymbolicRewriter:
|
|||
|
||||
def __init__(self) -> None:
|
||||
self._registry = RewriterRegistry()
|
||||
self._func_contexts = defaultdict(list)
|
||||
|
||||
def register_symbolic(self,
|
||||
func_name: str,
|
||||
|
@ -75,6 +78,9 @@ class SymbolicRewriter:
|
|||
opset: int = 11,
|
||||
**kwargs):
|
||||
"""The implementation of symbolic register."""
|
||||
# clear context
|
||||
self._func_contexts.clear()
|
||||
|
||||
# Get current records
|
||||
symbolic_records = self._registry.get_records(env)
|
||||
|
||||
|
@ -84,19 +90,27 @@ class SymbolicRewriter:
|
|||
for function_name, record_dict in symbolic_records:
|
||||
|
||||
symbolic_function = record_dict['_object']
|
||||
symbolic_function = copy_function(symbolic_function)
|
||||
arg_descriptors = record_dict['arg_descriptors']
|
||||
extra_kwargs = kwargs.copy()
|
||||
extra_kwargs.update(record_dict)
|
||||
context_caller = ContextCaller(symbolic_function, None, cfg,
|
||||
**extra_kwargs)
|
||||
|
||||
# register context
|
||||
qualname = get_func_qualname(symbolic_function)
|
||||
self._func_contexts[qualname].append(context_caller)
|
||||
self._func_contexts[function_name].append(context_caller)
|
||||
|
||||
if arg_descriptors is not None and len(arg_descriptors) > 0:
|
||||
context_caller = parse_args(*arg_descriptors)(context_caller)
|
||||
symbolic_function = parse_args(*arg_descriptors)(
|
||||
symbolic_function)
|
||||
|
||||
is_pytorch = record_dict['is_pytorch']
|
||||
if is_pytorch:
|
||||
from torch.onnx import register_custom_op_symbolic
|
||||
register_custom_op_symbolic(f'::{function_name}',
|
||||
context_caller, opset)
|
||||
symbolic_function, opset)
|
||||
|
||||
# Save domain and version
|
||||
self._pytorch_symbolic.append((function_name, '', opset))
|
||||
|
@ -123,7 +137,7 @@ class SymbolicRewriter:
|
|||
self._extra_symbolic.append((origin_func, origin_symbolic))
|
||||
|
||||
# Cache new the function to avoid homonymic bug
|
||||
new_functions.append((origin_func, context_caller))
|
||||
new_functions.append((origin_func, symbolic_function))
|
||||
|
||||
for origin_func, new_func in new_functions:
|
||||
origin_symbolic = getattr(origin_func, 'symbolic', None)
|
||||
|
@ -132,6 +146,9 @@ class SymbolicRewriter:
|
|||
|
||||
def exit(self):
|
||||
"""The implementation of symbolic unregister."""
|
||||
# clear context
|
||||
self._func_contexts.clear()
|
||||
|
||||
# Unregister pytorch op
|
||||
if hasattr(torch.onnx, 'unregister_custom_op_symbolic'):
|
||||
from torch.onnx import unregister_custom_op_symbolic
|
||||
|
@ -149,3 +166,33 @@ class SymbolicRewriter:
|
|||
# Unregister custom op
|
||||
for origin_func, origin_symbolic in self._extra_symbolic:
|
||||
origin_func.symbolic = origin_symbolic
|
||||
|
||||
def get_context(self, key: Optional[str] = None) -> ContextCaller:
|
||||
"""Get the context of rewriter.
|
||||
|
||||
Args:
|
||||
key: key to the context.
|
||||
|
||||
Returns:
|
||||
ContextCaller: context of function
|
||||
"""
|
||||
func = None
|
||||
if key is None:
|
||||
func = get_frame_func(2)
|
||||
key = get_func_qualname(func)
|
||||
|
||||
# get all contexts
|
||||
ctxs = self._func_contexts.get(key, [])
|
||||
|
||||
if func is None:
|
||||
assert len(ctxs) == 1
|
||||
return ctxs[0]
|
||||
|
||||
ctx = None
|
||||
for tmp_ctx in ctxs:
|
||||
if tmp_ctx.func == func:
|
||||
ctx = tmp_ctx
|
||||
|
||||
if ctx is None:
|
||||
get_root_logger().warning(f'Can not found context of {key}')
|
||||
return ctx
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from . import conv2d_adaptive_padding # noqa: F401,F403
|
||||
from .transformer import MultiHeadAttentionop
|
||||
|
||||
__all__ = ['MultiHeadAttentionop']
|
||||
|
|
|
@ -1,86 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
from mmdeploy.utils import Backend, is_dynamic_batch, is_dynamic_shape
|
||||
|
||||
|
||||
def compute_padding(input_size, kernel_size, stride, dilation):
|
||||
"""Compute padding."""
|
||||
|
||||
input_h, input_w = input_size
|
||||
kernel_h, kernel_w = kernel_size
|
||||
stride_h, stride_w = stride
|
||||
dilation_h, dilation_w = dilation
|
||||
output_h = math.ceil(input_h / stride_h)
|
||||
output_w = math.ceil(input_w / stride_w)
|
||||
pad_h = max(
|
||||
(output_h - 1) * stride_h + (kernel_h - 1) * dilation_h + 1 - input_h,
|
||||
0)
|
||||
pad_w = max(
|
||||
(output_w - 1) * stride_w + (kernel_w - 1) * dilation_w + 1 - input_w,
|
||||
0)
|
||||
if pad_w > 0 or pad_h > 0:
|
||||
padded = [
|
||||
pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2
|
||||
]
|
||||
else:
|
||||
padded = None
|
||||
return padded
|
||||
|
||||
|
||||
class AdaptivePadOp(torch.autograd.Function):
|
||||
"""Dummy adaptive pad op."""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x, padded):
|
||||
if padded is not None:
|
||||
x = F.pad(x, padded)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def symbolic(g, x, padded):
|
||||
if padded is None:
|
||||
return g.op('Identity', x)
|
||||
padded = g.op(
|
||||
'Constant', value_t=torch.tensor(padded, dtype=torch.int64))
|
||||
constant_value = g.op(
|
||||
'Constant', value_t=torch.tensor(0, dtype=torch.int64))
|
||||
return g.op(
|
||||
'Pad', x, padded, constant_value, mode_s='constant', outputs=1)
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmcv.cnn.bricks.conv2d_adaptive_padding. \
|
||||
Conv2dAdaptivePadding.forward',
|
||||
backend=Backend.TENSORRT.value)
|
||||
def conv2d_adaptive_padding__forward__tensorrt(ctx, self, x):
|
||||
"""Rewrite `forward` of Conv2dAdaptivePadding used in EfficientNet for
|
||||
TensorRT backend. Main changes of this rewritten function is to separate
|
||||
the computation of padding and encapsulate it into another
|
||||
`torch.autograd.Function` so that the adaptive padding could be parsed as
|
||||
`Pad` ops in ONNX with the padding information computed in advance (Only
|
||||
for static shape configuration).
|
||||
|
||||
Args:
|
||||
x (Tensor): Input tensor of Conv2dAdaptivePadding ops
|
||||
Returns:
|
||||
Tensor: forward result of 2D convolution after padding
|
||||
"""
|
||||
|
||||
deploy_cfg = ctx.cfg
|
||||
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
|
||||
if (not is_dynamic_flag) or is_dynamic_batch(deploy_cfg):
|
||||
padded = compute_padding(x.shape[2:], self.weight.shape[2:],
|
||||
self.stride, self.dilation)
|
||||
if padded is not None:
|
||||
padded = [int(_) for _ in padded]
|
||||
x = AdaptivePadOp.apply(x, padded)
|
||||
return F.conv2d(x, self.weight, self.bias, self.stride, self.padding,
|
||||
self.dilation, self.groups)
|
||||
else:
|
||||
x = ctx.origin_func(x)
|
||||
return x
|
|
@ -57,8 +57,7 @@ class MultiHeadAttentionop(torch.autograd.Function):
|
|||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmcv.cnn.bricks.transformer.MultiheadAttention.forward',
|
||||
backend=Backend.NCNN.value)
|
||||
def multiheadattention__forward__ncnn(ctx,
|
||||
self,
|
||||
def multiheadattention__forward__ncnn(self,
|
||||
query,
|
||||
key=None,
|
||||
value=None,
|
||||
|
|
|
@ -4,8 +4,7 @@ from mmdeploy.core import SYMBOLIC_REWRITER
|
|||
|
||||
@SYMBOLIC_REWRITER.register_symbolic(
|
||||
'mmcv.ops.deform_conv.DeformConv2dFunction')
|
||||
def deform_conv__default(ctx,
|
||||
g,
|
||||
def deform_conv__default(g,
|
||||
input,
|
||||
offset,
|
||||
weight,
|
||||
|
@ -31,8 +30,7 @@ def deform_conv__default(ctx,
|
|||
|
||||
@SYMBOLIC_REWRITER.register_symbolic(
|
||||
'mmcv.ops.deform_conv.DeformConv2dFunction', backend='openvino')
|
||||
def deform_conv_openvino(ctx,
|
||||
g,
|
||||
def deform_conv_openvino(g,
|
||||
input,
|
||||
offset,
|
||||
weight,
|
||||
|
|
|
@ -4,9 +4,8 @@ from mmdeploy.core import SYMBOLIC_REWRITER
|
|||
|
||||
@SYMBOLIC_REWRITER.register_symbolic(
|
||||
'mmcv.ops.modulated_deform_conv.ModulatedDeformConv2dFunction')
|
||||
def modulated_deform_conv_default(ctx, g, input, offset, mask, weight, bias,
|
||||
stride, padding, dilation, groups,
|
||||
deform_groups):
|
||||
def modulated_deform_conv_default(g, input, offset, mask, weight, bias, stride,
|
||||
padding, dilation, groups, deform_groups):
|
||||
"""Rewrite mdcn symbolic function for all backend."""
|
||||
input_tensors = [input, offset, mask, weight]
|
||||
if bias is not None:
|
||||
|
|
|
@ -82,7 +82,6 @@ class ONNXNMSop(torch.autograd.Function):
|
|||
Returns:
|
||||
NonMaxSuppression op for onnx.
|
||||
"""
|
||||
|
||||
if not sym_help._is_value(max_output_boxes_per_class):
|
||||
max_output_boxes_per_class = g.op(
|
||||
'Constant',
|
||||
|
@ -354,8 +353,7 @@ def _multiclass_nms_single(boxes: Tensor,
|
|||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdeploy.mmcv.ops.nms._multiclass_nms')
|
||||
def multiclass_nms__default(ctx,
|
||||
boxes: Tensor,
|
||||
def multiclass_nms__default(boxes: Tensor,
|
||||
scores: Tensor,
|
||||
max_output_boxes_per_class: int = 1000,
|
||||
iou_threshold: float = 0.5,
|
||||
|
@ -388,6 +386,7 @@ def multiclass_nms__default(ctx,
|
|||
tuple[Tensor, Tensor]: (dets, labels), `dets` of shape [N, num_det, 5]
|
||||
and `labels` of shape [N, num_det].
|
||||
"""
|
||||
ctx = FUNCTION_REWRITER.get_context()
|
||||
deploy_cfg = ctx.cfg
|
||||
batch_size = boxes.size(0)
|
||||
if not is_dynamic_batch(deploy_cfg) and batch_size == 1:
|
||||
|
@ -414,8 +413,7 @@ def multiclass_nms__default(ctx,
|
|||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdeploy.mmcv.ops.nms._multiclass_nms', backend='tensorrt')
|
||||
def multiclass_nms_static(ctx,
|
||||
boxes: Tensor,
|
||||
def multiclass_nms_static(boxes: Tensor,
|
||||
scores: Tensor,
|
||||
max_output_boxes_per_class: int = 1000,
|
||||
iou_threshold: float = 0.5,
|
||||
|
@ -479,12 +477,35 @@ def multiclass_nms_static(ctx,
|
|||
'multiclass_nms',
|
||||
inputs=['boxes', 'scores'],
|
||||
outputs=['dets', 'labels', 'index'])
|
||||
def multiclass_nms(*args, nms_type='nms', **kwargs):
|
||||
def multiclass_nms(boxes: Tensor,
|
||||
scores: Tensor,
|
||||
max_output_boxes_per_class: int = 1000,
|
||||
iou_threshold: float = 0.5,
|
||||
score_threshold: float = 0.05,
|
||||
pre_top_k: int = -1,
|
||||
keep_top_k: int = -1,
|
||||
output_index: bool = False,
|
||||
nms_type='nms'):
|
||||
"""Apis for multiclass nms."""
|
||||
if nms_type == 'nms':
|
||||
return _multiclass_nms(*args, **kwargs)
|
||||
return _multiclass_nms(
|
||||
boxes,
|
||||
scores,
|
||||
max_output_boxes_per_class=max_output_boxes_per_class,
|
||||
iou_threshold=iou_threshold,
|
||||
score_threshold=score_threshold,
|
||||
pre_top_k=pre_top_k,
|
||||
keep_top_k=keep_top_k,
|
||||
output_index=output_index)
|
||||
elif nms_type == 'nms_rotated':
|
||||
return multiclass_nms_rotated(*args, **kwargs)
|
||||
return multiclass_nms_rotated(
|
||||
boxes,
|
||||
scores,
|
||||
max_output_boxes_per_class=max_output_boxes_per_class,
|
||||
iou_threshold=iou_threshold,
|
||||
score_threshold=score_threshold,
|
||||
pre_top_k=pre_top_k,
|
||||
keep_top_k=keep_top_k)
|
||||
else:
|
||||
raise NotImplementedError(f'Unsupported nms type: {nms_type}.')
|
||||
|
||||
|
@ -492,8 +513,7 @@ def multiclass_nms(*args, nms_type='nms', **kwargs):
|
|||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdeploy.mmcv.ops.nms.bbox_nms._multiclass_nms',
|
||||
backend=Backend.COREML.value)
|
||||
def multiclass_nms__coreml(ctx,
|
||||
boxes: Tensor,
|
||||
def multiclass_nms__coreml(boxes: Tensor,
|
||||
scores: Tensor,
|
||||
max_output_boxes_per_class: int = 1000,
|
||||
iou_threshold: float = 0.5,
|
||||
|
@ -556,8 +576,7 @@ def multiclass_nms__coreml(ctx,
|
|||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdeploy.mmcv.ops.nms.bbox_nms._multiclass_nms',
|
||||
ir=IR.TORCHSCRIPT)
|
||||
def multiclass_nms__torchscript(ctx,
|
||||
boxes: Tensor,
|
||||
def multiclass_nms__torchscript(boxes: Tensor,
|
||||
scores: Tensor,
|
||||
max_output_boxes_per_class: int = 1000,
|
||||
iou_threshold: float = 0.5,
|
||||
|
@ -659,8 +678,7 @@ class AscendBatchNMSOp(torch.autograd.Function):
|
|||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdeploy.mmcv.ops.nms.bbox_nms._multiclass_nms',
|
||||
backend='ascend')
|
||||
def multiclass_nms__ascend(ctx,
|
||||
boxes: Tensor,
|
||||
def multiclass_nms__ascend(boxes: Tensor,
|
||||
scores: Tensor,
|
||||
max_output_boxes_per_class: int = 1000,
|
||||
iou_threshold: float = 0.5,
|
||||
|
|
|
@ -272,8 +272,7 @@ def _multiclass_nms_rotated(boxes: Tensor,
|
|||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdeploy.mmcv.ops.nms_rotated._multiclass_nms_rotated',
|
||||
backend='tensorrt')
|
||||
def multiclass_nms_rotated__tensorrt(ctx,
|
||||
boxes: Tensor,
|
||||
def multiclass_nms_rotated__tensorrt(boxes: Tensor,
|
||||
scores: Tensor,
|
||||
max_output_boxes_per_class: int = 1000,
|
||||
iou_threshold: float = 0.5,
|
||||
|
@ -317,7 +316,19 @@ def multiclass_nms_rotated__tensorrt(ctx,
|
|||
'multiclass_nms_rotated',
|
||||
inputs=['boxes', 'scores'],
|
||||
outputs=['dets', 'labels'])
|
||||
def multiclass_nms_rotated(*args, **kwargs):
|
||||
def multiclass_nms_rotated(boxes: Tensor,
|
||||
scores: Tensor,
|
||||
max_output_boxes_per_class: int = 1000,
|
||||
iou_threshold: float = 0.1,
|
||||
score_threshold: float = 0.05,
|
||||
pre_top_k: int = -1,
|
||||
keep_top_k: int = -1):
|
||||
"""Wrapper function for `_multiclass_nms`."""
|
||||
return mmdeploy.mmcv.ops.nms_rotated._multiclass_nms_rotated(
|
||||
*args, **kwargs)
|
||||
boxes,
|
||||
scores,
|
||||
max_output_boxes_per_class=max_output_boxes_per_class,
|
||||
iou_threshold=iou_threshold,
|
||||
score_threshold=score_threshold,
|
||||
pre_top_k=pre_top_k,
|
||||
keep_top_k=keep_top_k)
|
||||
|
|
|
@ -6,7 +6,7 @@ from mmdeploy.core import FUNCTION_REWRITER
|
|||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmcv.ops.point_sample', backend='default')
|
||||
def point_sample__default(ctx, input, points, align_corners=False, **kwargs):
|
||||
def point_sample__default(input, points, align_corners=False, **kwargs):
|
||||
"""A wrapper around :func:`grid_sample` to support 3D point_coords tensors
|
||||
Unlike :func:`torch.nn.functional.grid_sample` it assumes point_coords to
|
||||
lie inside ``[0, 1] x [0, 1]`` square.
|
||||
|
@ -37,7 +37,7 @@ def point_sample__default(ctx, input, points, align_corners=False, **kwargs):
|
|||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmcv.ops.SimpleRoIAlign.forward')
|
||||
def simple_roialign__forward(ctx, self, features, rois):
|
||||
def simple_roialign__forward(self, features, rois):
|
||||
"""Rewrite `forward` of SimpleRoIAlign.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -13,9 +13,9 @@ from mmdeploy.utils import Backend, get_backend, get_ir_config
|
|||
# visible in mmcv.
|
||||
@SYMBOLIC_REWRITER.register_symbolic(
|
||||
'mmcv.ops.roi_align.__self__', backend='default')
|
||||
def roi_align_default(ctx, g, input: Tensor, rois: Tensor,
|
||||
output_size: List[int], spatial_scale: float,
|
||||
sampling_ratio: int, pool_mode: str, aligned: bool):
|
||||
def roi_align_default(g, input: Tensor, rois: Tensor, output_size: List[int],
|
||||
spatial_scale: float, sampling_ratio: int,
|
||||
pool_mode: str, aligned: bool):
|
||||
"""Rewrite symbolic function for default backend.
|
||||
|
||||
Replace onnx::RoiAlign with mmcv::MMCVRoiAlign for PPLNN. For ONNXRuntime,
|
||||
|
@ -41,6 +41,7 @@ def roi_align_default(ctx, g, input: Tensor, rois: Tensor,
|
|||
Returns:
|
||||
MMCVRoiAlign op for onnx.
|
||||
"""
|
||||
ctx = SYMBOLIC_REWRITER.get_context()
|
||||
backend = get_backend(ctx.cfg)
|
||||
if backend == Backend.PPLNN or backend == Backend.TENSORRT:
|
||||
domain = 'mmcv'
|
||||
|
|
|
@ -11,7 +11,7 @@ from mmdeploy.core import SYMBOLIC_REWRITER
|
|||
# is not visible in mmcv.
|
||||
@SYMBOLIC_REWRITER.register_symbolic(
|
||||
'mmcv.ops.roi_align_rotated.__self__', backend='default')
|
||||
def roi_align_rotated_default(ctx, g, input: Tensor, rois: Tensor,
|
||||
def roi_align_rotated_default(g, input: Tensor, rois: Tensor,
|
||||
output_size: List[int], spatial_scale: float,
|
||||
sampling_ratio: int, aligned: bool,
|
||||
clockwise: bool):
|
||||
|
|
|
@ -6,7 +6,7 @@ from mmdeploy.utils import Backend
|
|||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmcv.cnn.bricks.transformer.PatchEmbed.forward',
|
||||
backend=Backend.NCNN.value)
|
||||
def patch_embed__forward__ncnn(ctx, self, x):
|
||||
def patch_embed__forward__ncnn(self, x):
|
||||
"""Rewrite `forward` of PatchEmbed for ncnn backend.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -9,8 +9,9 @@ from mmdeploy.utils import Backend, get_root_logger, is_dynamic_shape
|
|||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='torch.nn.functional.adaptive_avg_pool2d')
|
||||
def adaptive_avg_pool2d__default(ctx, input, output_size):
|
||||
def adaptive_avg_pool2d__default(input, output_size):
|
||||
"""Rewrite `adaptive_avg_pool2d` for default backend."""
|
||||
ctx = FUNCTION_REWRITER.get_context()
|
||||
output_size = _pair(output_size)
|
||||
if int(output_size[0]) == int(output_size[1]) == 1:
|
||||
out = ctx.origin_func(input, output_size)
|
||||
|
@ -39,6 +40,7 @@ def adaptive_avg_pool2d__default(ctx, input, output_size):
|
|||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='torch.nn.functional.adaptive_avg_pool2d',
|
||||
backend=Backend.TORCHSCRIPT.value)
|
||||
def adaptive_avg_pool2d__ncnn(ctx, input, output_size):
|
||||
def adaptive_avg_pool2d__ncnn(input, output_size):
|
||||
ctx = FUNCTION_REWRITER.get_context()
|
||||
"""Rewrite `adaptive_avg_pool2d` for ncnn and torchscript backend."""
|
||||
return ctx.origin_func(input, output_size)
|
||||
|
|
|
@ -7,7 +7,6 @@ from mmdeploy.core import FUNCTION_REWRITER
|
|||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='torch.atan2', backend='default')
|
||||
def atan2__default(
|
||||
ctx,
|
||||
input1: torch.Tensor,
|
||||
input2: torch.Tensor,
|
||||
):
|
||||
|
|
|
@ -7,7 +7,7 @@ from mmdeploy.utils import IR
|
|||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='torch.Tensor.chunk', backend='ncnn')
|
||||
def chunk__ncnn(ctx, self, num_chunks: int, dim: int = 0) -> torch.Tensor:
|
||||
def chunk__ncnn(self, num_chunks: int, dim: int = 0) -> torch.Tensor:
|
||||
"""Rewrite `chunk` for NCNN backend.
|
||||
|
||||
Chunk in ncnn are not supported, so it should be rewritten.
|
||||
|
@ -36,10 +36,7 @@ def chunk__ncnn(ctx, self, num_chunks: int, dim: int = 0) -> torch.Tensor:
|
|||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='torch.Tensor.chunk', ir=IR.TORCHSCRIPT)
|
||||
def chunk__torchscript(ctx,
|
||||
self,
|
||||
num_chunks: int,
|
||||
dim: int = 0) -> torch.Tensor:
|
||||
def chunk__torchscript(self, num_chunks: int, dim: int = 0) -> torch.Tensor:
|
||||
"""Rewrite `chunk` for Torchscript.
|
||||
|
||||
Replace chunk op with split op
|
||||
|
|
|
@ -13,11 +13,12 @@ from mmdeploy.utils import Backend
|
|||
func_name='torch.Tensor.clamp', backend=Backend.COREML.value)
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='torch.clamp', backend=Backend.COREML.value)
|
||||
def clip__coreml(ctx, input, min=None, max=None, **kwargs) -> torch.Tensor:
|
||||
def clip__coreml(input, min=None, max=None, **kwargs) -> torch.Tensor:
|
||||
"""Rewrite `clip` for coreml backend.
|
||||
|
||||
Cast data type.
|
||||
"""
|
||||
ctx = FUNCTION_REWRITER.get_context()
|
||||
if min is not None and not isinstance(min, torch.Tensor):
|
||||
min = input.new_tensor(min)
|
||||
|
||||
|
|
|
@ -6,11 +6,12 @@ from mmdeploy.core import FUNCTION_REWRITER
|
|||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='torch.Tensor.expand', backend='ncnn')
|
||||
def expand__ncnn(ctx, self, *sizes) -> torch.Tensor:
|
||||
def expand__ncnn(self, *sizes) -> torch.Tensor:
|
||||
"""Rewrite `expand` for NCNN backend.
|
||||
|
||||
Do not expand on batch dim for tensor with ndim >= 3
|
||||
"""
|
||||
ctx = FUNCTION_REWRITER.get_context()
|
||||
if self.ndim < 3 or sizes[0] not in [1, -1]:
|
||||
return ctx.origin_func(*sizes)
|
||||
return self
|
||||
|
|
|
@ -13,7 +13,7 @@ from mmdeploy.utils import Backend
|
|||
func_name='torch.Tensor.flatten', backend=Backend.COREML.value)
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='torch.flatten', backend=Backend.COREML.value)
|
||||
def flatten__coreml(ctx, input, start_dim=0, end_dim=-1) -> torch.Tensor:
|
||||
def flatten__coreml(input, start_dim=0, end_dim=-1) -> torch.Tensor:
|
||||
"""Rewrite `flatten` for coreml backend.
|
||||
|
||||
Use reshape instead of flatten
|
||||
|
|
|
@ -6,13 +6,14 @@ from mmdeploy.core import FUNCTION_REWRITER
|
|||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='torch.Tensor.__getattribute__', backend='ncnn')
|
||||
def tensor__getattribute__ncnn(ctx, self: torch.Tensor, name: str):
|
||||
def tensor__getattribute__ncnn(self: torch.Tensor, name: str):
|
||||
"""Rewrite `__getattribute__` of `torch.Tensor` for ncnn backend.
|
||||
|
||||
Shape node is not supported by ncnn. This function transform dynamic shape
|
||||
to constant shape.
|
||||
"""
|
||||
|
||||
ctx = FUNCTION_REWRITER.get_context()
|
||||
ret = ctx.origin_func(self, name)
|
||||
if name == 'shape':
|
||||
ret = torch.Size([int(s) for s in ret])
|
||||
|
|
|
@ -9,7 +9,6 @@ from mmdeploy.core import FUNCTION_REWRITER
|
|||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='torch.nn.functional.group_norm', backend='ncnn')
|
||||
def group_norm__ncnn(
|
||||
ctx,
|
||||
input: torch.Tensor,
|
||||
num_groups: int,
|
||||
weight: Union[torch.Tensor, torch.NoneType] = None,
|
||||
|
|
|
@ -10,8 +10,7 @@ from mmdeploy.utils import Backend, get_root_logger
|
|||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='torch.nn.functional.interpolate', backend='ncnn')
|
||||
def interpolate__ncnn(ctx,
|
||||
input: torch.Tensor,
|
||||
def interpolate__ncnn(input: torch.Tensor,
|
||||
size: Optional[Union[int, Tuple[int], Tuple[int, int],
|
||||
Tuple[int, int, int]]] = None,
|
||||
scale_factor: Optional[Union[float,
|
||||
|
@ -24,6 +23,7 @@ def interpolate__ncnn(ctx,
|
|||
ncnn require `size` should be constant in ONNX Node. We use `scale_factor`
|
||||
instead of `size` to avoid dynamic size.
|
||||
"""
|
||||
ctx = FUNCTION_REWRITER.get_context()
|
||||
|
||||
input_size = input.shape
|
||||
if scale_factor is None:
|
||||
|
@ -42,8 +42,7 @@ def interpolate__ncnn(ctx,
|
|||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='torch.nn.functional.interpolate', backend='rknn')
|
||||
def interpolate__rknn(ctx,
|
||||
input: torch.Tensor,
|
||||
def interpolate__rknn(input: torch.Tensor,
|
||||
size: Optional[Union[int, Tuple[int], Tuple[int, int],
|
||||
Tuple[int, int, int]]] = None,
|
||||
scale_factor: Optional[Union[float,
|
||||
|
@ -56,6 +55,7 @@ def interpolate__rknn(ctx,
|
|||
rknn require `size` should be constant in ONNX Node. We use `scale_factor`
|
||||
instead of `size` to avoid dynamic size.
|
||||
"""
|
||||
ctx = FUNCTION_REWRITER.get_context()
|
||||
input_size = input.shape
|
||||
if scale_factor is None:
|
||||
scale_factor = [(s_out / s_in)
|
||||
|
@ -77,7 +77,6 @@ def interpolate__rknn(ctx,
|
|||
is_pytorch=True,
|
||||
backend=Backend.TENSORRT.value)
|
||||
def interpolate__tensorrt(
|
||||
ctx,
|
||||
input: torch.Tensor,
|
||||
size: Optional[Union[int, Tuple[int], Tuple[int, int], Tuple[int, int,
|
||||
int]]] = None,
|
||||
|
@ -87,6 +86,7 @@ def interpolate__tensorrt(
|
|||
recompute_scale_factor: Optional[bool] = None,
|
||||
):
|
||||
"""Register default symbolic function for `interpolate`."""
|
||||
ctx = FUNCTION_REWRITER.get_context()
|
||||
|
||||
class BicubicInterpolate(Function):
|
||||
|
||||
|
|
|
@ -30,7 +30,6 @@ class GemmOp(torch.autograd.Function):
|
|||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='torch.nn.functional.linear', backend='ncnn')
|
||||
def linear__ncnn(
|
||||
ctx,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: Optional[Union[torch.Tensor, torch.NoneType]] = None,
|
||||
|
@ -41,6 +40,7 @@ def linear__ncnn(
|
|||
add extra reshape and transpose to support linear operation of different
|
||||
input shape.
|
||||
"""
|
||||
ctx = FUNCTION_REWRITER.get_context()
|
||||
origin_func = ctx.origin_func
|
||||
dim = input.dim()
|
||||
|
||||
|
|
|
@ -13,13 +13,14 @@ from mmdeploy.utils.constants import Backend
|
|||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='torch.Tensor.masked_fill', backend=Backend.ONNXRUNTIME.value)
|
||||
def masked_fill__onnxruntime(
|
||||
ctx, input, mask: torch.Tensor, value: Union[torch.Tensor,
|
||||
Number]) -> torch.Tensor:
|
||||
input, mask: torch.Tensor, value: Union[torch.Tensor,
|
||||
Number]) -> torch.Tensor:
|
||||
"""Rewrite `masked_fill` for onnxruntime backend.
|
||||
|
||||
SATRN model as example, when value is set to `float('-inf')`, the results
|
||||
of ORT inferencing turns out to be NAN.
|
||||
"""
|
||||
ctx = FUNCTION_REWRITER.get_context()
|
||||
if value == float('-inf'):
|
||||
value = -1e34 # hard coding number
|
||||
return ctx.origin_func(input, mask, value)
|
||||
|
|
|
@ -11,10 +11,11 @@ from mmdeploy.utils.constants import Backend
|
|||
# TODO add version control when MOD is supported by TensorRT
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='torch.Tensor.__mod__', backend=Backend.TENSORRT.value)
|
||||
def mod__tensorrt(ctx, input: torch.Tensor, other: Union[torch.Tensor,
|
||||
torch.NumberType],
|
||||
*args, **kwargs) -> torch.Tensor:
|
||||
def mod__tensorrt(input: torch.Tensor, other: Union[torch.Tensor,
|
||||
torch.NumberType], *args,
|
||||
**kwargs) -> torch.Tensor:
|
||||
"""Rewrite `mod` when exporting model to ONNX for TensorRT backend."""
|
||||
ctx = FUNCTION_REWRITER.get_context()
|
||||
if version.parse(torch.__version__) > version.parse('1.10.0'):
|
||||
return input - (input // other) * other
|
||||
return ctx.origin_func(input, other, *args, **kwargs)
|
||||
|
|
|
@ -46,7 +46,6 @@ class ScaledDotProductAttentionTRT(torch.autograd.Function):
|
|||
func_name='torch.nn.functional._scaled_dot_product_attention',
|
||||
backend=Backend.TENSORRT.value)
|
||||
def _scaled_dot_product_attention__tensorrt(
|
||||
ctx,
|
||||
q: Tensor,
|
||||
k: Tensor,
|
||||
v: Tensor,
|
||||
|
|
|
@ -7,8 +7,7 @@ from mmdeploy.core import FUNCTION_REWRITER
|
|||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='torch.nn.functional.normalize', backend='ncnn')
|
||||
def normalize__ncnn(ctx,
|
||||
input: torch.Tensor,
|
||||
def normalize__ncnn(input: torch.Tensor,
|
||||
p: int = 2,
|
||||
dim: int = 1,
|
||||
eps: float = 1e-12,
|
||||
|
@ -18,6 +17,7 @@ def normalize__ncnn(ctx,
|
|||
|
||||
Make sure L2 norm on channel dim and be exported to ncnn correctly.
|
||||
"""
|
||||
ctx = FUNCTION_REWRITER.get_context()
|
||||
if dim < 0:
|
||||
dim += input.ndim
|
||||
assert dim != 0, 'Should not normalize on batch index'
|
||||
|
|
|
@ -11,7 +11,7 @@ from mmdeploy.core import FUNCTION_REWRITER
|
|||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='torch.onnx.symbolic_opset11._prepare_onnx_paddings',
|
||||
backend='tensorrt')
|
||||
def _prepare_onnx_paddings__tensorrt(ctx, g, input, pad):
|
||||
def _prepare_onnx_paddings__tensorrt(g, input, pad):
|
||||
"""Rewrite `_prepare_onnx_paddings` for TensorRT backend.
|
||||
|
||||
For codes like `x = torch.nn.ZeroPad2d((0, a, 0, b))(x)`, where a and b are
|
||||
|
@ -26,6 +26,7 @@ def _prepare_onnx_paddings__tensorrt(ctx, g, input, pad):
|
|||
..., dim_m_begin, dim_m_end,
|
||||
where m is in range [0, n].
|
||||
"""
|
||||
ctx = FUNCTION_REWRITER.get_context()
|
||||
torch_version = version_parse(torch.__version__)
|
||||
if torch_version.major == 1 and torch_version.minor < 10:
|
||||
return ctx.origin_func(g, input, pad)
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue