[Enhancement] Update function name and docstring in mmdeploy.pytorch (#191)

* Update function name and docstring in mmdeploy.python

* remove  in docstring
This commit is contained in:
q.yao 2021-11-17 14:20:29 +08:00 committed by GitHub
parent a4dceb4bb4
commit acf1dc5d88
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 133 additions and 54 deletions

View File

@ -1,13 +1,13 @@
from .getattribute import getattribute_static
from .group_norm import group_norm_ncnn
from .interpolate import interpolate_static
from .linear import linear_ncnn
from .repeat import repeat_static
from .size import size_of_tensor_static
from .topk import topk_dynamic, topk_static
from .getattribute import tensor__getattribute__ncnn
from .group_norm import group_norm__ncnn
from .interpolate import interpolate__ncnn
from .linear import linear__ncnn
from .repeat import tensor__repeat__tensorrt
from .size import tensor__size__ncnn
from .topk import topk__dynamic, topk__tensorrt
__all__ = [
'getattribute_static', 'group_norm_ncnn', 'interpolate_static',
'linear_ncnn', 'repeat_static', 'size_of_tensor_static', 'topk_static',
'topk_dynamic'
'tensor__getattribute__ncnn', 'group_norm__ncnn', 'interpolate__ncnn',
'linear__ncnn', 'tensor__repeat__tensorrt', 'tensor__size__ncnn',
'topk__dynamic', 'topk__tensorrt'
]

View File

@ -5,8 +5,12 @@ from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
func_name='torch.Tensor.__getattribute__', backend='ncnn')
def getattribute_static(ctx, self, name):
"""Rewrite `__getattribute__` for NCNN backend."""
def tensor__getattribute__ncnn(ctx, self: torch.Tensor, name: str):
"""Rewrite `__getattribute__` of `torch.Tensor` for NCNN backend.
Shape node is not supported by ncnn. This function transform dynamic shape
to constant shape.
"""
ret = ctx.origin_func(self, name)
if name == 'shape':

View File

@ -7,7 +7,7 @@ from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
func_name='torch.nn.functional.group_norm', backend='ncnn')
def group_norm_ncnn(
def group_norm__ncnn(
ctx,
input: torch.Tensor,
num_groups: int,
@ -15,7 +15,11 @@ def group_norm_ncnn(
bias: Union[torch.Tensor, torch.NoneType] = None,
eps: float = 1e-05,
) -> torch.Tensor:
"""Rewrite `group_norm` for NCNN backend."""
"""Rewrite `group_norm` for NCNN backend.
InstanceNorm in ncnn require input with shape [C, H, W]. So we have to
reshape the input tensor before it.
"""
input_shape = input.shape
batch_size = input_shape[0]
# We cannot use input.reshape(batch_size, num_groups, -1, 1)

View File

@ -1,16 +1,26 @@
from typing import Optional, Tuple, Union
import torch
from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
func_name='torch.nn.functional.interpolate', backend='ncnn')
def interpolate_static(ctx,
input,
size=None,
scale_factor=None,
mode='nearest',
align_corners=None,
recompute_scale_factor=None):
"""Rewrite `interpolate` for NCNN backend."""
def interpolate__ncnn(ctx,
input: torch.Tensor,
size: Optional[Union[int, Tuple[int], Tuple[int, int],
Tuple[int, int, int]]] = None,
scale_factor: Optional[Union[float,
Tuple[float]]] = None,
mode: str = 'nearest',
align_corners: Optional[bool] = None,
recompute_scale_factor: Optional[bool] = None):
"""Rewrite `interpolate` for NCNN backend.
NCNN require `size` should be constant in ONNX Node. We use `scale_factor`
instead of `size` to avoid dynamic size.
"""
input_size = input.shape
if scale_factor is None:

View File

@ -1,4 +1,4 @@
from typing import Union
from typing import Optional, Union
import torch
@ -7,13 +7,18 @@ from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
func_name='torch.nn.functional.linear', backend='ncnn')
def linear_ncnn(
def linear__ncnn(
ctx,
input: torch.Tensor,
weight: torch.Tensor,
bias: Union[torch.Tensor, torch.NoneType] = None,
bias: Optional[Union[torch.Tensor, torch.NoneType]] = None,
):
"""Rewrite `linear` for NCNN backend."""
"""Rewrite `linear` for NCNN backend.
The broadcast rules are different between ncnn and PyTorch. This function
add extra reshape and transpose to support linear operation of different
input shape.
"""
origin_func = ctx.origin_func

View File

@ -1,10 +1,19 @@
from typing import Sequence, Union
import torch
from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
func_name='torch.Tensor.repeat', backend='tensorrt')
def repeat_static(ctx, input, *size):
"""Rewrite `repeat` for NCNN backend."""
def tensor__repeat__tensorrt(ctx, input: torch.Tensor,
*size: Union[torch.Size, Sequence[int]]):
"""Rewrite `repeat` for TensorRT backend.
Some layers in TensorRT can not be applied on batch axis. add extra axis
before operation and remove it afterward.
"""
origin_func = ctx.origin_func
if input.dim() == 1 and len(size) == 1:

View File

@ -5,8 +5,12 @@ from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
func_name='torch.Tensor.size', backend='ncnn')
def size_of_tensor_static(ctx, self, *args):
"""Rewrite `size` for NCNN backend."""
def tensor__size__ncnn(ctx, self, *args):
"""Rewrite `size` for NCNN backend.
ONNX Shape node is not supported in ncnn. This function return integal
instead of Torch.Size to avoid ONNX Shape node.
"""
ret = ctx.origin_func(self, *args)
if isinstance(ret, torch.Tensor):

View File

@ -1,3 +1,5 @@
from typing import Optional
import torch
from mmdeploy.core import FUNCTION_REWRITER
@ -6,8 +8,16 @@ from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(func_name='torch.topk', backend='default')
@FUNCTION_REWRITER.register_rewriter(
func_name='torch.Tensor.topk', backend='default')
def topk_dynamic(ctx, input, k, dim=None, largest=True, sorted=True):
"""Rewrite `topk` for default backend."""
def topk__dynamic(ctx,
input: torch.Tensor,
k: int,
dim: Optional[int] = None,
largest: bool = True,
sorted: bool = True):
"""Rewrite `topk` for default backend.
Cast k to tensor and makesure k is smaller than input.shape[dim].
"""
if dim is None:
dim = int(input.ndim - 1)
@ -25,8 +35,17 @@ def topk_dynamic(ctx, input, k, dim=None, largest=True, sorted=True):
func_name='torch.topk', backend='tensorrt')
@FUNCTION_REWRITER.register_rewriter(
func_name='torch.Tensor.topk', backend='tensorrt')
def topk_static(ctx, input, k, dim=None, largest=True, sorted=True):
"""Rewrite `topk` for TensorRT backend."""
def topk__tensorrt(ctx,
input: torch.Tensor,
k: int,
dim: Optional[int] = None,
largest: bool = True,
sorted: bool = True):
"""Rewrite `topk` for TensorRT backend.
TensorRT does not support topk with dynamic k. This function cast k to
constant integer.
"""
if dim is None:
dim = int(input.ndim - 1)

View File

@ -1,11 +1,12 @@
from .adaptive_avg_pool import (adaptive_avg_pool1d_op, adaptive_avg_pool2d_op,
adaptive_avg_pool3d_op)
from .grid_sampler import grid_sampler_default
from .instance_norm import instance_norm_trt
from .squeeze import squeeze_default
from .adaptive_avg_pool import (adaptive_avg_pool1d__default,
adaptive_avg_pool2d__default,
adaptive_avg_pool3d__default)
from .grid_sampler import grid_sampler__default
from .instance_norm import instance_norm__tensorrt
from .squeeze import squeeze__default
__all__ = [
'adaptive_avg_pool1d_op', 'adaptive_avg_pool2d_op',
'adaptive_avg_pool3d_op', 'grid_sampler_default', 'instance_norm_trt',
'squeeze_default'
'adaptive_avg_pool1d__default', 'adaptive_avg_pool2d__default',
'adaptive_avg_pool3d__default', 'grid_sampler__default',
'instance_norm__tensorrt', 'squeeze__default'
]

View File

@ -1,3 +1,4 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# Modified from:
# https://github.com/pytorch/pytorch/blob/9ade03959392e5a90b74261012de1d806cab2253/torch/onnx/symbolic_opset9.py
@ -53,18 +54,27 @@ adaptive_avg_pool3d = _adaptive_pool('adaptive_avg_pool3d', 'AveragePool',
@SYMBOLIC_REWRITER.register_symbolic('adaptive_avg_pool1d', is_pytorch=True)
def adaptive_avg_pool1d_op(ctx, *args):
"""Register default symbolic function for `adaptive_avg_pool1d`."""
def adaptive_avg_pool1d__default(ctx, *args):
"""Register default symbolic function for `adaptive_avg_pool1d`.
Align symbolic of adaptive_pool between different torch version.
"""
return adaptive_avg_pool1d(*args)
@SYMBOLIC_REWRITER.register_symbolic('adaptive_avg_pool2d', is_pytorch=True)
def adaptive_avg_pool2d_op(ctx, *args):
"""Register default symbolic function for `adaptive_avg_pool2d`."""
def adaptive_avg_pool2d__default(ctx, *args):
"""Register default symbolic function for `adaptive_avg_pool2d`.
Align symbolic of adaptive_pool between different torch version.
"""
return adaptive_avg_pool2d(*args)
@SYMBOLIC_REWRITER.register_symbolic('adaptive_avg_pool3d', is_pytorch=True)
def adaptive_avg_pool3d_op(ctx, *args):
"""Register default symbolic function for `adaptive_avg_pool3d`."""
def adaptive_avg_pool3d__default(ctx, *args):
"""Register default symbolic function for `adaptive_avg_pool3d`.
Align symbolic of adaptive_pool between different torch version.
"""
return adaptive_avg_pool3d(*args)

View File

@ -10,7 +10,12 @@ def grid_sampler(g,
interpolation_mode,
padding_mode,
align_corners=False):
"""Symbolic function for `grid_sampler`."""
"""Symbolic function for `grid_sampler`.
PyTorch does not support export grid_sampler to ONNX by default. We add the
support here. `grid_sampler` will be exported as ONNX node
'mmcv::grid_sampler'
"""
return g.op(
'mmcv::grid_sampler',
input,
@ -21,6 +26,9 @@ def grid_sampler(g,
@SYMBOLIC_REWRITER.register_symbolic('grid_sampler', is_pytorch=True)
def grid_sampler_default(ctx, *args):
"""Register default symbolic function for `grid_sampler`."""
def grid_sampler__default(ctx, *args):
"""Register default symbolic function for `grid_sampler`.
Add support to grid_sample to ONNX.
"""
return grid_sampler(*args)

View File

@ -1,3 +1,4 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# Modified from:
# https://github.com/pytorch/pytorch/blob/9ade03959392e5a90b74261012de1d806cab2253/torch/onnx/symbolic_opset9.py
@ -63,7 +64,7 @@ def instance_norm(g, input, num_groups, weight, bias, eps, cudnn_enabled):
@SYMBOLIC_REWRITER.register_symbolic(
'group_norm', backend='tensorrt', is_pytorch=True)
def instance_norm_trt(ctx, *args):
def instance_norm__tensorrt(ctx, *args):
"""Register symbolic function for TensorRT backend.
Notes:

View File

@ -4,8 +4,12 @@ from mmdeploy.core import SYMBOLIC_REWRITER
@SYMBOLIC_REWRITER.register_symbolic('squeeze', is_pytorch=True)
def squeeze_default(ctx, g, self, dim=None):
"""Register default symbolic function for `squeeze`."""
def squeeze__default(ctx, g, self, dim=None):
"""Register default symbolic function for `squeeze`.
squeeze might be exported with IF node in ONNX, which is not supported in
lots of backend.
"""
if dim is None:
dims = []
for i, size in enumerate(self.type().sizes()):