mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
[Enhancement] Update function name and docstring in mmdeploy.pytorch (#191)
* Update function name and docstring in mmdeploy.python * remove in docstring
This commit is contained in:
parent
a4dceb4bb4
commit
acf1dc5d88
@ -1,13 +1,13 @@
|
||||
from .getattribute import getattribute_static
|
||||
from .group_norm import group_norm_ncnn
|
||||
from .interpolate import interpolate_static
|
||||
from .linear import linear_ncnn
|
||||
from .repeat import repeat_static
|
||||
from .size import size_of_tensor_static
|
||||
from .topk import topk_dynamic, topk_static
|
||||
from .getattribute import tensor__getattribute__ncnn
|
||||
from .group_norm import group_norm__ncnn
|
||||
from .interpolate import interpolate__ncnn
|
||||
from .linear import linear__ncnn
|
||||
from .repeat import tensor__repeat__tensorrt
|
||||
from .size import tensor__size__ncnn
|
||||
from .topk import topk__dynamic, topk__tensorrt
|
||||
|
||||
__all__ = [
|
||||
'getattribute_static', 'group_norm_ncnn', 'interpolate_static',
|
||||
'linear_ncnn', 'repeat_static', 'size_of_tensor_static', 'topk_static',
|
||||
'topk_dynamic'
|
||||
'tensor__getattribute__ncnn', 'group_norm__ncnn', 'interpolate__ncnn',
|
||||
'linear__ncnn', 'tensor__repeat__tensorrt', 'tensor__size__ncnn',
|
||||
'topk__dynamic', 'topk__tensorrt'
|
||||
]
|
||||
|
@ -5,8 +5,12 @@ from mmdeploy.core import FUNCTION_REWRITER
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='torch.Tensor.__getattribute__', backend='ncnn')
|
||||
def getattribute_static(ctx, self, name):
|
||||
"""Rewrite `__getattribute__` for NCNN backend."""
|
||||
def tensor__getattribute__ncnn(ctx, self: torch.Tensor, name: str):
|
||||
"""Rewrite `__getattribute__` of `torch.Tensor` for NCNN backend.
|
||||
|
||||
Shape node is not supported by ncnn. This function transform dynamic shape
|
||||
to constant shape.
|
||||
"""
|
||||
|
||||
ret = ctx.origin_func(self, name)
|
||||
if name == 'shape':
|
||||
|
@ -7,7 +7,7 @@ from mmdeploy.core import FUNCTION_REWRITER
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='torch.nn.functional.group_norm', backend='ncnn')
|
||||
def group_norm_ncnn(
|
||||
def group_norm__ncnn(
|
||||
ctx,
|
||||
input: torch.Tensor,
|
||||
num_groups: int,
|
||||
@ -15,7 +15,11 @@ def group_norm_ncnn(
|
||||
bias: Union[torch.Tensor, torch.NoneType] = None,
|
||||
eps: float = 1e-05,
|
||||
) -> torch.Tensor:
|
||||
"""Rewrite `group_norm` for NCNN backend."""
|
||||
"""Rewrite `group_norm` for NCNN backend.
|
||||
|
||||
InstanceNorm in ncnn require input with shape [C, H, W]. So we have to
|
||||
reshape the input tensor before it.
|
||||
"""
|
||||
input_shape = input.shape
|
||||
batch_size = input_shape[0]
|
||||
# We cannot use input.reshape(batch_size, num_groups, -1, 1)
|
||||
|
@ -1,16 +1,26 @@
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='torch.nn.functional.interpolate', backend='ncnn')
|
||||
def interpolate_static(ctx,
|
||||
input,
|
||||
size=None,
|
||||
scale_factor=None,
|
||||
mode='nearest',
|
||||
align_corners=None,
|
||||
recompute_scale_factor=None):
|
||||
"""Rewrite `interpolate` for NCNN backend."""
|
||||
def interpolate__ncnn(ctx,
|
||||
input: torch.Tensor,
|
||||
size: Optional[Union[int, Tuple[int], Tuple[int, int],
|
||||
Tuple[int, int, int]]] = None,
|
||||
scale_factor: Optional[Union[float,
|
||||
Tuple[float]]] = None,
|
||||
mode: str = 'nearest',
|
||||
align_corners: Optional[bool] = None,
|
||||
recompute_scale_factor: Optional[bool] = None):
|
||||
"""Rewrite `interpolate` for NCNN backend.
|
||||
|
||||
NCNN require `size` should be constant in ONNX Node. We use `scale_factor`
|
||||
instead of `size` to avoid dynamic size.
|
||||
"""
|
||||
|
||||
input_size = input.shape
|
||||
if scale_factor is None:
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -7,13 +7,18 @@ from mmdeploy.core import FUNCTION_REWRITER
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='torch.nn.functional.linear', backend='ncnn')
|
||||
def linear_ncnn(
|
||||
def linear__ncnn(
|
||||
ctx,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: Union[torch.Tensor, torch.NoneType] = None,
|
||||
bias: Optional[Union[torch.Tensor, torch.NoneType]] = None,
|
||||
):
|
||||
"""Rewrite `linear` for NCNN backend."""
|
||||
"""Rewrite `linear` for NCNN backend.
|
||||
|
||||
The broadcast rules are different between ncnn and PyTorch. This function
|
||||
add extra reshape and transpose to support linear operation of different
|
||||
input shape.
|
||||
"""
|
||||
|
||||
origin_func = ctx.origin_func
|
||||
|
||||
|
@ -1,10 +1,19 @@
|
||||
from typing import Sequence, Union
|
||||
|
||||
import torch
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='torch.Tensor.repeat', backend='tensorrt')
|
||||
def repeat_static(ctx, input, *size):
|
||||
"""Rewrite `repeat` for NCNN backend."""
|
||||
def tensor__repeat__tensorrt(ctx, input: torch.Tensor,
|
||||
*size: Union[torch.Size, Sequence[int]]):
|
||||
"""Rewrite `repeat` for TensorRT backend.
|
||||
|
||||
Some layers in TensorRT can not be applied on batch axis. add extra axis
|
||||
before operation and remove it afterward.
|
||||
"""
|
||||
|
||||
origin_func = ctx.origin_func
|
||||
if input.dim() == 1 and len(size) == 1:
|
||||
|
@ -5,8 +5,12 @@ from mmdeploy.core import FUNCTION_REWRITER
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='torch.Tensor.size', backend='ncnn')
|
||||
def size_of_tensor_static(ctx, self, *args):
|
||||
"""Rewrite `size` for NCNN backend."""
|
||||
def tensor__size__ncnn(ctx, self, *args):
|
||||
"""Rewrite `size` for NCNN backend.
|
||||
|
||||
ONNX Shape node is not supported in ncnn. This function return integal
|
||||
instead of Torch.Size to avoid ONNX Shape node.
|
||||
"""
|
||||
|
||||
ret = ctx.origin_func(self, *args)
|
||||
if isinstance(ret, torch.Tensor):
|
||||
|
@ -1,3 +1,5 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
@ -6,8 +8,16 @@ from mmdeploy.core import FUNCTION_REWRITER
|
||||
@FUNCTION_REWRITER.register_rewriter(func_name='torch.topk', backend='default')
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='torch.Tensor.topk', backend='default')
|
||||
def topk_dynamic(ctx, input, k, dim=None, largest=True, sorted=True):
|
||||
"""Rewrite `topk` for default backend."""
|
||||
def topk__dynamic(ctx,
|
||||
input: torch.Tensor,
|
||||
k: int,
|
||||
dim: Optional[int] = None,
|
||||
largest: bool = True,
|
||||
sorted: bool = True):
|
||||
"""Rewrite `topk` for default backend.
|
||||
|
||||
Cast k to tensor and makesure k is smaller than input.shape[dim].
|
||||
"""
|
||||
|
||||
if dim is None:
|
||||
dim = int(input.ndim - 1)
|
||||
@ -25,8 +35,17 @@ def topk_dynamic(ctx, input, k, dim=None, largest=True, sorted=True):
|
||||
func_name='torch.topk', backend='tensorrt')
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='torch.Tensor.topk', backend='tensorrt')
|
||||
def topk_static(ctx, input, k, dim=None, largest=True, sorted=True):
|
||||
"""Rewrite `topk` for TensorRT backend."""
|
||||
def topk__tensorrt(ctx,
|
||||
input: torch.Tensor,
|
||||
k: int,
|
||||
dim: Optional[int] = None,
|
||||
largest: bool = True,
|
||||
sorted: bool = True):
|
||||
"""Rewrite `topk` for TensorRT backend.
|
||||
|
||||
TensorRT does not support topk with dynamic k. This function cast k to
|
||||
constant integer.
|
||||
"""
|
||||
|
||||
if dim is None:
|
||||
dim = int(input.ndim - 1)
|
||||
|
@ -1,11 +1,12 @@
|
||||
from .adaptive_avg_pool import (adaptive_avg_pool1d_op, adaptive_avg_pool2d_op,
|
||||
adaptive_avg_pool3d_op)
|
||||
from .grid_sampler import grid_sampler_default
|
||||
from .instance_norm import instance_norm_trt
|
||||
from .squeeze import squeeze_default
|
||||
from .adaptive_avg_pool import (adaptive_avg_pool1d__default,
|
||||
adaptive_avg_pool2d__default,
|
||||
adaptive_avg_pool3d__default)
|
||||
from .grid_sampler import grid_sampler__default
|
||||
from .instance_norm import instance_norm__tensorrt
|
||||
from .squeeze import squeeze__default
|
||||
|
||||
__all__ = [
|
||||
'adaptive_avg_pool1d_op', 'adaptive_avg_pool2d_op',
|
||||
'adaptive_avg_pool3d_op', 'grid_sampler_default', 'instance_norm_trt',
|
||||
'squeeze_default'
|
||||
'adaptive_avg_pool1d__default', 'adaptive_avg_pool2d__default',
|
||||
'adaptive_avg_pool3d__default', 'grid_sampler__default',
|
||||
'instance_norm__tensorrt', 'squeeze__default'
|
||||
]
|
||||
|
@ -1,3 +1,4 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
# Modified from:
|
||||
# https://github.com/pytorch/pytorch/blob/9ade03959392e5a90b74261012de1d806cab2253/torch/onnx/symbolic_opset9.py
|
||||
|
||||
@ -53,18 +54,27 @@ adaptive_avg_pool3d = _adaptive_pool('adaptive_avg_pool3d', 'AveragePool',
|
||||
|
||||
|
||||
@SYMBOLIC_REWRITER.register_symbolic('adaptive_avg_pool1d', is_pytorch=True)
|
||||
def adaptive_avg_pool1d_op(ctx, *args):
|
||||
"""Register default symbolic function for `adaptive_avg_pool1d`."""
|
||||
def adaptive_avg_pool1d__default(ctx, *args):
|
||||
"""Register default symbolic function for `adaptive_avg_pool1d`.
|
||||
|
||||
Align symbolic of adaptive_pool between different torch version.
|
||||
"""
|
||||
return adaptive_avg_pool1d(*args)
|
||||
|
||||
|
||||
@SYMBOLIC_REWRITER.register_symbolic('adaptive_avg_pool2d', is_pytorch=True)
|
||||
def adaptive_avg_pool2d_op(ctx, *args):
|
||||
"""Register default symbolic function for `adaptive_avg_pool2d`."""
|
||||
def adaptive_avg_pool2d__default(ctx, *args):
|
||||
"""Register default symbolic function for `adaptive_avg_pool2d`.
|
||||
|
||||
Align symbolic of adaptive_pool between different torch version.
|
||||
"""
|
||||
return adaptive_avg_pool2d(*args)
|
||||
|
||||
|
||||
@SYMBOLIC_REWRITER.register_symbolic('adaptive_avg_pool3d', is_pytorch=True)
|
||||
def adaptive_avg_pool3d_op(ctx, *args):
|
||||
"""Register default symbolic function for `adaptive_avg_pool3d`."""
|
||||
def adaptive_avg_pool3d__default(ctx, *args):
|
||||
"""Register default symbolic function for `adaptive_avg_pool3d`.
|
||||
|
||||
Align symbolic of adaptive_pool between different torch version.
|
||||
"""
|
||||
return adaptive_avg_pool3d(*args)
|
||||
|
@ -10,7 +10,12 @@ def grid_sampler(g,
|
||||
interpolation_mode,
|
||||
padding_mode,
|
||||
align_corners=False):
|
||||
"""Symbolic function for `grid_sampler`."""
|
||||
"""Symbolic function for `grid_sampler`.
|
||||
|
||||
PyTorch does not support export grid_sampler to ONNX by default. We add the
|
||||
support here. `grid_sampler` will be exported as ONNX node
|
||||
'mmcv::grid_sampler'
|
||||
"""
|
||||
return g.op(
|
||||
'mmcv::grid_sampler',
|
||||
input,
|
||||
@ -21,6 +26,9 @@ def grid_sampler(g,
|
||||
|
||||
|
||||
@SYMBOLIC_REWRITER.register_symbolic('grid_sampler', is_pytorch=True)
|
||||
def grid_sampler_default(ctx, *args):
|
||||
"""Register default symbolic function for `grid_sampler`."""
|
||||
def grid_sampler__default(ctx, *args):
|
||||
"""Register default symbolic function for `grid_sampler`.
|
||||
|
||||
Add support to grid_sample to ONNX.
|
||||
"""
|
||||
return grid_sampler(*args)
|
||||
|
@ -1,3 +1,4 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
# Modified from:
|
||||
# https://github.com/pytorch/pytorch/blob/9ade03959392e5a90b74261012de1d806cab2253/torch/onnx/symbolic_opset9.py
|
||||
|
||||
@ -63,7 +64,7 @@ def instance_norm(g, input, num_groups, weight, bias, eps, cudnn_enabled):
|
||||
|
||||
@SYMBOLIC_REWRITER.register_symbolic(
|
||||
'group_norm', backend='tensorrt', is_pytorch=True)
|
||||
def instance_norm_trt(ctx, *args):
|
||||
def instance_norm__tensorrt(ctx, *args):
|
||||
"""Register symbolic function for TensorRT backend.
|
||||
|
||||
Notes:
|
||||
|
@ -4,8 +4,12 @@ from mmdeploy.core import SYMBOLIC_REWRITER
|
||||
|
||||
|
||||
@SYMBOLIC_REWRITER.register_symbolic('squeeze', is_pytorch=True)
|
||||
def squeeze_default(ctx, g, self, dim=None):
|
||||
"""Register default symbolic function for `squeeze`."""
|
||||
def squeeze__default(ctx, g, self, dim=None):
|
||||
"""Register default symbolic function for `squeeze`.
|
||||
|
||||
squeeze might be exported with IF node in ONNX, which is not supported in
|
||||
lots of backend.
|
||||
"""
|
||||
if dim is None:
|
||||
dims = []
|
||||
for i, size in enumerate(self.type().sizes()):
|
||||
|
Loading…
x
Reference in New Issue
Block a user