[Fix] Fix for torch113 for master (#1488)
parent
4046e13146
commit
8ea3dc943f
|
@ -30,3 +30,10 @@ def jit_pass_onnx_deduplicate_initializers__disable(ctx, graph, param_dict,
|
||||||
disable for TensorRT.
|
disable for TensorRT.
|
||||||
"""
|
"""
|
||||||
return param_dict
|
return param_dict
|
||||||
|
|
||||||
|
|
||||||
|
@FUNCTION_REWRITER.register_rewriter(
|
||||||
|
'torch._C._jit_pass_onnx_autograd_function_process')
|
||||||
|
def jit_pass_onnx_autograd_function_process__disable(ctx, graph):
|
||||||
|
"""Disable process autograph function."""
|
||||||
|
return
|
||||||
|
|
|
@ -12,6 +12,14 @@ class GridPriorsTRTOp(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, base_anchors, feat_h, feat_w, stride_h: int,
|
def forward(ctx, base_anchors, feat_h, feat_w, stride_h: int,
|
||||||
stride_w: int):
|
stride_w: int):
|
||||||
|
"""Generate grid priors by base anchors."""
|
||||||
|
|
||||||
|
# torch>=1.13 has runtime error
|
||||||
|
# when using torch.arange in autograd function
|
||||||
|
output = getattr(GridPriorsTRTOp, 'output', None)
|
||||||
|
if output is not None:
|
||||||
|
return output
|
||||||
|
|
||||||
device = base_anchors.device
|
device = base_anchors.device
|
||||||
dtype = base_anchors.dtype
|
dtype = base_anchors.dtype
|
||||||
shift_x = torch.arange(0, feat_w, device=device).to(dtype) * stride_w
|
shift_x = torch.arange(0, feat_w, device=device).to(dtype) * stride_w
|
||||||
|
@ -39,8 +47,8 @@ class GridPriorsTRTOp(torch.autograd.Function):
|
||||||
def symbolic(g, base_anchors, feat_h, feat_w, stride_h: int,
|
def symbolic(g, base_anchors, feat_h, feat_w, stride_h: int,
|
||||||
stride_w: int):
|
stride_w: int):
|
||||||
# zero_h and zero_w is used to provide shape to GridPriorsTRT
|
# zero_h and zero_w is used to provide shape to GridPriorsTRT
|
||||||
feat_h = symbolic_helper._unsqueeze_helper(g, feat_h, [0])
|
feat_h = g.op('Unsqueeze', feat_h, axes_i=[0])
|
||||||
feat_w = symbolic_helper._unsqueeze_helper(g, feat_w, [0])
|
feat_w = g.op('Unsqueeze', feat_w, axes_i=[0])
|
||||||
zero_h = g.op(
|
zero_h = g.op(
|
||||||
'ConstantOfShape',
|
'ConstantOfShape',
|
||||||
feat_h,
|
feat_h,
|
||||||
|
@ -90,9 +98,11 @@ def anchorgenerator__single_level_grid_priors__trt(
|
||||||
torch.Tensor: Anchors in the overall feature maps.
|
torch.Tensor: Anchors in the overall feature maps.
|
||||||
"""
|
"""
|
||||||
feat_h, feat_w = featmap_size
|
feat_h, feat_w = featmap_size
|
||||||
|
output = ctx.origin_func(self, featmap_size, level_idx, dtype, device).data
|
||||||
if isinstance(feat_h, int) and isinstance(feat_w, int):
|
if isinstance(feat_h, int) and isinstance(feat_w, int):
|
||||||
return ctx.origin_func(self, featmap_size, level_idx, dtype,
|
return output
|
||||||
device).data
|
|
||||||
base_anchors = self.base_anchors[level_idx].to(device).to(dtype)
|
base_anchors = self.base_anchors[level_idx].to(device).to(dtype)
|
||||||
stride_w, stride_h = self.strides[level_idx]
|
stride_w, stride_h = self.strides[level_idx]
|
||||||
|
|
||||||
|
GridPriorsTRTOp.output = output
|
||||||
return grid_priors_trt(base_anchors, feat_h, feat_w, stride_h, stride_w)
|
return grid_priors_trt(base_anchors, feat_h, feat_w, stride_h, stride_w)
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import torch
|
import torch
|
||||||
|
from packaging import version
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from mmdeploy.core import FUNCTION_REWRITER, mark
|
from mmdeploy.core import FUNCTION_REWRITER, mark
|
||||||
|
@ -88,7 +89,9 @@ def _multiclass_nms(boxes: Tensor,
|
||||||
shape (N, num_bboxes, num_classes) and the boxes is of shape (N, num_boxes,
|
shape (N, num_bboxes, num_classes) and the boxes is of shape (N, num_boxes,
|
||||||
4).
|
4).
|
||||||
"""
|
"""
|
||||||
max_output_boxes_per_class = torch.LongTensor([max_output_boxes_per_class])
|
if version.parse(torch.__version__) < version.parse('1.13.0'):
|
||||||
|
max_output_boxes_per_class = torch.LongTensor(
|
||||||
|
[max_output_boxes_per_class])
|
||||||
iou_threshold = torch.tensor([iou_threshold], dtype=torch.float32)
|
iou_threshold = torch.tensor([iou_threshold], dtype=torch.float32)
|
||||||
score_threshold = torch.tensor([score_threshold], dtype=torch.float32)
|
score_threshold = torch.tensor([score_threshold], dtype=torch.float32)
|
||||||
batch_size = scores.shape[0]
|
batch_size = scores.shape[0]
|
||||||
|
@ -122,7 +125,9 @@ def _multiclass_nms_single(boxes: Tensor,
|
||||||
|
|
||||||
Single batch nms could be optimized.
|
Single batch nms could be optimized.
|
||||||
"""
|
"""
|
||||||
max_output_boxes_per_class = torch.LongTensor([max_output_boxes_per_class])
|
if version.parse(torch.__version__) < version.parse('1.13.0'):
|
||||||
|
max_output_boxes_per_class = torch.LongTensor(
|
||||||
|
[max_output_boxes_per_class])
|
||||||
iou_threshold = torch.tensor([iou_threshold], dtype=torch.float32)
|
iou_threshold = torch.tensor([iou_threshold], dtype=torch.float32)
|
||||||
score_threshold = torch.tensor([score_threshold], dtype=torch.float32)
|
score_threshold = torch.tensor([score_threshold], dtype=torch.float32)
|
||||||
|
|
||||||
|
|
|
@ -269,8 +269,8 @@ class SingleRoIExtractorOpenVINO(Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def symbolic(g, output_size, featmap_strides, sample_num, rois, *feats):
|
def symbolic(g, output_size, featmap_strides, sample_num, rois, *feats):
|
||||||
from torch.onnx.symbolic_helper import _slice_helper
|
from torch.onnx.symbolic_opset10 import _slice
|
||||||
rois = _slice_helper(g, rois, axes=[1], starts=[1], ends=[5])
|
rois = _slice(g, rois, axes=[1], starts=[1], ends=[5])
|
||||||
domain = 'org.openvinotoolkit'
|
domain = 'org.openvinotoolkit'
|
||||||
op_name = 'ExperimentalDetectronROIFeatureExtractor'
|
op_name = 'ExperimentalDetectronROIFeatureExtractor'
|
||||||
roi_feats = g.op(
|
roi_feats = g.op(
|
||||||
|
|
|
@ -1,8 +1,7 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from . import context_block # noqa: F401,F403
|
from . import context_block # noqa: F401,F403
|
||||||
from . import conv2d_adaptive_padding # noqa: F401,F403
|
|
||||||
from . import hsigmoid # noqa: F401,F403
|
from . import hsigmoid # noqa: F401,F403
|
||||||
from . import hswish # noqa: F401,F403
|
from . import hswish # noqa: F401,F403
|
||||||
from .transformer import MultiHeadAttentionop
|
from .transformer import MultiHeadAttentionop
|
||||||
|
|
||||||
__all__ = ['conv2d_adaptive_padding', 'MultiHeadAttentionop']
|
__all__ = ['MultiHeadAttentionop']
|
||||||
|
|
|
@ -1,86 +0,0 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
|
||||||
import math
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from mmdeploy.core import FUNCTION_REWRITER
|
|
||||||
from mmdeploy.utils import Backend, is_dynamic_batch, is_dynamic_shape
|
|
||||||
|
|
||||||
|
|
||||||
def compute_padding(input_size, kernel_size, stride, dilation):
|
|
||||||
"""Compute padding."""
|
|
||||||
|
|
||||||
input_h, input_w = input_size
|
|
||||||
kernel_h, kernel_w = kernel_size
|
|
||||||
stride_h, stride_w = stride
|
|
||||||
dilation_h, dilation_w = dilation
|
|
||||||
output_h = math.ceil(input_h / stride_h)
|
|
||||||
output_w = math.ceil(input_w / stride_w)
|
|
||||||
pad_h = max(
|
|
||||||
(output_h - 1) * stride_h + (kernel_h - 1) * dilation_h + 1 - input_h,
|
|
||||||
0)
|
|
||||||
pad_w = max(
|
|
||||||
(output_w - 1) * stride_w + (kernel_w - 1) * dilation_w + 1 - input_w,
|
|
||||||
0)
|
|
||||||
if pad_w > 0 or pad_h > 0:
|
|
||||||
padded = [
|
|
||||||
pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
padded = None
|
|
||||||
return padded
|
|
||||||
|
|
||||||
|
|
||||||
class AdaptivePadOp(torch.autograd.Function):
|
|
||||||
"""Dummy adaptive pad op."""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, x, padded):
|
|
||||||
if padded is not None:
|
|
||||||
x = F.pad(x, padded)
|
|
||||||
return x
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def symbolic(g, x, padded):
|
|
||||||
if padded is None:
|
|
||||||
return g.op('Identity', x)
|
|
||||||
padded = g.op(
|
|
||||||
'Constant', value_t=torch.tensor(padded, dtype=torch.int64))
|
|
||||||
constant_value = g.op(
|
|
||||||
'Constant', value_t=torch.tensor(0, dtype=torch.int64))
|
|
||||||
return g.op(
|
|
||||||
'Pad', x, padded, constant_value, mode_s='constant', outputs=1)
|
|
||||||
|
|
||||||
|
|
||||||
@FUNCTION_REWRITER.register_rewriter(
|
|
||||||
func_name='mmcv.cnn.bricks.conv2d_adaptive_padding. \
|
|
||||||
Conv2dAdaptivePadding.forward',
|
|
||||||
backend=Backend.TENSORRT.value)
|
|
||||||
def conv2d_adaptive_padding__forward__tensorrt(ctx, self, x):
|
|
||||||
"""Rewrite `forward` of Conv2dAdaptivePadding used in EfficientNet for
|
|
||||||
TensorRT backend. Main changes of this rewritten function is to separate
|
|
||||||
the computation of padding and encapsulate it into another
|
|
||||||
`torch.autograd.Function` so that the adaptive padding could be parsed as
|
|
||||||
`Pad` ops in ONNX with the padding information computed in advance (Only
|
|
||||||
for static shape configuration).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x (Tensor): Input tensor of Conv2dAdaptivePadding ops
|
|
||||||
Returns:
|
|
||||||
Tensor: forward result of 2D convolution after padding
|
|
||||||
"""
|
|
||||||
|
|
||||||
deploy_cfg = ctx.cfg
|
|
||||||
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
|
|
||||||
if (not is_dynamic_flag) or is_dynamic_batch(deploy_cfg):
|
|
||||||
padded = compute_padding(x.shape[2:], self.weight.shape[2:],
|
|
||||||
self.stride, self.dilation)
|
|
||||||
if padded is not None:
|
|
||||||
padded = [int(_) for _ in padded]
|
|
||||||
x = AdaptivePadOp.apply(x, padded)
|
|
||||||
return F.conv2d(x, self.weight, self.bias, self.stride, self.padding,
|
|
||||||
self.dilation, self.groups)
|
|
||||||
else:
|
|
||||||
x = ctx.origin_func(x)
|
|
||||||
return x
|
|
|
@ -3,8 +3,6 @@ import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.onnx import symbolic_helper as sym_help
|
from torch.onnx import symbolic_helper as sym_help
|
||||||
|
|
||||||
from mmdeploy.core import SYMBOLIC_REWRITER
|
|
||||||
|
|
||||||
|
|
||||||
class ONNXNMSop(torch.autograd.Function):
|
class ONNXNMSop(torch.autograd.Function):
|
||||||
"""Create onnx::NonMaxSuppression op.
|
"""Create onnx::NonMaxSuppression op.
|
||||||
|
@ -75,40 +73,6 @@ class ONNXNMSop(torch.autograd.Function):
|
||||||
iou_threshold (float): IOU threshold of nms.
|
iou_threshold (float): IOU threshold of nms.
|
||||||
score_threshold (float): score threshold of nms.
|
score_threshold (float): score threshold of nms.
|
||||||
|
|
||||||
Returns:
|
|
||||||
NonMaxSuppression op for onnx.
|
|
||||||
"""
|
|
||||||
return g.op(
|
|
||||||
'NonMaxSuppression',
|
|
||||||
boxes,
|
|
||||||
scores,
|
|
||||||
max_output_boxes_per_class,
|
|
||||||
iou_threshold,
|
|
||||||
score_threshold,
|
|
||||||
outputs=1)
|
|
||||||
|
|
||||||
|
|
||||||
@SYMBOLIC_REWRITER.register_symbolic(
|
|
||||||
'mmdeploy.mmcv.ops.ONNXNMSop', backend='default')
|
|
||||||
def nms_dynamic(ctx, g, boxes: Tensor, scores: Tensor,
|
|
||||||
max_output_boxes_per_class: int, iou_threshold: float,
|
|
||||||
score_threshold: float):
|
|
||||||
"""Rewrite symbolic function for default backend.
|
|
||||||
|
|
||||||
Support max_output_boxes_per_class, iou_threshold, score_threshold of
|
|
||||||
constant Tensor, which is aligned with ONNX's nms op.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
ctx (ContextCaller): The context with additional information.
|
|
||||||
g (Graph): The traced onnx graph.
|
|
||||||
boxes (Tensor): The bounding boxes of shape [N, num_boxes, 4].
|
|
||||||
scores (Tensor): The detection scores of shape
|
|
||||||
[N, num_boxes, num_classes].
|
|
||||||
max_output_boxes_per_class (int): Maximum number of output
|
|
||||||
boxes per class of nms.
|
|
||||||
iou_threshold (float): IOU threshold of nms.
|
|
||||||
score_threshold (float): score threshold of nms.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
NonMaxSuppression op for onnx.
|
NonMaxSuppression op for onnx.
|
||||||
"""
|
"""
|
||||||
|
@ -116,7 +80,8 @@ def nms_dynamic(ctx, g, boxes: Tensor, scores: Tensor,
|
||||||
if not sym_help._is_value(max_output_boxes_per_class):
|
if not sym_help._is_value(max_output_boxes_per_class):
|
||||||
max_output_boxes_per_class = g.op(
|
max_output_boxes_per_class = g.op(
|
||||||
'Constant',
|
'Constant',
|
||||||
value_t=torch.tensor(max_output_boxes_per_class, dtype=torch.long))
|
value_t=torch.tensor(
|
||||||
|
max_output_boxes_per_class, dtype=torch.long))
|
||||||
|
|
||||||
if not sym_help._is_value(iou_threshold):
|
if not sym_help._is_value(iou_threshold):
|
||||||
iou_threshold = g.op(
|
iou_threshold = g.op(
|
||||||
|
@ -127,8 +92,8 @@ def nms_dynamic(ctx, g, boxes: Tensor, scores: Tensor,
|
||||||
score_threshold = g.op(
|
score_threshold = g.op(
|
||||||
'Constant',
|
'Constant',
|
||||||
value_t=torch.tensor([score_threshold], dtype=torch.float))
|
value_t=torch.tensor([score_threshold], dtype=torch.float))
|
||||||
return g.op('NonMaxSuppression', boxes, scores, max_output_boxes_per_class,
|
return g.op('NonMaxSuppression', boxes, scores,
|
||||||
iou_threshold, score_threshold)
|
max_output_boxes_per_class, iou_threshold, score_threshold)
|
||||||
|
|
||||||
|
|
||||||
class TRTBatchedNMSop(torch.autograd.Function):
|
class TRTBatchedNMSop(torch.autograd.Function):
|
||||||
|
|
|
@ -56,17 +56,17 @@ def roi_align_default(ctx, g, input: Tensor, rois: Tensor,
|
||||||
aligned_i=aligned)
|
aligned_i=aligned)
|
||||||
else:
|
else:
|
||||||
from torch.onnx.symbolic_opset9 import _cast_Long
|
from torch.onnx.symbolic_opset9 import _cast_Long
|
||||||
from torch.onnx.symbolic_opset11 import add, select, squeeze
|
from torch.onnx.symbolic_opset11 import add, select
|
||||||
batch_indices = _cast_Long(
|
batch_indices = _cast_Long(
|
||||||
g,
|
g,
|
||||||
squeeze(
|
g.op(
|
||||||
g,
|
'Squeeze',
|
||||||
select(
|
select(
|
||||||
g, rois, 1,
|
g, rois, 1,
|
||||||
g.op(
|
g.op(
|
||||||
'Constant',
|
'Constant',
|
||||||
value_t=torch.tensor([0], dtype=torch.long))), 1),
|
value_t=torch.tensor([0], dtype=torch.long))),
|
||||||
False)
|
axes_i=[1]), False)
|
||||||
rois = select(
|
rois = select(
|
||||||
g, rois, 1,
|
g, rois, 1,
|
||||||
g.op(
|
g.op(
|
||||||
|
|
|
@ -34,33 +34,6 @@ def test_multiheadattention_ncnn():
|
||||||
model_outputs, rewrite_outputs[0], rtol=1e-03, atol=1e-05)
|
model_outputs, rewrite_outputs[0], rtol=1e-03, atol=1e-05)
|
||||||
|
|
||||||
|
|
||||||
def test_conv2d_adaptive_padding_tensorrt():
|
|
||||||
check_backend(Backend.TENSORRT)
|
|
||||||
from mmcv.cnn.bricks.conv2d_adaptive_padding import Conv2dAdaptivePadding
|
|
||||||
in_channels, out_channels = 3, 64
|
|
||||||
kernel_sz = 3
|
|
||||||
model = Conv2dAdaptivePadding(in_channels, out_channels, kernel_sz)
|
|
||||||
dummy_input = torch.rand(1, 3, 256, 256)
|
|
||||||
|
|
||||||
deploy_cfg = mmcv.Config(
|
|
||||||
dict(
|
|
||||||
onnx_config=dict(input_shape=None),
|
|
||||||
backend_config=dict(type=Backend.TENSORRT.value),
|
|
||||||
))
|
|
||||||
model_outputs = model(dummy_input)
|
|
||||||
rewrite_inputs = dict(x=dummy_input)
|
|
||||||
rewrite_outputs, is_backend_output = get_rewrite_outputs(
|
|
||||||
wrapped_model=model,
|
|
||||||
model_inputs=rewrite_inputs,
|
|
||||||
deploy_cfg=deploy_cfg,
|
|
||||||
run_with_backend=True)
|
|
||||||
if is_backend_output is None:
|
|
||||||
assert rewrite_outputs is not None
|
|
||||||
else:
|
|
||||||
assert torch.allclose(
|
|
||||||
model_outputs, rewrite_outputs[0], rtol=1e-03, atol=1e-05)
|
|
||||||
|
|
||||||
|
|
||||||
def test_context_block_ncnn():
|
def test_context_block_ncnn():
|
||||||
check_backend(Backend.NCNN)
|
check_backend(Backend.NCNN)
|
||||||
from mmcv.cnn.bricks.context_block import ContextBlock
|
from mmcv.cnn.bricks.context_block import ContextBlock
|
||||||
|
|
Loading…
Reference in New Issue