[Fix] Fix for torch113 for master ()

pull/1508/head
q.yao 2022-12-08 17:17:27 +08:00 committed by GitHub
parent 4046e13146
commit 8ea3dc943f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 51 additions and 178 deletions
tests/test_mmcv

View File

@ -30,3 +30,10 @@ def jit_pass_onnx_deduplicate_initializers__disable(ctx, graph, param_dict,
disable for TensorRT.
"""
return param_dict
@FUNCTION_REWRITER.register_rewriter(
'torch._C._jit_pass_onnx_autograd_function_process')
def jit_pass_onnx_autograd_function_process__disable(ctx, graph):
"""Disable process autograph function."""
return

View File

@ -12,6 +12,14 @@ class GridPriorsTRTOp(torch.autograd.Function):
@staticmethod
def forward(ctx, base_anchors, feat_h, feat_w, stride_h: int,
stride_w: int):
"""Generate grid priors by base anchors."""
# torch>=1.13 has runtime error
# when using torch.arange in autograd function
output = getattr(GridPriorsTRTOp, 'output', None)
if output is not None:
return output
device = base_anchors.device
dtype = base_anchors.dtype
shift_x = torch.arange(0, feat_w, device=device).to(dtype) * stride_w
@ -39,8 +47,8 @@ class GridPriorsTRTOp(torch.autograd.Function):
def symbolic(g, base_anchors, feat_h, feat_w, stride_h: int,
stride_w: int):
# zero_h and zero_w is used to provide shape to GridPriorsTRT
feat_h = symbolic_helper._unsqueeze_helper(g, feat_h, [0])
feat_w = symbolic_helper._unsqueeze_helper(g, feat_w, [0])
feat_h = g.op('Unsqueeze', feat_h, axes_i=[0])
feat_w = g.op('Unsqueeze', feat_w, axes_i=[0])
zero_h = g.op(
'ConstantOfShape',
feat_h,
@ -90,9 +98,11 @@ def anchorgenerator__single_level_grid_priors__trt(
torch.Tensor: Anchors in the overall feature maps.
"""
feat_h, feat_w = featmap_size
output = ctx.origin_func(self, featmap_size, level_idx, dtype, device).data
if isinstance(feat_h, int) and isinstance(feat_w, int):
return ctx.origin_func(self, featmap_size, level_idx, dtype,
device).data
return output
base_anchors = self.base_anchors[level_idx].to(device).to(dtype)
stride_w, stride_h = self.strides[level_idx]
GridPriorsTRTOp.output = output
return grid_priors_trt(base_anchors, feat_h, feat_w, stride_h, stride_w)

View File

@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from packaging import version
from torch import Tensor
from mmdeploy.core import FUNCTION_REWRITER, mark
@ -88,7 +89,9 @@ def _multiclass_nms(boxes: Tensor,
shape (N, num_bboxes, num_classes) and the boxes is of shape (N, num_boxes,
4).
"""
max_output_boxes_per_class = torch.LongTensor([max_output_boxes_per_class])
if version.parse(torch.__version__) < version.parse('1.13.0'):
max_output_boxes_per_class = torch.LongTensor(
[max_output_boxes_per_class])
iou_threshold = torch.tensor([iou_threshold], dtype=torch.float32)
score_threshold = torch.tensor([score_threshold], dtype=torch.float32)
batch_size = scores.shape[0]
@ -122,7 +125,9 @@ def _multiclass_nms_single(boxes: Tensor,
Single batch nms could be optimized.
"""
max_output_boxes_per_class = torch.LongTensor([max_output_boxes_per_class])
if version.parse(torch.__version__) < version.parse('1.13.0'):
max_output_boxes_per_class = torch.LongTensor(
[max_output_boxes_per_class])
iou_threshold = torch.tensor([iou_threshold], dtype=torch.float32)
score_threshold = torch.tensor([score_threshold], dtype=torch.float32)

View File

@ -269,8 +269,8 @@ class SingleRoIExtractorOpenVINO(Function):
@staticmethod
def symbolic(g, output_size, featmap_strides, sample_num, rois, *feats):
from torch.onnx.symbolic_helper import _slice_helper
rois = _slice_helper(g, rois, axes=[1], starts=[1], ends=[5])
from torch.onnx.symbolic_opset10 import _slice
rois = _slice(g, rois, axes=[1], starts=[1], ends=[5])
domain = 'org.openvinotoolkit'
op_name = 'ExperimentalDetectronROIFeatureExtractor'
roi_feats = g.op(

View File

@ -1,8 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from . import context_block # noqa: F401,F403
from . import conv2d_adaptive_padding # noqa: F401,F403
from . import hsigmoid # noqa: F401,F403
from . import hswish # noqa: F401,F403
from .transformer import MultiHeadAttentionop
__all__ = ['conv2d_adaptive_padding', 'MultiHeadAttentionop']
__all__ = ['MultiHeadAttentionop']

View File

@ -1,86 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
import torch
import torch.nn.functional as F
from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.utils import Backend, is_dynamic_batch, is_dynamic_shape
def compute_padding(input_size, kernel_size, stride, dilation):
"""Compute padding."""
input_h, input_w = input_size
kernel_h, kernel_w = kernel_size
stride_h, stride_w = stride
dilation_h, dilation_w = dilation
output_h = math.ceil(input_h / stride_h)
output_w = math.ceil(input_w / stride_w)
pad_h = max(
(output_h - 1) * stride_h + (kernel_h - 1) * dilation_h + 1 - input_h,
0)
pad_w = max(
(output_w - 1) * stride_w + (kernel_w - 1) * dilation_w + 1 - input_w,
0)
if pad_w > 0 or pad_h > 0:
padded = [
pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2
]
else:
padded = None
return padded
class AdaptivePadOp(torch.autograd.Function):
"""Dummy adaptive pad op."""
@staticmethod
def forward(ctx, x, padded):
if padded is not None:
x = F.pad(x, padded)
return x
@staticmethod
def symbolic(g, x, padded):
if padded is None:
return g.op('Identity', x)
padded = g.op(
'Constant', value_t=torch.tensor(padded, dtype=torch.int64))
constant_value = g.op(
'Constant', value_t=torch.tensor(0, dtype=torch.int64))
return g.op(
'Pad', x, padded, constant_value, mode_s='constant', outputs=1)
@FUNCTION_REWRITER.register_rewriter(
func_name='mmcv.cnn.bricks.conv2d_adaptive_padding. \
Conv2dAdaptivePadding.forward',
backend=Backend.TENSORRT.value)
def conv2d_adaptive_padding__forward__tensorrt(ctx, self, x):
"""Rewrite `forward` of Conv2dAdaptivePadding used in EfficientNet for
TensorRT backend. Main changes of this rewritten function is to separate
the computation of padding and encapsulate it into another
`torch.autograd.Function` so that the adaptive padding could be parsed as
`Pad` ops in ONNX with the padding information computed in advance (Only
for static shape configuration).
Args:
x (Tensor): Input tensor of Conv2dAdaptivePadding ops
Returns:
Tensor: forward result of 2D convolution after padding
"""
deploy_cfg = ctx.cfg
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
if (not is_dynamic_flag) or is_dynamic_batch(deploy_cfg):
padded = compute_padding(x.shape[2:], self.weight.shape[2:],
self.stride, self.dilation)
if padded is not None:
padded = [int(_) for _ in padded]
x = AdaptivePadOp.apply(x, padded)
return F.conv2d(x, self.weight, self.bias, self.stride, self.padding,
self.dilation, self.groups)
else:
x = ctx.origin_func(x)
return x

View File

@ -3,8 +3,6 @@ import torch
from torch import Tensor
from torch.onnx import symbolic_helper as sym_help
from mmdeploy.core import SYMBOLIC_REWRITER
class ONNXNMSop(torch.autograd.Function):
"""Create onnx::NonMaxSuppression op.
@ -78,57 +76,24 @@ class ONNXNMSop(torch.autograd.Function):
Returns:
NonMaxSuppression op for onnx.
"""
return g.op(
'NonMaxSuppression',
boxes,
scores,
max_output_boxes_per_class,
iou_threshold,
score_threshold,
outputs=1)
if not sym_help._is_value(max_output_boxes_per_class):
max_output_boxes_per_class = g.op(
'Constant',
value_t=torch.tensor(
max_output_boxes_per_class, dtype=torch.long))
@SYMBOLIC_REWRITER.register_symbolic(
'mmdeploy.mmcv.ops.ONNXNMSop', backend='default')
def nms_dynamic(ctx, g, boxes: Tensor, scores: Tensor,
max_output_boxes_per_class: int, iou_threshold: float,
score_threshold: float):
"""Rewrite symbolic function for default backend.
if not sym_help._is_value(iou_threshold):
iou_threshold = g.op(
'Constant',
value_t=torch.tensor([iou_threshold], dtype=torch.float))
Support max_output_boxes_per_class, iou_threshold, score_threshold of
constant Tensor, which is aligned with ONNX's nms op.
Args:
ctx (ContextCaller): The context with additional information.
g (Graph): The traced onnx graph.
boxes (Tensor): The bounding boxes of shape [N, num_boxes, 4].
scores (Tensor): The detection scores of shape
[N, num_boxes, num_classes].
max_output_boxes_per_class (int): Maximum number of output
boxes per class of nms.
iou_threshold (float): IOU threshold of nms.
score_threshold (float): score threshold of nms.
Returns:
NonMaxSuppression op for onnx.
"""
if not sym_help._is_value(max_output_boxes_per_class):
max_output_boxes_per_class = g.op(
'Constant',
value_t=torch.tensor(max_output_boxes_per_class, dtype=torch.long))
if not sym_help._is_value(iou_threshold):
iou_threshold = g.op(
'Constant',
value_t=torch.tensor([iou_threshold], dtype=torch.float))
if not sym_help._is_value(score_threshold):
score_threshold = g.op(
'Constant',
value_t=torch.tensor([score_threshold], dtype=torch.float))
return g.op('NonMaxSuppression', boxes, scores, max_output_boxes_per_class,
iou_threshold, score_threshold)
if not sym_help._is_value(score_threshold):
score_threshold = g.op(
'Constant',
value_t=torch.tensor([score_threshold], dtype=torch.float))
return g.op('NonMaxSuppression', boxes, scores,
max_output_boxes_per_class, iou_threshold, score_threshold)
class TRTBatchedNMSop(torch.autograd.Function):

View File

@ -56,17 +56,17 @@ def roi_align_default(ctx, g, input: Tensor, rois: Tensor,
aligned_i=aligned)
else:
from torch.onnx.symbolic_opset9 import _cast_Long
from torch.onnx.symbolic_opset11 import add, select, squeeze
from torch.onnx.symbolic_opset11 import add, select
batch_indices = _cast_Long(
g,
squeeze(
g,
g.op(
'Squeeze',
select(
g, rois, 1,
g.op(
'Constant',
value_t=torch.tensor([0], dtype=torch.long))), 1),
False)
value_t=torch.tensor([0], dtype=torch.long))),
axes_i=[1]), False)
rois = select(
g, rois, 1,
g.op(

View File

@ -34,33 +34,6 @@ def test_multiheadattention_ncnn():
model_outputs, rewrite_outputs[0], rtol=1e-03, atol=1e-05)
def test_conv2d_adaptive_padding_tensorrt():
check_backend(Backend.TENSORRT)
from mmcv.cnn.bricks.conv2d_adaptive_padding import Conv2dAdaptivePadding
in_channels, out_channels = 3, 64
kernel_sz = 3
model = Conv2dAdaptivePadding(in_channels, out_channels, kernel_sz)
dummy_input = torch.rand(1, 3, 256, 256)
deploy_cfg = mmcv.Config(
dict(
onnx_config=dict(input_shape=None),
backend_config=dict(type=Backend.TENSORRT.value),
))
model_outputs = model(dummy_input)
rewrite_inputs = dict(x=dummy_input)
rewrite_outputs, is_backend_output = get_rewrite_outputs(
wrapped_model=model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg,
run_with_backend=True)
if is_backend_output is None:
assert rewrite_outputs is not None
else:
assert torch.allclose(
model_outputs, rewrite_outputs[0], rtol=1e-03, atol=1e-05)
def test_context_block_ncnn():
check_backend(Backend.NCNN)
from mmcv.cnn.bricks.context_block import ContextBlock