better mmdet support

This commit is contained in:
grimoire 2021-06-18 15:16:21 +08:00
parent 2317ee659a
commit 2bd96243e2
13 changed files with 78 additions and 113 deletions

View File

@ -1,2 +1,2 @@
[settings] [settings]
known_third_party = mmcv,numpy,setuptools,torch known_third_party = mmcv,mmdet,numpy,setuptools,torch

View File

@ -8,7 +8,6 @@ pytorch2onnx = dict(
0: 'batch' 0: 'batch'
}, },
'output': { 'output': {
0: 'batch' 0: 'batch'
} }
} })
)

View File

@ -3,22 +3,39 @@ import torch
from mmdeploy.utils import FUNCTION_REWRITERS from mmdeploy.utils import FUNCTION_REWRITERS
@FUNCTION_REWRITERS.register_rewriter(func_name='torch.topk', backend='default') @FUNCTION_REWRITERS.register_rewriter(
@FUNCTION_REWRITERS.register_rewriter(func_name='torch.Tensor.topk', backend='default') func_name='torch.topk', backend='default')
def rewrite_topk_default(rewriter, input, k, dim=None, largest=True, sorted=True): @FUNCTION_REWRITERS.register_rewriter(
func_name='torch.Tensor.topk', backend='default')
def rewrite_topk_default(rewriter,
input,
k,
dim=None,
largest=True,
sorted=True):
if dim is None: if dim is None:
dim = int(input.ndim - 1) dim = int(input.ndim - 1)
size = input.shape[dim] size = input.shape[dim]
if not isinstance(k, torch.Tensor): if not isinstance(k, torch.Tensor):
k = torch.tensor(k, device=input.device, dtype=torch.long) k = torch.tensor(k, device=input.device, dtype=torch.long)
# Always keep topk op for dynamic input # Always keep topk op for dynamic input
if isinstance(size, torch.Tensor):
size = size.to(input.device)
k = torch.where(k < size, k, size) k = torch.where(k < size, k, size)
return rewriter.origin_func(input, k, dim=dim, largest=largest, sorted=sorted) return rewriter.origin_func(
input, k, dim=dim, largest=largest, sorted=sorted)
@FUNCTION_REWRITERS.register_rewriter(func_name='torch.topk', backend='tensorrt') @FUNCTION_REWRITERS.register_rewriter(
@FUNCTION_REWRITERS.register_rewriter(func_name='torch.Tensor.topk', backend='tensorrt') func_name='torch.topk', backend='tensorrt')
def rewrite_topk_tensorrt(rewriter, input, k, dim=None, largest=True, sorted=True): @FUNCTION_REWRITERS.register_rewriter(
func_name='torch.Tensor.topk', backend='tensorrt')
def rewrite_topk_tensorrt(rewriter,
input,
k,
dim=None,
largest=True,
sorted=True):
if dim is None: if dim is None:
dim = int(input.ndim - 1) dim = int(input.ndim - 1)
size = input.shape[dim] size = input.shape[dim]
@ -26,4 +43,5 @@ def rewrite_topk_tensorrt(rewriter, input, k, dim=None, largest=True, sorted=Tru
k = size k = size
if not isinstance(k, int): if not isinstance(k, int):
k = int(k) k = int(k)
return rewriter.origin_func(input, k, dim=dim, largest=largest, sorted=sorted) return rewriter.origin_func(
input, k, dim=dim, largest=largest, sorted=sorted)

View File

@ -1,5 +1,6 @@
import torch import torch
from torch.onnx import symbolic_helper as sym_help from torch.onnx import symbolic_helper as sym_help
from mmdeploy.utils import SYMBOLICS_REGISTER from mmdeploy.utils import SYMBOLICS_REGISTER
@ -27,14 +28,11 @@ class DummyONNXNMSop(torch.autograd.Function):
score_threshold, score_threshold,
outputs=1) outputs=1)
@SYMBOLICS_REGISTER.register_symbolic('mmdeploy.mmcv.ops.DummyONNXNMSop', backend='default')
def nms_default(symbolic_wrapper, @SYMBOLICS_REGISTER.register_symbolic(
g, 'mmdeploy.mmcv.ops.DummyONNXNMSop', backend='default')
boxes, def nms_default(symbolic_wrapper, g, boxes, scores, max_output_boxes_per_class,
scores, iou_threshold, score_threshold):
max_output_boxes_per_class,
iou_threshold,
score_threshold):
if not sym_help._is_value(max_output_boxes_per_class): if not sym_help._is_value(max_output_boxes_per_class):
max_output_boxes_per_class = g.op( max_output_boxes_per_class = g.op(
'Constant', 'Constant',
@ -49,18 +47,17 @@ def nms_default(symbolic_wrapper,
score_threshold = g.op( score_threshold = g.op(
'Constant', 'Constant',
value_t=torch.tensor([score_threshold], dtype=torch.float)) value_t=torch.tensor([score_threshold], dtype=torch.float))
return g.op('NonMaxSuppression', boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold) return g.op('NonMaxSuppression', boxes, scores, max_output_boxes_per_class,
iou_threshold, score_threshold)
@SYMBOLICS_REGISTER.register_symbolic('mmdeploy.mmcv.ops.DummyONNXNMSop', backend='tensorrt')
def nms_tensorrt(symbolic_wrapper, @SYMBOLICS_REGISTER.register_symbolic(
g, 'mmdeploy.mmcv.ops.DummyONNXNMSop', backend='tensorrt')
boxes, def nms_tensorrt(symbolic_wrapper, g, boxes, scores,
scores, max_output_boxes_per_class, iou_threshold, score_threshold):
max_output_boxes_per_class,
iou_threshold,
score_threshold):
if sym_help._is_value(max_output_boxes_per_class): if sym_help._is_value(max_output_boxes_per_class):
max_output_boxes_per_class = sym_help._maybe_get_const(max_output_boxes_per_class, 'i') max_output_boxes_per_class = sym_help._maybe_get_const(
max_output_boxes_per_class, 'i')
if sym_help._is_value(iou_threshold): if sym_help._is_value(iou_threshold):
iou_threshold = sym_help._maybe_get_const(iou_threshold, 'f') iou_threshold = sym_help._maybe_get_const(iou_threshold, 'f')
@ -68,9 +65,10 @@ def nms_tensorrt(symbolic_wrapper,
if sym_help._is_value(score_threshold): if sym_help._is_value(score_threshold):
score_threshold = sym_help._maybe_get_const(score_threshold, 'f') score_threshold = sym_help._maybe_get_const(score_threshold, 'f')
return g.op('NonMaxSuppression', return g.op(
boxes, 'NonMaxSuppression',
scores, boxes,
max_output_boxes_per_class_i=max_output_boxes_per_class, scores,
iou_threshold_f=iou_threshold, max_output_boxes_per_class_i=max_output_boxes_per_class,
score_threshold_f=score_threshold) iou_threshold_f=iou_threshold,
score_threshold_f=score_threshold)

View File

@ -1,11 +1,10 @@
from .onnx_helper import (add_dummy_nms_for_onnx, dynamic_clip_for_onnx, from .onnx_helper import add_dummy_nms_for_onnx, dynamic_clip_for_onnx
get_k_for_topk)
from .pytorch2onnx import (build_model_from_cfg, from .pytorch2onnx import (build_model_from_cfg,
generate_inputs_and_wrap_model, generate_inputs_and_wrap_model,
preprocess_example_input) preprocess_example_input)
__all__ = [ __all__ = [
'build_model_from_cfg', 'generate_inputs_and_wrap_model', 'build_model_from_cfg', 'generate_inputs_and_wrap_model',
'preprocess_example_input', 'get_k_for_topk', 'add_dummy_nms_for_onnx', 'preprocess_example_input', 'add_dummy_nms_for_onnx',
'dynamic_clip_for_onnx' 'dynamic_clip_for_onnx'
] ]

View File

@ -3,7 +3,6 @@ import warnings
import numpy as np import numpy as np
import torch import torch
from mmdet.core import bbox2result from mmdet.core import bbox2result
from mmdet.models import BaseDetector from mmdet.models import BaseDetector

View File

@ -1,5 +1,3 @@
import os
import torch import torch
from mmdeploy.mmcv.ops import DummyONNXNMSop from mmdeploy.mmcv.ops import DummyONNXNMSop
@ -44,42 +42,6 @@ def dynamic_clip_for_onnx(x1, y1, x2, y2, max_shape):
return x1, y1, x2, y2 return x1, y1, x2, y2
def get_k_for_topk(k, size):
"""Get k of TopK for onnx exporting.
The K of TopK in TensorRT should not be a Tensor, while in ONNX Runtime
it could be a Tensor.Due to dynamic shape feature, we have to decide
whether to do TopK and what K it should be while exporting to ONNX.
If returned K is less than zero, it means we do not have to do
TopK operation.
Args:
k (int or Tensor): The set k value for nms from config file.
size (Tensor or torch.Size): The number of elements of \
TopK's input tensor
Returns:
tuple: (int or Tensor): The final K for TopK.
"""
ret_k = -1
if k <= 0 or size <= 0:
return ret_k
if torch.onnx.is_in_onnx_export():
is_trt_backend = os.environ.get('ONNX_BACKEND') == 'MMCVTensorRT'
if is_trt_backend:
# TensorRT does not support dynamic K with TopK op
if 0 < k < size:
ret_k = k
else:
# Always keep topk op for dynamic input in onnx for ONNX Runtime
ret_k = torch.where(k < size, k, size)
elif k < size:
ret_k = k
else:
# ret_k is -1
pass
return ret_k
def add_dummy_nms_for_onnx(boxes, def add_dummy_nms_for_onnx(boxes,
scores, scores,
max_output_boxes_per_class=1000, max_output_boxes_per_class=1000,
@ -122,12 +84,9 @@ def add_dummy_nms_for_onnx(boxes,
batch_size = scores.shape[0] batch_size = scores.shape[0]
num_class = scores.shape[2] num_class = scores.shape[2]
nms_pre = torch.tensor(pre_top_k, device=scores.device, dtype=torch.long) if pre_top_k > 0:
nms_pre = get_k_for_topk(nms_pre, boxes.shape[1])
if nms_pre > 0:
max_scores, _ = scores.max(-1) max_scores, _ = scores.max(-1)
_, topk_inds = max_scores.topk(nms_pre) _, topk_inds = max_scores.topk(pre_top_k)
batch_inds = torch.arange(batch_size).view( batch_inds = torch.arange(batch_size).view(
-1, 1).expand_as(topk_inds).long() -1, 1).expand_as(topk_inds).long()
# Avoid onnx2tensorrt issue in https://github.com/NVIDIA/TensorRT/issues/1134 # noqa: E501 # Avoid onnx2tensorrt issue in https://github.com/NVIDIA/TensorRT/issues/1134 # noqa: E501
@ -178,13 +137,10 @@ def add_dummy_nms_for_onnx(boxes,
boxes = boxes.reshape(batch_size, -1, 4) boxes = boxes.reshape(batch_size, -1, 4)
labels = labels.reshape(batch_size, -1) labels = labels.reshape(batch_size, -1)
nms_after = torch.tensor( if after_top_k > 0:
after_top_k, device=scores.device, dtype=torch.long) _, topk_inds = scores.topk(after_top_k)
nms_after = get_k_for_topk(nms_after, num_box * num_class) batch_inds = torch.arange(
batch_size, device=scores.device).view(-1, 1).expand_as(topk_inds)
if nms_after > 0:
_, topk_inds = scores.topk(nms_after)
batch_inds = torch.arange(batch_size).view(-1, 1).expand_as(topk_inds)
# Avoid onnx2tensorrt issue in https://github.com/NVIDIA/TensorRT/issues/1134 # noqa: E501 # Avoid onnx2tensorrt issue in https://github.com/NVIDIA/TensorRT/issues/1134 # noqa: E501
transformed_inds = scores.shape[1] * batch_inds + topk_inds transformed_inds = scores.shape[1] * batch_inds + topk_inds
scores = scores.reshape(-1, 1)[transformed_inds, :].reshape( scores = scores.reshape(-1, 1)[transformed_inds, :].reshape(
@ -197,6 +153,3 @@ def add_dummy_nms_for_onnx(boxes,
scores = scores.unsqueeze(2) scores = scores.unsqueeze(2)
dets = torch.cat([boxes, scores], dim=2) dets = torch.cat([boxes, scores], dim=2)
return dets, labels return dets, labels

View File

@ -1,4 +1,4 @@
from .dense_heads import * # noqa: F401,F403 from .dense_heads import * # noqa: F401,F403
from .detectors import * # noqa: F401,F403 from .detectors import * # noqa: F401,F403
from .roi_heads import * # noqa: F401,F403
from .necks import * # noqa: F401,F403 from .necks import * # noqa: F401,F403
from .roi_heads import * # noqa: F401,F403

View File

@ -1,12 +1,16 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from mmdeploy.utils import MODULE_REWRITERS
from mmdeploy.mmdet.core.export import add_dummy_nms_for_onnx from mmdeploy.mmdet.core.export import add_dummy_nms_for_onnx
from mmdeploy.utils import MODULE_REWRITERS
@MODULE_REWRITERS.register_rewrite_module(module_type='mmdet.models.AnchorHead')
@MODULE_REWRITERS.register_rewrite_module(module_type='mmdet.models.RetinaHead') @MODULE_REWRITERS.register_rewrite_module(module_type='mmdet.models.AnchorHead'
)
@MODULE_REWRITERS.register_rewrite_module(module_type='mmdet.models.RetinaHead'
)
class AnchorHead(nn.Module): class AnchorHead(nn.Module):
def __init__(self, module, cfg, **kwargs): def __init__(self, module, cfg, **kwargs):
super(AnchorHead, self).__init__() super(AnchorHead, self).__init__()
self.module = module self.module = module
@ -96,14 +100,12 @@ class AnchorHead(nn.Module):
if not self.use_sigmoid_cls: if not self.use_sigmoid_cls:
num_classes = batch_mlvl_scores.shape[2] - 1 num_classes = batch_mlvl_scores.shape[2] - 1
batch_mlvl_scores = batch_mlvl_scores[..., :num_classes] batch_mlvl_scores = batch_mlvl_scores[..., :num_classes]
max_output_boxes_per_class = cfg.nms.get( max_output_boxes_per_class = cfg.nms.get('max_output_boxes_per_class',
'max_output_boxes_per_class', 200) 200)
iou_threshold = cfg.nms.get('iou_threshold', 0.5) iou_threshold = cfg.nms.get('iou_threshold', 0.5)
score_threshold = cfg.score_thr score_threshold = cfg.score_thr
nms_pre = cfg.get('deploy_nms_pre', -1) nms_pre = cfg.get('deploy_nms_pre', -1)
return add_dummy_nms_for_onnx(batch_mlvl_bboxes, batch_mlvl_scores, return add_dummy_nms_for_onnx(batch_mlvl_bboxes, batch_mlvl_scores,
max_output_boxes_per_class, max_output_boxes_per_class,
iou_threshold, score_threshold, iou_threshold, score_threshold, nms_pre,
nms_pre, cfg.max_per_img) cfg.max_per_img)

View File

@ -5,16 +5,13 @@ from mmdeploy.utils import MODULE_REWRITERS
@MODULE_REWRITERS.register_rewrite_module(module_type='mmdet.models.RetinaNet') @MODULE_REWRITERS.register_rewrite_module(module_type='mmdet.models.RetinaNet')
@MODULE_REWRITERS.register_rewrite_module(module_type='mmdet.models.SingleStageDetector') @MODULE_REWRITERS.register_rewrite_module(
module_type='mmdet.models.SingleStageDetector')
class SingleStageDetector(nn.Module): class SingleStageDetector(nn.Module):
def __init__(self, module, cfg, **kwargs): def __init__(self, module, cfg, **kwargs):
super(SingleStageDetector, self).__init__() super(SingleStageDetector, self).__init__()
self.module = module self.module = module
self.backbone = module.backbone
self.with_neck = module.with_neck
if module.neck is not None:
self.neck = module.neck
self.bbox_head = module.bbox_head self.bbox_head = module.bbox_head
def forward(self, data, **kwargs): def forward(self, data, **kwargs):