better mmdet support

This commit is contained in:
grimoire 2021-06-18 15:16:21 +08:00
parent 2317ee659a
commit 2bd96243e2
13 changed files with 78 additions and 113 deletions

View File

@ -1,2 +1,2 @@
[settings]
known_third_party = mmcv,numpy,setuptools,torch
known_third_party = mmcv,mmdet,numpy,setuptools,torch

View File

@ -8,7 +8,6 @@ pytorch2onnx = dict(
0: 'batch'
},
'output': {
0: 'batch'
0: 'batch'
}
}
)
})

View File

@ -3,22 +3,39 @@ import torch
from mmdeploy.utils import FUNCTION_REWRITERS
@FUNCTION_REWRITERS.register_rewriter(func_name='torch.topk', backend='default')
@FUNCTION_REWRITERS.register_rewriter(func_name='torch.Tensor.topk', backend='default')
def rewrite_topk_default(rewriter, input, k, dim=None, largest=True, sorted=True):
@FUNCTION_REWRITERS.register_rewriter(
func_name='torch.topk', backend='default')
@FUNCTION_REWRITERS.register_rewriter(
func_name='torch.Tensor.topk', backend='default')
def rewrite_topk_default(rewriter,
input,
k,
dim=None,
largest=True,
sorted=True):
if dim is None:
dim = int(input.ndim - 1)
size = input.shape[dim]
if not isinstance(k, torch.Tensor):
k = torch.tensor(k, device=input.device, dtype=torch.long)
# Always keep topk op for dynamic input
if isinstance(size, torch.Tensor):
size = size.to(input.device)
k = torch.where(k < size, k, size)
return rewriter.origin_func(input, k, dim=dim, largest=largest, sorted=sorted)
return rewriter.origin_func(
input, k, dim=dim, largest=largest, sorted=sorted)
@FUNCTION_REWRITERS.register_rewriter(func_name='torch.topk', backend='tensorrt')
@FUNCTION_REWRITERS.register_rewriter(func_name='torch.Tensor.topk', backend='tensorrt')
def rewrite_topk_tensorrt(rewriter, input, k, dim=None, largest=True, sorted=True):
@FUNCTION_REWRITERS.register_rewriter(
func_name='torch.topk', backend='tensorrt')
@FUNCTION_REWRITERS.register_rewriter(
func_name='torch.Tensor.topk', backend='tensorrt')
def rewrite_topk_tensorrt(rewriter,
input,
k,
dim=None,
largest=True,
sorted=True):
if dim is None:
dim = int(input.ndim - 1)
size = input.shape[dim]
@ -26,4 +43,5 @@ def rewrite_topk_tensorrt(rewriter, input, k, dim=None, largest=True, sorted=Tru
k = size
if not isinstance(k, int):
k = int(k)
return rewriter.origin_func(input, k, dim=dim, largest=largest, sorted=sorted)
return rewriter.origin_func(
input, k, dim=dim, largest=largest, sorted=sorted)

View File

@ -1,5 +1,6 @@
import torch
from torch.onnx import symbolic_helper as sym_help
from mmdeploy.utils import SYMBOLICS_REGISTER
@ -27,14 +28,11 @@ class DummyONNXNMSop(torch.autograd.Function):
score_threshold,
outputs=1)
@SYMBOLICS_REGISTER.register_symbolic('mmdeploy.mmcv.ops.DummyONNXNMSop', backend='default')
def nms_default(symbolic_wrapper,
g,
boxes,
scores,
max_output_boxes_per_class,
iou_threshold,
score_threshold):
@SYMBOLICS_REGISTER.register_symbolic(
'mmdeploy.mmcv.ops.DummyONNXNMSop', backend='default')
def nms_default(symbolic_wrapper, g, boxes, scores, max_output_boxes_per_class,
iou_threshold, score_threshold):
if not sym_help._is_value(max_output_boxes_per_class):
max_output_boxes_per_class = g.op(
'Constant',
@ -49,18 +47,17 @@ def nms_default(symbolic_wrapper,
score_threshold = g.op(
'Constant',
value_t=torch.tensor([score_threshold], dtype=torch.float))
return g.op('NonMaxSuppression', boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold)
return g.op('NonMaxSuppression', boxes, scores, max_output_boxes_per_class,
iou_threshold, score_threshold)
@SYMBOLICS_REGISTER.register_symbolic('mmdeploy.mmcv.ops.DummyONNXNMSop', backend='tensorrt')
def nms_tensorrt(symbolic_wrapper,
g,
boxes,
scores,
max_output_boxes_per_class,
iou_threshold,
score_threshold):
@SYMBOLICS_REGISTER.register_symbolic(
'mmdeploy.mmcv.ops.DummyONNXNMSop', backend='tensorrt')
def nms_tensorrt(symbolic_wrapper, g, boxes, scores,
max_output_boxes_per_class, iou_threshold, score_threshold):
if sym_help._is_value(max_output_boxes_per_class):
max_output_boxes_per_class = sym_help._maybe_get_const(max_output_boxes_per_class, 'i')
max_output_boxes_per_class = sym_help._maybe_get_const(
max_output_boxes_per_class, 'i')
if sym_help._is_value(iou_threshold):
iou_threshold = sym_help._maybe_get_const(iou_threshold, 'f')
@ -68,9 +65,10 @@ def nms_tensorrt(symbolic_wrapper,
if sym_help._is_value(score_threshold):
score_threshold = sym_help._maybe_get_const(score_threshold, 'f')
return g.op('NonMaxSuppression',
boxes,
scores,
max_output_boxes_per_class_i=max_output_boxes_per_class,
iou_threshold_f=iou_threshold,
score_threshold_f=score_threshold)
return g.op(
'NonMaxSuppression',
boxes,
scores,
max_output_boxes_per_class_i=max_output_boxes_per_class,
iou_threshold_f=iou_threshold,
score_threshold_f=score_threshold)

View File

@ -1,11 +1,10 @@
from .onnx_helper import (add_dummy_nms_for_onnx, dynamic_clip_for_onnx,
get_k_for_topk)
from .onnx_helper import add_dummy_nms_for_onnx, dynamic_clip_for_onnx
from .pytorch2onnx import (build_model_from_cfg,
generate_inputs_and_wrap_model,
preprocess_example_input)
__all__ = [
'build_model_from_cfg', 'generate_inputs_and_wrap_model',
'preprocess_example_input', 'get_k_for_topk', 'add_dummy_nms_for_onnx',
'preprocess_example_input', 'add_dummy_nms_for_onnx',
'dynamic_clip_for_onnx'
]

View File

@ -3,7 +3,6 @@ import warnings
import numpy as np
import torch
from mmdet.core import bbox2result
from mmdet.models import BaseDetector

View File

@ -1,5 +1,3 @@
import os
import torch
from mmdeploy.mmcv.ops import DummyONNXNMSop
@ -44,42 +42,6 @@ def dynamic_clip_for_onnx(x1, y1, x2, y2, max_shape):
return x1, y1, x2, y2
def get_k_for_topk(k, size):
"""Get k of TopK for onnx exporting.
The K of TopK in TensorRT should not be a Tensor, while in ONNX Runtime
it could be a Tensor.Due to dynamic shape feature, we have to decide
whether to do TopK and what K it should be while exporting to ONNX.
If returned K is less than zero, it means we do not have to do
TopK operation.
Args:
k (int or Tensor): The set k value for nms from config file.
size (Tensor or torch.Size): The number of elements of \
TopK's input tensor
Returns:
tuple: (int or Tensor): The final K for TopK.
"""
ret_k = -1
if k <= 0 or size <= 0:
return ret_k
if torch.onnx.is_in_onnx_export():
is_trt_backend = os.environ.get('ONNX_BACKEND') == 'MMCVTensorRT'
if is_trt_backend:
# TensorRT does not support dynamic K with TopK op
if 0 < k < size:
ret_k = k
else:
# Always keep topk op for dynamic input in onnx for ONNX Runtime
ret_k = torch.where(k < size, k, size)
elif k < size:
ret_k = k
else:
# ret_k is -1
pass
return ret_k
def add_dummy_nms_for_onnx(boxes,
scores,
max_output_boxes_per_class=1000,
@ -122,12 +84,9 @@ def add_dummy_nms_for_onnx(boxes,
batch_size = scores.shape[0]
num_class = scores.shape[2]
nms_pre = torch.tensor(pre_top_k, device=scores.device, dtype=torch.long)
nms_pre = get_k_for_topk(nms_pre, boxes.shape[1])
if nms_pre > 0:
if pre_top_k > 0:
max_scores, _ = scores.max(-1)
_, topk_inds = max_scores.topk(nms_pre)
_, topk_inds = max_scores.topk(pre_top_k)
batch_inds = torch.arange(batch_size).view(
-1, 1).expand_as(topk_inds).long()
# Avoid onnx2tensorrt issue in https://github.com/NVIDIA/TensorRT/issues/1134 # noqa: E501
@ -178,13 +137,10 @@ def add_dummy_nms_for_onnx(boxes,
boxes = boxes.reshape(batch_size, -1, 4)
labels = labels.reshape(batch_size, -1)
nms_after = torch.tensor(
after_top_k, device=scores.device, dtype=torch.long)
nms_after = get_k_for_topk(nms_after, num_box * num_class)
if nms_after > 0:
_, topk_inds = scores.topk(nms_after)
batch_inds = torch.arange(batch_size).view(-1, 1).expand_as(topk_inds)
if after_top_k > 0:
_, topk_inds = scores.topk(after_top_k)
batch_inds = torch.arange(
batch_size, device=scores.device).view(-1, 1).expand_as(topk_inds)
# Avoid onnx2tensorrt issue in https://github.com/NVIDIA/TensorRT/issues/1134 # noqa: E501
transformed_inds = scores.shape[1] * batch_inds + topk_inds
scores = scores.reshape(-1, 1)[transformed_inds, :].reshape(
@ -197,6 +153,3 @@ def add_dummy_nms_for_onnx(boxes,
scores = scores.unsqueeze(2)
dets = torch.cat([boxes, scores], dim=2)
return dets, labels

View File

@ -1,4 +1,4 @@
from .dense_heads import * # noqa: F401,F403
from .dense_heads import * # noqa: F401,F403
from .detectors import * # noqa: F401,F403
from .roi_heads import * # noqa: F401,F403
from .necks import * # noqa: F401,F403
from .roi_heads import * # noqa: F401,F403

View File

@ -1,12 +1,16 @@
import torch
import torch.nn as nn
from mmdeploy.utils import MODULE_REWRITERS
from mmdeploy.mmdet.core.export import add_dummy_nms_for_onnx
from mmdeploy.utils import MODULE_REWRITERS
@MODULE_REWRITERS.register_rewrite_module(module_type='mmdet.models.AnchorHead')
@MODULE_REWRITERS.register_rewrite_module(module_type='mmdet.models.RetinaHead')
@MODULE_REWRITERS.register_rewrite_module(module_type='mmdet.models.AnchorHead'
)
@MODULE_REWRITERS.register_rewrite_module(module_type='mmdet.models.RetinaHead'
)
class AnchorHead(nn.Module):
def __init__(self, module, cfg, **kwargs):
super(AnchorHead, self).__init__()
self.module = module
@ -96,14 +100,12 @@ class AnchorHead(nn.Module):
if not self.use_sigmoid_cls:
num_classes = batch_mlvl_scores.shape[2] - 1
batch_mlvl_scores = batch_mlvl_scores[..., :num_classes]
max_output_boxes_per_class = cfg.nms.get(
'max_output_boxes_per_class', 200)
max_output_boxes_per_class = cfg.nms.get('max_output_boxes_per_class',
200)
iou_threshold = cfg.nms.get('iou_threshold', 0.5)
score_threshold = cfg.score_thr
nms_pre = cfg.get('deploy_nms_pre', -1)
return add_dummy_nms_for_onnx(batch_mlvl_bboxes, batch_mlvl_scores,
max_output_boxes_per_class,
iou_threshold, score_threshold,
nms_pre, cfg.max_per_img)
iou_threshold, score_threshold, nms_pre,
cfg.max_per_img)

View File

@ -5,16 +5,13 @@ from mmdeploy.utils import MODULE_REWRITERS
@MODULE_REWRITERS.register_rewrite_module(module_type='mmdet.models.RetinaNet')
@MODULE_REWRITERS.register_rewrite_module(module_type='mmdet.models.SingleStageDetector')
@MODULE_REWRITERS.register_rewrite_module(
module_type='mmdet.models.SingleStageDetector')
class SingleStageDetector(nn.Module):
def __init__(self, module, cfg, **kwargs):
super(SingleStageDetector, self).__init__()
self.module = module
self.backbone = module.backbone
self.with_neck = module.with_neck
if module.neck is not None:
self.neck = module.neck
self.bbox_head = module.bbox_head
def forward(self, data, **kwargs):