[Fix] Fix for torch113 for master (#1488)
parent
4046e13146
commit
8ea3dc943f
mmdeploy
apis/onnx
codebase/mmdet
core
post_processing
models/roi_heads
tests/test_mmcv
|
@ -30,3 +30,10 @@ def jit_pass_onnx_deduplicate_initializers__disable(ctx, graph, param_dict,
|
|||
disable for TensorRT.
|
||||
"""
|
||||
return param_dict
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'torch._C._jit_pass_onnx_autograd_function_process')
|
||||
def jit_pass_onnx_autograd_function_process__disable(ctx, graph):
|
||||
"""Disable process autograph function."""
|
||||
return
|
||||
|
|
|
@ -12,6 +12,14 @@ class GridPriorsTRTOp(torch.autograd.Function):
|
|||
@staticmethod
|
||||
def forward(ctx, base_anchors, feat_h, feat_w, stride_h: int,
|
||||
stride_w: int):
|
||||
"""Generate grid priors by base anchors."""
|
||||
|
||||
# torch>=1.13 has runtime error
|
||||
# when using torch.arange in autograd function
|
||||
output = getattr(GridPriorsTRTOp, 'output', None)
|
||||
if output is not None:
|
||||
return output
|
||||
|
||||
device = base_anchors.device
|
||||
dtype = base_anchors.dtype
|
||||
shift_x = torch.arange(0, feat_w, device=device).to(dtype) * stride_w
|
||||
|
@ -39,8 +47,8 @@ class GridPriorsTRTOp(torch.autograd.Function):
|
|||
def symbolic(g, base_anchors, feat_h, feat_w, stride_h: int,
|
||||
stride_w: int):
|
||||
# zero_h and zero_w is used to provide shape to GridPriorsTRT
|
||||
feat_h = symbolic_helper._unsqueeze_helper(g, feat_h, [0])
|
||||
feat_w = symbolic_helper._unsqueeze_helper(g, feat_w, [0])
|
||||
feat_h = g.op('Unsqueeze', feat_h, axes_i=[0])
|
||||
feat_w = g.op('Unsqueeze', feat_w, axes_i=[0])
|
||||
zero_h = g.op(
|
||||
'ConstantOfShape',
|
||||
feat_h,
|
||||
|
@ -90,9 +98,11 @@ def anchorgenerator__single_level_grid_priors__trt(
|
|||
torch.Tensor: Anchors in the overall feature maps.
|
||||
"""
|
||||
feat_h, feat_w = featmap_size
|
||||
output = ctx.origin_func(self, featmap_size, level_idx, dtype, device).data
|
||||
if isinstance(feat_h, int) and isinstance(feat_w, int):
|
||||
return ctx.origin_func(self, featmap_size, level_idx, dtype,
|
||||
device).data
|
||||
return output
|
||||
base_anchors = self.base_anchors[level_idx].to(device).to(dtype)
|
||||
stride_w, stride_h = self.strides[level_idx]
|
||||
|
||||
GridPriorsTRTOp.output = output
|
||||
return grid_priors_trt(base_anchors, feat_h, feat_w, stride_h, stride_w)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
from packaging import version
|
||||
from torch import Tensor
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER, mark
|
||||
|
@ -88,7 +89,9 @@ def _multiclass_nms(boxes: Tensor,
|
|||
shape (N, num_bboxes, num_classes) and the boxes is of shape (N, num_boxes,
|
||||
4).
|
||||
"""
|
||||
max_output_boxes_per_class = torch.LongTensor([max_output_boxes_per_class])
|
||||
if version.parse(torch.__version__) < version.parse('1.13.0'):
|
||||
max_output_boxes_per_class = torch.LongTensor(
|
||||
[max_output_boxes_per_class])
|
||||
iou_threshold = torch.tensor([iou_threshold], dtype=torch.float32)
|
||||
score_threshold = torch.tensor([score_threshold], dtype=torch.float32)
|
||||
batch_size = scores.shape[0]
|
||||
|
@ -122,7 +125,9 @@ def _multiclass_nms_single(boxes: Tensor,
|
|||
|
||||
Single batch nms could be optimized.
|
||||
"""
|
||||
max_output_boxes_per_class = torch.LongTensor([max_output_boxes_per_class])
|
||||
if version.parse(torch.__version__) < version.parse('1.13.0'):
|
||||
max_output_boxes_per_class = torch.LongTensor(
|
||||
[max_output_boxes_per_class])
|
||||
iou_threshold = torch.tensor([iou_threshold], dtype=torch.float32)
|
||||
score_threshold = torch.tensor([score_threshold], dtype=torch.float32)
|
||||
|
||||
|
|
|
@ -269,8 +269,8 @@ class SingleRoIExtractorOpenVINO(Function):
|
|||
|
||||
@staticmethod
|
||||
def symbolic(g, output_size, featmap_strides, sample_num, rois, *feats):
|
||||
from torch.onnx.symbolic_helper import _slice_helper
|
||||
rois = _slice_helper(g, rois, axes=[1], starts=[1], ends=[5])
|
||||
from torch.onnx.symbolic_opset10 import _slice
|
||||
rois = _slice(g, rois, axes=[1], starts=[1], ends=[5])
|
||||
domain = 'org.openvinotoolkit'
|
||||
op_name = 'ExperimentalDetectronROIFeatureExtractor'
|
||||
roi_feats = g.op(
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from . import context_block # noqa: F401,F403
|
||||
from . import conv2d_adaptive_padding # noqa: F401,F403
|
||||
from . import hsigmoid # noqa: F401,F403
|
||||
from . import hswish # noqa: F401,F403
|
||||
from .transformer import MultiHeadAttentionop
|
||||
|
||||
__all__ = ['conv2d_adaptive_padding', 'MultiHeadAttentionop']
|
||||
__all__ = ['MultiHeadAttentionop']
|
||||
|
|
|
@ -1,86 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
from mmdeploy.utils import Backend, is_dynamic_batch, is_dynamic_shape
|
||||
|
||||
|
||||
def compute_padding(input_size, kernel_size, stride, dilation):
|
||||
"""Compute padding."""
|
||||
|
||||
input_h, input_w = input_size
|
||||
kernel_h, kernel_w = kernel_size
|
||||
stride_h, stride_w = stride
|
||||
dilation_h, dilation_w = dilation
|
||||
output_h = math.ceil(input_h / stride_h)
|
||||
output_w = math.ceil(input_w / stride_w)
|
||||
pad_h = max(
|
||||
(output_h - 1) * stride_h + (kernel_h - 1) * dilation_h + 1 - input_h,
|
||||
0)
|
||||
pad_w = max(
|
||||
(output_w - 1) * stride_w + (kernel_w - 1) * dilation_w + 1 - input_w,
|
||||
0)
|
||||
if pad_w > 0 or pad_h > 0:
|
||||
padded = [
|
||||
pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2
|
||||
]
|
||||
else:
|
||||
padded = None
|
||||
return padded
|
||||
|
||||
|
||||
class AdaptivePadOp(torch.autograd.Function):
|
||||
"""Dummy adaptive pad op."""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x, padded):
|
||||
if padded is not None:
|
||||
x = F.pad(x, padded)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def symbolic(g, x, padded):
|
||||
if padded is None:
|
||||
return g.op('Identity', x)
|
||||
padded = g.op(
|
||||
'Constant', value_t=torch.tensor(padded, dtype=torch.int64))
|
||||
constant_value = g.op(
|
||||
'Constant', value_t=torch.tensor(0, dtype=torch.int64))
|
||||
return g.op(
|
||||
'Pad', x, padded, constant_value, mode_s='constant', outputs=1)
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmcv.cnn.bricks.conv2d_adaptive_padding. \
|
||||
Conv2dAdaptivePadding.forward',
|
||||
backend=Backend.TENSORRT.value)
|
||||
def conv2d_adaptive_padding__forward__tensorrt(ctx, self, x):
|
||||
"""Rewrite `forward` of Conv2dAdaptivePadding used in EfficientNet for
|
||||
TensorRT backend. Main changes of this rewritten function is to separate
|
||||
the computation of padding and encapsulate it into another
|
||||
`torch.autograd.Function` so that the adaptive padding could be parsed as
|
||||
`Pad` ops in ONNX with the padding information computed in advance (Only
|
||||
for static shape configuration).
|
||||
|
||||
Args:
|
||||
x (Tensor): Input tensor of Conv2dAdaptivePadding ops
|
||||
Returns:
|
||||
Tensor: forward result of 2D convolution after padding
|
||||
"""
|
||||
|
||||
deploy_cfg = ctx.cfg
|
||||
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
|
||||
if (not is_dynamic_flag) or is_dynamic_batch(deploy_cfg):
|
||||
padded = compute_padding(x.shape[2:], self.weight.shape[2:],
|
||||
self.stride, self.dilation)
|
||||
if padded is not None:
|
||||
padded = [int(_) for _ in padded]
|
||||
x = AdaptivePadOp.apply(x, padded)
|
||||
return F.conv2d(x, self.weight, self.bias, self.stride, self.padding,
|
||||
self.dilation, self.groups)
|
||||
else:
|
||||
x = ctx.origin_func(x)
|
||||
return x
|
|
@ -3,8 +3,6 @@ import torch
|
|||
from torch import Tensor
|
||||
from torch.onnx import symbolic_helper as sym_help
|
||||
|
||||
from mmdeploy.core import SYMBOLIC_REWRITER
|
||||
|
||||
|
||||
class ONNXNMSop(torch.autograd.Function):
|
||||
"""Create onnx::NonMaxSuppression op.
|
||||
|
@ -78,57 +76,24 @@ class ONNXNMSop(torch.autograd.Function):
|
|||
Returns:
|
||||
NonMaxSuppression op for onnx.
|
||||
"""
|
||||
return g.op(
|
||||
'NonMaxSuppression',
|
||||
boxes,
|
||||
scores,
|
||||
max_output_boxes_per_class,
|
||||
iou_threshold,
|
||||
score_threshold,
|
||||
outputs=1)
|
||||
|
||||
if not sym_help._is_value(max_output_boxes_per_class):
|
||||
max_output_boxes_per_class = g.op(
|
||||
'Constant',
|
||||
value_t=torch.tensor(
|
||||
max_output_boxes_per_class, dtype=torch.long))
|
||||
|
||||
@SYMBOLIC_REWRITER.register_symbolic(
|
||||
'mmdeploy.mmcv.ops.ONNXNMSop', backend='default')
|
||||
def nms_dynamic(ctx, g, boxes: Tensor, scores: Tensor,
|
||||
max_output_boxes_per_class: int, iou_threshold: float,
|
||||
score_threshold: float):
|
||||
"""Rewrite symbolic function for default backend.
|
||||
if not sym_help._is_value(iou_threshold):
|
||||
iou_threshold = g.op(
|
||||
'Constant',
|
||||
value_t=torch.tensor([iou_threshold], dtype=torch.float))
|
||||
|
||||
Support max_output_boxes_per_class, iou_threshold, score_threshold of
|
||||
constant Tensor, which is aligned with ONNX's nms op.
|
||||
|
||||
Args:
|
||||
ctx (ContextCaller): The context with additional information.
|
||||
g (Graph): The traced onnx graph.
|
||||
boxes (Tensor): The bounding boxes of shape [N, num_boxes, 4].
|
||||
scores (Tensor): The detection scores of shape
|
||||
[N, num_boxes, num_classes].
|
||||
max_output_boxes_per_class (int): Maximum number of output
|
||||
boxes per class of nms.
|
||||
iou_threshold (float): IOU threshold of nms.
|
||||
score_threshold (float): score threshold of nms.
|
||||
|
||||
Returns:
|
||||
NonMaxSuppression op for onnx.
|
||||
"""
|
||||
|
||||
if not sym_help._is_value(max_output_boxes_per_class):
|
||||
max_output_boxes_per_class = g.op(
|
||||
'Constant',
|
||||
value_t=torch.tensor(max_output_boxes_per_class, dtype=torch.long))
|
||||
|
||||
if not sym_help._is_value(iou_threshold):
|
||||
iou_threshold = g.op(
|
||||
'Constant',
|
||||
value_t=torch.tensor([iou_threshold], dtype=torch.float))
|
||||
|
||||
if not sym_help._is_value(score_threshold):
|
||||
score_threshold = g.op(
|
||||
'Constant',
|
||||
value_t=torch.tensor([score_threshold], dtype=torch.float))
|
||||
return g.op('NonMaxSuppression', boxes, scores, max_output_boxes_per_class,
|
||||
iou_threshold, score_threshold)
|
||||
if not sym_help._is_value(score_threshold):
|
||||
score_threshold = g.op(
|
||||
'Constant',
|
||||
value_t=torch.tensor([score_threshold], dtype=torch.float))
|
||||
return g.op('NonMaxSuppression', boxes, scores,
|
||||
max_output_boxes_per_class, iou_threshold, score_threshold)
|
||||
|
||||
|
||||
class TRTBatchedNMSop(torch.autograd.Function):
|
||||
|
|
|
@ -56,17 +56,17 @@ def roi_align_default(ctx, g, input: Tensor, rois: Tensor,
|
|||
aligned_i=aligned)
|
||||
else:
|
||||
from torch.onnx.symbolic_opset9 import _cast_Long
|
||||
from torch.onnx.symbolic_opset11 import add, select, squeeze
|
||||
from torch.onnx.symbolic_opset11 import add, select
|
||||
batch_indices = _cast_Long(
|
||||
g,
|
||||
squeeze(
|
||||
g,
|
||||
g.op(
|
||||
'Squeeze',
|
||||
select(
|
||||
g, rois, 1,
|
||||
g.op(
|
||||
'Constant',
|
||||
value_t=torch.tensor([0], dtype=torch.long))), 1),
|
||||
False)
|
||||
value_t=torch.tensor([0], dtype=torch.long))),
|
||||
axes_i=[1]), False)
|
||||
rois = select(
|
||||
g, rois, 1,
|
||||
g.op(
|
||||
|
|
|
@ -34,33 +34,6 @@ def test_multiheadattention_ncnn():
|
|||
model_outputs, rewrite_outputs[0], rtol=1e-03, atol=1e-05)
|
||||
|
||||
|
||||
def test_conv2d_adaptive_padding_tensorrt():
|
||||
check_backend(Backend.TENSORRT)
|
||||
from mmcv.cnn.bricks.conv2d_adaptive_padding import Conv2dAdaptivePadding
|
||||
in_channels, out_channels = 3, 64
|
||||
kernel_sz = 3
|
||||
model = Conv2dAdaptivePadding(in_channels, out_channels, kernel_sz)
|
||||
dummy_input = torch.rand(1, 3, 256, 256)
|
||||
|
||||
deploy_cfg = mmcv.Config(
|
||||
dict(
|
||||
onnx_config=dict(input_shape=None),
|
||||
backend_config=dict(type=Backend.TENSORRT.value),
|
||||
))
|
||||
model_outputs = model(dummy_input)
|
||||
rewrite_inputs = dict(x=dummy_input)
|
||||
rewrite_outputs, is_backend_output = get_rewrite_outputs(
|
||||
wrapped_model=model,
|
||||
model_inputs=rewrite_inputs,
|
||||
deploy_cfg=deploy_cfg,
|
||||
run_with_backend=True)
|
||||
if is_backend_output is None:
|
||||
assert rewrite_outputs is not None
|
||||
else:
|
||||
assert torch.allclose(
|
||||
model_outputs, rewrite_outputs[0], rtol=1e-03, atol=1e-05)
|
||||
|
||||
|
||||
def test_context_block_ncnn():
|
||||
check_backend(Backend.NCNN)
|
||||
from mmcv.cnn.bricks.context_block import ContextBlock
|
||||
|
|
Loading…
Reference in New Issue