mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
[Enhancement] Update function name and docstring in mmdeploy.pytorch (#191)
* Update function name and docstring in mmdeploy.python * remove in docstring
This commit is contained in:
parent
a4dceb4bb4
commit
acf1dc5d88
@ -1,13 +1,13 @@
|
|||||||
from .getattribute import getattribute_static
|
from .getattribute import tensor__getattribute__ncnn
|
||||||
from .group_norm import group_norm_ncnn
|
from .group_norm import group_norm__ncnn
|
||||||
from .interpolate import interpolate_static
|
from .interpolate import interpolate__ncnn
|
||||||
from .linear import linear_ncnn
|
from .linear import linear__ncnn
|
||||||
from .repeat import repeat_static
|
from .repeat import tensor__repeat__tensorrt
|
||||||
from .size import size_of_tensor_static
|
from .size import tensor__size__ncnn
|
||||||
from .topk import topk_dynamic, topk_static
|
from .topk import topk__dynamic, topk__tensorrt
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'getattribute_static', 'group_norm_ncnn', 'interpolate_static',
|
'tensor__getattribute__ncnn', 'group_norm__ncnn', 'interpolate__ncnn',
|
||||||
'linear_ncnn', 'repeat_static', 'size_of_tensor_static', 'topk_static',
|
'linear__ncnn', 'tensor__repeat__tensorrt', 'tensor__size__ncnn',
|
||||||
'topk_dynamic'
|
'topk__dynamic', 'topk__tensorrt'
|
||||||
]
|
]
|
||||||
|
@ -5,8 +5,12 @@ from mmdeploy.core import FUNCTION_REWRITER
|
|||||||
|
|
||||||
@FUNCTION_REWRITER.register_rewriter(
|
@FUNCTION_REWRITER.register_rewriter(
|
||||||
func_name='torch.Tensor.__getattribute__', backend='ncnn')
|
func_name='torch.Tensor.__getattribute__', backend='ncnn')
|
||||||
def getattribute_static(ctx, self, name):
|
def tensor__getattribute__ncnn(ctx, self: torch.Tensor, name: str):
|
||||||
"""Rewrite `__getattribute__` for NCNN backend."""
|
"""Rewrite `__getattribute__` of `torch.Tensor` for NCNN backend.
|
||||||
|
|
||||||
|
Shape node is not supported by ncnn. This function transform dynamic shape
|
||||||
|
to constant shape.
|
||||||
|
"""
|
||||||
|
|
||||||
ret = ctx.origin_func(self, name)
|
ret = ctx.origin_func(self, name)
|
||||||
if name == 'shape':
|
if name == 'shape':
|
||||||
|
@ -7,7 +7,7 @@ from mmdeploy.core import FUNCTION_REWRITER
|
|||||||
|
|
||||||
@FUNCTION_REWRITER.register_rewriter(
|
@FUNCTION_REWRITER.register_rewriter(
|
||||||
func_name='torch.nn.functional.group_norm', backend='ncnn')
|
func_name='torch.nn.functional.group_norm', backend='ncnn')
|
||||||
def group_norm_ncnn(
|
def group_norm__ncnn(
|
||||||
ctx,
|
ctx,
|
||||||
input: torch.Tensor,
|
input: torch.Tensor,
|
||||||
num_groups: int,
|
num_groups: int,
|
||||||
@ -15,7 +15,11 @@ def group_norm_ncnn(
|
|||||||
bias: Union[torch.Tensor, torch.NoneType] = None,
|
bias: Union[torch.Tensor, torch.NoneType] = None,
|
||||||
eps: float = 1e-05,
|
eps: float = 1e-05,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Rewrite `group_norm` for NCNN backend."""
|
"""Rewrite `group_norm` for NCNN backend.
|
||||||
|
|
||||||
|
InstanceNorm in ncnn require input with shape [C, H, W]. So we have to
|
||||||
|
reshape the input tensor before it.
|
||||||
|
"""
|
||||||
input_shape = input.shape
|
input_shape = input.shape
|
||||||
batch_size = input_shape[0]
|
batch_size = input_shape[0]
|
||||||
# We cannot use input.reshape(batch_size, num_groups, -1, 1)
|
# We cannot use input.reshape(batch_size, num_groups, -1, 1)
|
||||||
|
@ -1,16 +1,26 @@
|
|||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
from mmdeploy.core import FUNCTION_REWRITER
|
from mmdeploy.core import FUNCTION_REWRITER
|
||||||
|
|
||||||
|
|
||||||
@FUNCTION_REWRITER.register_rewriter(
|
@FUNCTION_REWRITER.register_rewriter(
|
||||||
func_name='torch.nn.functional.interpolate', backend='ncnn')
|
func_name='torch.nn.functional.interpolate', backend='ncnn')
|
||||||
def interpolate_static(ctx,
|
def interpolate__ncnn(ctx,
|
||||||
input,
|
input: torch.Tensor,
|
||||||
size=None,
|
size: Optional[Union[int, Tuple[int], Tuple[int, int],
|
||||||
scale_factor=None,
|
Tuple[int, int, int]]] = None,
|
||||||
mode='nearest',
|
scale_factor: Optional[Union[float,
|
||||||
align_corners=None,
|
Tuple[float]]] = None,
|
||||||
recompute_scale_factor=None):
|
mode: str = 'nearest',
|
||||||
"""Rewrite `interpolate` for NCNN backend."""
|
align_corners: Optional[bool] = None,
|
||||||
|
recompute_scale_factor: Optional[bool] = None):
|
||||||
|
"""Rewrite `interpolate` for NCNN backend.
|
||||||
|
|
||||||
|
NCNN require `size` should be constant in ONNX Node. We use `scale_factor`
|
||||||
|
instead of `size` to avoid dynamic size.
|
||||||
|
"""
|
||||||
|
|
||||||
input_size = input.shape
|
input_size = input.shape
|
||||||
if scale_factor is None:
|
if scale_factor is None:
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -7,13 +7,18 @@ from mmdeploy.core import FUNCTION_REWRITER
|
|||||||
|
|
||||||
@FUNCTION_REWRITER.register_rewriter(
|
@FUNCTION_REWRITER.register_rewriter(
|
||||||
func_name='torch.nn.functional.linear', backend='ncnn')
|
func_name='torch.nn.functional.linear', backend='ncnn')
|
||||||
def linear_ncnn(
|
def linear__ncnn(
|
||||||
ctx,
|
ctx,
|
||||||
input: torch.Tensor,
|
input: torch.Tensor,
|
||||||
weight: torch.Tensor,
|
weight: torch.Tensor,
|
||||||
bias: Union[torch.Tensor, torch.NoneType] = None,
|
bias: Optional[Union[torch.Tensor, torch.NoneType]] = None,
|
||||||
):
|
):
|
||||||
"""Rewrite `linear` for NCNN backend."""
|
"""Rewrite `linear` for NCNN backend.
|
||||||
|
|
||||||
|
The broadcast rules are different between ncnn and PyTorch. This function
|
||||||
|
add extra reshape and transpose to support linear operation of different
|
||||||
|
input shape.
|
||||||
|
"""
|
||||||
|
|
||||||
origin_func = ctx.origin_func
|
origin_func = ctx.origin_func
|
||||||
|
|
||||||
|
@ -1,10 +1,19 @@
|
|||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
from mmdeploy.core import FUNCTION_REWRITER
|
from mmdeploy.core import FUNCTION_REWRITER
|
||||||
|
|
||||||
|
|
||||||
@FUNCTION_REWRITER.register_rewriter(
|
@FUNCTION_REWRITER.register_rewriter(
|
||||||
func_name='torch.Tensor.repeat', backend='tensorrt')
|
func_name='torch.Tensor.repeat', backend='tensorrt')
|
||||||
def repeat_static(ctx, input, *size):
|
def tensor__repeat__tensorrt(ctx, input: torch.Tensor,
|
||||||
"""Rewrite `repeat` for NCNN backend."""
|
*size: Union[torch.Size, Sequence[int]]):
|
||||||
|
"""Rewrite `repeat` for TensorRT backend.
|
||||||
|
|
||||||
|
Some layers in TensorRT can not be applied on batch axis. add extra axis
|
||||||
|
before operation and remove it afterward.
|
||||||
|
"""
|
||||||
|
|
||||||
origin_func = ctx.origin_func
|
origin_func = ctx.origin_func
|
||||||
if input.dim() == 1 and len(size) == 1:
|
if input.dim() == 1 and len(size) == 1:
|
||||||
|
@ -5,8 +5,12 @@ from mmdeploy.core import FUNCTION_REWRITER
|
|||||||
|
|
||||||
@FUNCTION_REWRITER.register_rewriter(
|
@FUNCTION_REWRITER.register_rewriter(
|
||||||
func_name='torch.Tensor.size', backend='ncnn')
|
func_name='torch.Tensor.size', backend='ncnn')
|
||||||
def size_of_tensor_static(ctx, self, *args):
|
def tensor__size__ncnn(ctx, self, *args):
|
||||||
"""Rewrite `size` for NCNN backend."""
|
"""Rewrite `size` for NCNN backend.
|
||||||
|
|
||||||
|
ONNX Shape node is not supported in ncnn. This function return integal
|
||||||
|
instead of Torch.Size to avoid ONNX Shape node.
|
||||||
|
"""
|
||||||
|
|
||||||
ret = ctx.origin_func(self, *args)
|
ret = ctx.origin_func(self, *args)
|
||||||
if isinstance(ret, torch.Tensor):
|
if isinstance(ret, torch.Tensor):
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from mmdeploy.core import FUNCTION_REWRITER
|
from mmdeploy.core import FUNCTION_REWRITER
|
||||||
@ -6,8 +8,16 @@ from mmdeploy.core import FUNCTION_REWRITER
|
|||||||
@FUNCTION_REWRITER.register_rewriter(func_name='torch.topk', backend='default')
|
@FUNCTION_REWRITER.register_rewriter(func_name='torch.topk', backend='default')
|
||||||
@FUNCTION_REWRITER.register_rewriter(
|
@FUNCTION_REWRITER.register_rewriter(
|
||||||
func_name='torch.Tensor.topk', backend='default')
|
func_name='torch.Tensor.topk', backend='default')
|
||||||
def topk_dynamic(ctx, input, k, dim=None, largest=True, sorted=True):
|
def topk__dynamic(ctx,
|
||||||
"""Rewrite `topk` for default backend."""
|
input: torch.Tensor,
|
||||||
|
k: int,
|
||||||
|
dim: Optional[int] = None,
|
||||||
|
largest: bool = True,
|
||||||
|
sorted: bool = True):
|
||||||
|
"""Rewrite `topk` for default backend.
|
||||||
|
|
||||||
|
Cast k to tensor and makesure k is smaller than input.shape[dim].
|
||||||
|
"""
|
||||||
|
|
||||||
if dim is None:
|
if dim is None:
|
||||||
dim = int(input.ndim - 1)
|
dim = int(input.ndim - 1)
|
||||||
@ -25,8 +35,17 @@ def topk_dynamic(ctx, input, k, dim=None, largest=True, sorted=True):
|
|||||||
func_name='torch.topk', backend='tensorrt')
|
func_name='torch.topk', backend='tensorrt')
|
||||||
@FUNCTION_REWRITER.register_rewriter(
|
@FUNCTION_REWRITER.register_rewriter(
|
||||||
func_name='torch.Tensor.topk', backend='tensorrt')
|
func_name='torch.Tensor.topk', backend='tensorrt')
|
||||||
def topk_static(ctx, input, k, dim=None, largest=True, sorted=True):
|
def topk__tensorrt(ctx,
|
||||||
"""Rewrite `topk` for TensorRT backend."""
|
input: torch.Tensor,
|
||||||
|
k: int,
|
||||||
|
dim: Optional[int] = None,
|
||||||
|
largest: bool = True,
|
||||||
|
sorted: bool = True):
|
||||||
|
"""Rewrite `topk` for TensorRT backend.
|
||||||
|
|
||||||
|
TensorRT does not support topk with dynamic k. This function cast k to
|
||||||
|
constant integer.
|
||||||
|
"""
|
||||||
|
|
||||||
if dim is None:
|
if dim is None:
|
||||||
dim = int(input.ndim - 1)
|
dim = int(input.ndim - 1)
|
||||||
|
@ -1,11 +1,12 @@
|
|||||||
from .adaptive_avg_pool import (adaptive_avg_pool1d_op, adaptive_avg_pool2d_op,
|
from .adaptive_avg_pool import (adaptive_avg_pool1d__default,
|
||||||
adaptive_avg_pool3d_op)
|
adaptive_avg_pool2d__default,
|
||||||
from .grid_sampler import grid_sampler_default
|
adaptive_avg_pool3d__default)
|
||||||
from .instance_norm import instance_norm_trt
|
from .grid_sampler import grid_sampler__default
|
||||||
from .squeeze import squeeze_default
|
from .instance_norm import instance_norm__tensorrt
|
||||||
|
from .squeeze import squeeze__default
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'adaptive_avg_pool1d_op', 'adaptive_avg_pool2d_op',
|
'adaptive_avg_pool1d__default', 'adaptive_avg_pool2d__default',
|
||||||
'adaptive_avg_pool3d_op', 'grid_sampler_default', 'instance_norm_trt',
|
'adaptive_avg_pool3d__default', 'grid_sampler__default',
|
||||||
'squeeze_default'
|
'instance_norm__tensorrt', 'squeeze__default'
|
||||||
]
|
]
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||||
# Modified from:
|
# Modified from:
|
||||||
# https://github.com/pytorch/pytorch/blob/9ade03959392e5a90b74261012de1d806cab2253/torch/onnx/symbolic_opset9.py
|
# https://github.com/pytorch/pytorch/blob/9ade03959392e5a90b74261012de1d806cab2253/torch/onnx/symbolic_opset9.py
|
||||||
|
|
||||||
@ -53,18 +54,27 @@ adaptive_avg_pool3d = _adaptive_pool('adaptive_avg_pool3d', 'AveragePool',
|
|||||||
|
|
||||||
|
|
||||||
@SYMBOLIC_REWRITER.register_symbolic('adaptive_avg_pool1d', is_pytorch=True)
|
@SYMBOLIC_REWRITER.register_symbolic('adaptive_avg_pool1d', is_pytorch=True)
|
||||||
def adaptive_avg_pool1d_op(ctx, *args):
|
def adaptive_avg_pool1d__default(ctx, *args):
|
||||||
"""Register default symbolic function for `adaptive_avg_pool1d`."""
|
"""Register default symbolic function for `adaptive_avg_pool1d`.
|
||||||
|
|
||||||
|
Align symbolic of adaptive_pool between different torch version.
|
||||||
|
"""
|
||||||
return adaptive_avg_pool1d(*args)
|
return adaptive_avg_pool1d(*args)
|
||||||
|
|
||||||
|
|
||||||
@SYMBOLIC_REWRITER.register_symbolic('adaptive_avg_pool2d', is_pytorch=True)
|
@SYMBOLIC_REWRITER.register_symbolic('adaptive_avg_pool2d', is_pytorch=True)
|
||||||
def adaptive_avg_pool2d_op(ctx, *args):
|
def adaptive_avg_pool2d__default(ctx, *args):
|
||||||
"""Register default symbolic function for `adaptive_avg_pool2d`."""
|
"""Register default symbolic function for `adaptive_avg_pool2d`.
|
||||||
|
|
||||||
|
Align symbolic of adaptive_pool between different torch version.
|
||||||
|
"""
|
||||||
return adaptive_avg_pool2d(*args)
|
return adaptive_avg_pool2d(*args)
|
||||||
|
|
||||||
|
|
||||||
@SYMBOLIC_REWRITER.register_symbolic('adaptive_avg_pool3d', is_pytorch=True)
|
@SYMBOLIC_REWRITER.register_symbolic('adaptive_avg_pool3d', is_pytorch=True)
|
||||||
def adaptive_avg_pool3d_op(ctx, *args):
|
def adaptive_avg_pool3d__default(ctx, *args):
|
||||||
"""Register default symbolic function for `adaptive_avg_pool3d`."""
|
"""Register default symbolic function for `adaptive_avg_pool3d`.
|
||||||
|
|
||||||
|
Align symbolic of adaptive_pool between different torch version.
|
||||||
|
"""
|
||||||
return adaptive_avg_pool3d(*args)
|
return adaptive_avg_pool3d(*args)
|
||||||
|
@ -10,7 +10,12 @@ def grid_sampler(g,
|
|||||||
interpolation_mode,
|
interpolation_mode,
|
||||||
padding_mode,
|
padding_mode,
|
||||||
align_corners=False):
|
align_corners=False):
|
||||||
"""Symbolic function for `grid_sampler`."""
|
"""Symbolic function for `grid_sampler`.
|
||||||
|
|
||||||
|
PyTorch does not support export grid_sampler to ONNX by default. We add the
|
||||||
|
support here. `grid_sampler` will be exported as ONNX node
|
||||||
|
'mmcv::grid_sampler'
|
||||||
|
"""
|
||||||
return g.op(
|
return g.op(
|
||||||
'mmcv::grid_sampler',
|
'mmcv::grid_sampler',
|
||||||
input,
|
input,
|
||||||
@ -21,6 +26,9 @@ def grid_sampler(g,
|
|||||||
|
|
||||||
|
|
||||||
@SYMBOLIC_REWRITER.register_symbolic('grid_sampler', is_pytorch=True)
|
@SYMBOLIC_REWRITER.register_symbolic('grid_sampler', is_pytorch=True)
|
||||||
def grid_sampler_default(ctx, *args):
|
def grid_sampler__default(ctx, *args):
|
||||||
"""Register default symbolic function for `grid_sampler`."""
|
"""Register default symbolic function for `grid_sampler`.
|
||||||
|
|
||||||
|
Add support to grid_sample to ONNX.
|
||||||
|
"""
|
||||||
return grid_sampler(*args)
|
return grid_sampler(*args)
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||||
# Modified from:
|
# Modified from:
|
||||||
# https://github.com/pytorch/pytorch/blob/9ade03959392e5a90b74261012de1d806cab2253/torch/onnx/symbolic_opset9.py
|
# https://github.com/pytorch/pytorch/blob/9ade03959392e5a90b74261012de1d806cab2253/torch/onnx/symbolic_opset9.py
|
||||||
|
|
||||||
@ -63,7 +64,7 @@ def instance_norm(g, input, num_groups, weight, bias, eps, cudnn_enabled):
|
|||||||
|
|
||||||
@SYMBOLIC_REWRITER.register_symbolic(
|
@SYMBOLIC_REWRITER.register_symbolic(
|
||||||
'group_norm', backend='tensorrt', is_pytorch=True)
|
'group_norm', backend='tensorrt', is_pytorch=True)
|
||||||
def instance_norm_trt(ctx, *args):
|
def instance_norm__tensorrt(ctx, *args):
|
||||||
"""Register symbolic function for TensorRT backend.
|
"""Register symbolic function for TensorRT backend.
|
||||||
|
|
||||||
Notes:
|
Notes:
|
||||||
|
@ -4,8 +4,12 @@ from mmdeploy.core import SYMBOLIC_REWRITER
|
|||||||
|
|
||||||
|
|
||||||
@SYMBOLIC_REWRITER.register_symbolic('squeeze', is_pytorch=True)
|
@SYMBOLIC_REWRITER.register_symbolic('squeeze', is_pytorch=True)
|
||||||
def squeeze_default(ctx, g, self, dim=None):
|
def squeeze__default(ctx, g, self, dim=None):
|
||||||
"""Register default symbolic function for `squeeze`."""
|
"""Register default symbolic function for `squeeze`.
|
||||||
|
|
||||||
|
squeeze might be exported with IF node in ONNX, which is not supported in
|
||||||
|
lots of backend.
|
||||||
|
"""
|
||||||
if dim is None:
|
if dim is None:
|
||||||
dims = []
|
dims = []
|
||||||
for i, size in enumerate(self.type().sizes()):
|
for i, size in enumerate(self.type().sizes()):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user