Rewrite delta2bbox (#8)
* add delta2bbox rewriter * rename onnx2trt * add rewriter: anchor_generator_single_level_grid_priors * Revert "add rewriter: anchor_generator_single_level_grid_priors" This reverts commit ac7cf272942c4787bf143c0d67e414b0d2603b75. * update comments * remove clamp rewriter * remove unused funcpull/12/head
parent
66300c0c74
commit
dae6a8ccf9
|
@ -1,11 +1,9 @@
|
|||
import ctypes
|
||||
import glob
|
||||
import logging
|
||||
import os
|
||||
|
||||
|
||||
def get_ops_path():
|
||||
"""Get TensorRT plugins library path."""
|
||||
"""Get ONNX Runtime plugins library path."""
|
||||
wildcard = os.path.abspath(
|
||||
os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
|
@ -14,14 +12,3 @@ def get_ops_path():
|
|||
paths = glob.glob(wildcard)
|
||||
lib_path = paths[0] if len(paths) > 0 else ''
|
||||
return lib_path
|
||||
|
||||
|
||||
def load_tensorrt_plugin():
|
||||
"""load TensorRT plugins library."""
|
||||
lib_path = get_ops_path()
|
||||
if os.path.exists(lib_path):
|
||||
ctypes.CDLL(lib_path)
|
||||
return 0
|
||||
else:
|
||||
logging.warning('Can not load tensorrt custom ops.')
|
||||
return -1
|
||||
|
|
|
@ -55,7 +55,7 @@ def torch2onnx(img: Any,
|
|||
if ret_value is not None:
|
||||
ret_value.value = -1
|
||||
|
||||
# load deploy_cfg if needed
|
||||
# load deploy_cfg if necessary
|
||||
if isinstance(deploy_cfg, str):
|
||||
deploy_cfg = mmcv.Config.fromfile(deploy_cfg)
|
||||
if not isinstance(deploy_cfg, mmcv.Config):
|
||||
|
|
|
@ -14,14 +14,13 @@ def is_available():
|
|||
|
||||
if is_available():
|
||||
from .onnx2tensorrt import onnx2tensorrt
|
||||
from .tensorrt_utils import (TRTWrapper, load_trt_engine, onnx2trt,
|
||||
save_trt_engine)
|
||||
from .tensorrt_utils import (TRTWrapper, load_trt_engine,
|
||||
create_trt_engine, save_trt_engine)
|
||||
|
||||
# load tensorrt plugin lib
|
||||
load_tensorrt_plugin()
|
||||
|
||||
__all__ = [
|
||||
'onnx2trt', 'save_trt_engine', 'load_trt_engine', 'TRTWraper',
|
||||
'TRTWrapper', 'is_tensorrt_plugin_loaded', 'preprocess_onnx',
|
||||
'onnx2tensorrt'
|
||||
'create_trt_engine', 'save_trt_engine', 'load_trt_engine',
|
||||
'TRTWrapper', 'TRTWrapper', 'onnx2tensorrt'
|
||||
]
|
||||
|
|
|
@ -7,7 +7,7 @@ import onnx
|
|||
import tensorrt as trt
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from .tensorrt_utils import onnx2trt, save_trt_engine
|
||||
from .tensorrt_utils import create_trt_engine, save_trt_engine
|
||||
|
||||
|
||||
def get_trt_loglevel():
|
||||
|
@ -34,7 +34,7 @@ def onnx2tensorrt(work_dir: str,
|
|||
ret_value: Optional[mp.Value] = None):
|
||||
ret_value.value = -1
|
||||
|
||||
# load deploy_cfg if needed
|
||||
# load deploy_cfg if necessary
|
||||
if isinstance(deploy_cfg, str):
|
||||
deploy_cfg = mmcv.Config.fromfile(deploy_cfg)
|
||||
elif not isinstance(deploy_cfg, mmcv.Config):
|
||||
|
@ -56,7 +56,7 @@ def onnx2tensorrt(work_dir: str,
|
|||
device_id = 0
|
||||
if len(device) >= 6:
|
||||
device_id = int(device[5:])
|
||||
engine = onnx2trt(
|
||||
engine = create_trt_engine(
|
||||
onnx_model,
|
||||
opt_shape_dict=final_param['opt_shape_dict'],
|
||||
log_level=final_param.get('log_level', get_trt_loglevel()),
|
||||
|
|
|
@ -3,13 +3,13 @@ import tensorrt as trt
|
|||
import torch
|
||||
|
||||
|
||||
def onnx2trt(onnx_model,
|
||||
opt_shape_dict,
|
||||
log_level=trt.Logger.ERROR,
|
||||
fp16_mode=False,
|
||||
max_workspace_size=0,
|
||||
device_id=0):
|
||||
"""Convert onnx model to tensorrt engine.
|
||||
def create_trt_engine(onnx_model,
|
||||
opt_shape_dict,
|
||||
log_level=trt.Logger.ERROR,
|
||||
fp16_mode=False,
|
||||
max_workspace_size=0,
|
||||
device_id=0):
|
||||
"""Create a tensorrt engine from ONNX.
|
||||
|
||||
Arguments:
|
||||
onnx_model (str or onnx.ModelProto): the onnx model to convert from
|
||||
|
@ -24,7 +24,7 @@ def onnx2trt(onnx_model,
|
|||
tensorrt.ICudaEngine: the TensorRT engine created from onnx_model
|
||||
|
||||
Example:
|
||||
>>> engine = onnx2trt(
|
||||
>>> engine = create_trt_engine(
|
||||
>>> "onnx_model.onnx",
|
||||
>>> {'input': [[1, 3, 160, 160],
|
||||
>>> [1, 3, 320, 320],
|
||||
|
|
|
@ -13,10 +13,8 @@ class DummyONNXNMSop(torch.autograd.Function):
|
|||
@staticmethod
|
||||
def forward(ctx, boxes, scores, max_output_boxes_per_class, iou_threshold,
|
||||
score_threshold):
|
||||
|
||||
batch_size, num_class, num_box = scores.shape
|
||||
# turn off tracing to create a dummy output of nms
|
||||
# dummy indices of nms's output
|
||||
# create dummy indices of nms output
|
||||
num_fake_det = 2
|
||||
batch_inds = torch.randint(batch_size, (num_fake_det, 1))
|
||||
cls_inds = torch.randint(num_class, (num_fake_det, 1))
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
from .delta_xywh_bbox_coder import * # noqa: F401,F403
|
|
@ -0,0 +1,68 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
|
||||
from mmdeploy.utils import FUNCTION_REWRITERS
|
||||
|
||||
|
||||
@FUNCTION_REWRITERS.register_rewriter(
|
||||
func_name='mmdet.core.bbox.coder.delta_xywh_bbox_coder.delta2bbox', # noqa
|
||||
backend='default')
|
||||
def delta2bbox(rewriter,
|
||||
rois,
|
||||
deltas,
|
||||
means=(0., 0., 0., 0.),
|
||||
stds=(1., 1., 1., 1.),
|
||||
max_shape=None,
|
||||
wh_ratio_clip=16 / 1000,
|
||||
clip_border=True,
|
||||
add_ctr_clamp=False,
|
||||
ctr_clamp=32):
|
||||
means = deltas.new_tensor(means).view(1,
|
||||
-1).repeat(1,
|
||||
deltas.size(-1) // 4)
|
||||
stds = deltas.new_tensor(stds).view(1, -1).repeat(1, deltas.size(-1) // 4)
|
||||
denorm_deltas = deltas * stds + means
|
||||
dx = denorm_deltas[..., 0::4]
|
||||
dy = denorm_deltas[..., 1::4]
|
||||
dw = denorm_deltas[..., 2::4]
|
||||
dh = denorm_deltas[..., 3::4]
|
||||
|
||||
x1, y1 = rois[..., 0], rois[..., 1]
|
||||
x2, y2 = rois[..., 2], rois[..., 3]
|
||||
# Compute center of each roi
|
||||
px = ((x1 + x2) * 0.5).unsqueeze(-1).expand_as(dx)
|
||||
py = ((y1 + y2) * 0.5).unsqueeze(-1).expand_as(dy)
|
||||
# Compute width/height of each roi
|
||||
pw = (x2 - x1).unsqueeze(-1).expand_as(dw)
|
||||
ph = (y2 - y1).unsqueeze(-1).expand_as(dh)
|
||||
|
||||
dx_width = pw * dx
|
||||
dy_height = ph * dy
|
||||
|
||||
max_ratio = np.abs(np.log(wh_ratio_clip))
|
||||
if add_ctr_clamp:
|
||||
dx_width = torch.clamp(dx_width, max=ctr_clamp, min=-ctr_clamp)
|
||||
dy_height = torch.clamp(dy_height, max=ctr_clamp, min=-ctr_clamp)
|
||||
dw = torch.clamp(dw, max=max_ratio)
|
||||
dh = torch.clamp(dh, max=max_ratio)
|
||||
else:
|
||||
dw = dw.clamp(min=-max_ratio, max=max_ratio)
|
||||
dh = dh.clamp(min=-max_ratio, max=max_ratio)
|
||||
# Use exp(network energy) to enlarge/shrink each roi
|
||||
gw = pw * dw.exp()
|
||||
gh = ph * dh.exp()
|
||||
# Use network energy to shift the center of each roi
|
||||
gx = px + dx_width
|
||||
gy = py + dy_height
|
||||
# Convert center-xy/width/height to top-left, bottom-right
|
||||
x1 = gx - gw * 0.5
|
||||
y1 = gy - gh * 0.5
|
||||
x2 = gx + gw * 0.5
|
||||
y2 = gy + gh * 0.5
|
||||
|
||||
if clip_border and max_shape is not None:
|
||||
from mmdeploy.mmdet.core.export import clip_bboxes
|
||||
x1, y1, x2, y2 = clip_bboxes(x1, y1, x2, y2, max_shape)
|
||||
|
||||
bboxes = torch.stack([x1, y1, x2, y2], dim=-1).view(deltas.size())
|
||||
return bboxes
|
|
@ -1,3 +1,3 @@
|
|||
from .onnx_helper import add_dummy_nms_for_onnx, dynamic_clip_for_onnx
|
||||
from .onnx_helper import add_dummy_nms_for_onnx, clip_bboxes
|
||||
|
||||
__all__ = ['add_dummy_nms_for_onnx', 'dynamic_clip_for_onnx']
|
||||
__all__ = ['add_dummy_nms_for_onnx', 'clip_bboxes']
|
||||
|
|
|
@ -4,42 +4,45 @@ from mmdeploy.mmcv.ops import DummyONNXNMSop, TRTBatchedNMSop
|
|||
from mmdeploy.utils import FUNCTION_REWRITERS
|
||||
|
||||
|
||||
def dynamic_clip_for_onnx(x1, y1, x2, y2, max_shape):
|
||||
"""Clip boxes dynamically for onnx.
|
||||
def clip_bboxes(x1, y1, x2, y2, max_shape):
|
||||
"""Clip bboxes for onnx.
|
||||
|
||||
Since torch.clamp cannot have dynamic `min` and `max`, we scale the
|
||||
boxes by 1/max_shape and clamp in the range [0, 1].
|
||||
boxes by 1/max_shape and clamp in the range [0, 1] if necessary.
|
||||
|
||||
Args:
|
||||
x1 (Tensor): The x1 for bounding boxes.
|
||||
y1 (Tensor): The y1 for bounding boxes.
|
||||
x2 (Tensor): The x2 for bounding boxes.
|
||||
y2 (Tensor): The y2 for bounding boxes.
|
||||
max_shape (Tensor or torch.Size): The (H,W) of original image.
|
||||
max_shape (List or Tensor): The (H,W) of original image.
|
||||
Returns:
|
||||
tuple(Tensor): The clipped x1, y1, x2, y2.
|
||||
"""
|
||||
assert isinstance(
|
||||
max_shape,
|
||||
torch.Tensor), '`max_shape` should be tensor of (h,w) for onnx'
|
||||
assert len(max_shape) == 2, '`max_shape` should be [h, w]'
|
||||
if isinstance(max_shape, torch.Tensor):
|
||||
# scale by 1/max_shape
|
||||
x1 = x1 / max_shape[1]
|
||||
y1 = y1 / max_shape[0]
|
||||
x2 = x2 / max_shape[1]
|
||||
y2 = y2 / max_shape[0]
|
||||
|
||||
# scale by 1/max_shape
|
||||
x1 = x1 / max_shape[1]
|
||||
y1 = y1 / max_shape[0]
|
||||
x2 = x2 / max_shape[1]
|
||||
y2 = y2 / max_shape[0]
|
||||
# clamp [0, 1]
|
||||
x1 = torch.clamp(x1, 0, 1)
|
||||
y1 = torch.clamp(y1, 0, 1)
|
||||
x2 = torch.clamp(x2, 0, 1)
|
||||
y2 = torch.clamp(y2, 0, 1)
|
||||
|
||||
# clamp [0, 1]
|
||||
x1 = torch.clamp(x1, 0, 1)
|
||||
y1 = torch.clamp(y1, 0, 1)
|
||||
x2 = torch.clamp(x2, 0, 1)
|
||||
y2 = torch.clamp(y2, 0, 1)
|
||||
|
||||
# scale back
|
||||
x1 = x1 * max_shape[1]
|
||||
y1 = y1 * max_shape[0]
|
||||
x2 = x2 * max_shape[1]
|
||||
y2 = y2 * max_shape[0]
|
||||
# scale back
|
||||
x1 = x1 * max_shape[1]
|
||||
y1 = y1 * max_shape[0]
|
||||
x2 = x2 * max_shape[1]
|
||||
y2 = y2 * max_shape[0]
|
||||
else:
|
||||
x1 = torch.clamp(x1, 0, max_shape[1])
|
||||
y1 = torch.clamp(y1, 0, max_shape[0])
|
||||
x2 = torch.clamp(x2, 0, max_shape[1])
|
||||
y2 = torch.clamp(y2, 0, max_shape[0])
|
||||
return x1, y1, x2, y2
|
||||
|
||||
|
||||
|
|
|
@ -18,6 +18,7 @@ def anchor_head_get_bboxes(rewriter,
|
|||
**kwargs):
|
||||
assert len(cls_scores) == len(bbox_preds)
|
||||
deploy_cfg = rewriter.cfg
|
||||
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
|
||||
num_levels = len(cls_scores)
|
||||
|
||||
device = cls_scores[0].device
|
||||
|
@ -49,7 +50,7 @@ def anchor_head_get_bboxes(rewriter,
|
|||
bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 4)
|
||||
|
||||
# use static anchor if input shape is static
|
||||
if not is_dynamic_shape(deploy_cfg):
|
||||
if not is_dynamic_flag:
|
||||
anchors = anchors.data
|
||||
|
||||
anchors = anchors.expand_as(bbox_pred)
|
||||
|
@ -78,9 +79,6 @@ def anchor_head_get_bboxes(rewriter,
|
|||
bbox_pred = bbox_pred[batch_inds, topk_inds, :]
|
||||
scores = scores[batch_inds, topk_inds, :]
|
||||
|
||||
if not is_dynamic_shape(deploy_cfg):
|
||||
img_shape = [int(val) for val in img_shape]
|
||||
|
||||
bboxes = self.bbox_coder.decode(
|
||||
anchors, bbox_pred, max_shape=img_shape)
|
||||
mlvl_bboxes.append(bboxes)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import torch
|
||||
|
||||
from mmdeploy.utils import FUNCTION_REWRITERS, mark
|
||||
from mmdeploy.utils import FUNCTION_REWRITERS, is_dynamic_shape, mark
|
||||
|
||||
|
||||
@FUNCTION_REWRITERS.register_rewriter(
|
||||
|
@ -15,8 +15,12 @@ def single_stage_extract_feat(rewriter, self, img):
|
|||
@FUNCTION_REWRITERS.register_rewriter(
|
||||
func_name='mmdet.models.SingleStageDetector.forward')
|
||||
def single_stage_forward(rewriter, self, data, **kwargs):
|
||||
deploy_cfg = rewriter.cfg
|
||||
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
|
||||
# get origin input shape to support onnx dynamic shape
|
||||
img_shape = torch._shape_as_tensor(data)[2:]
|
||||
if not is_dynamic_flag:
|
||||
img_shape = [int(val) for val in img_shape]
|
||||
x = self.extract_feat(data)
|
||||
outs = self.bbox_head(x)
|
||||
return self.bbox_head.get_bboxes(*outs, img_shape, **kwargs)
|
||||
|
|
|
@ -41,6 +41,7 @@ def set_symbolic(cfg: Dict,
|
|||
func,
|
||||
Function), '{} is not an torch.autograd.Function'.format(
|
||||
func_name)
|
||||
symbolic_impl.origin_func = getattr(func, 'symbolic', None)
|
||||
func.symbolic = symbolic_impl
|
||||
except Exception:
|
||||
logging.warning(
|
||||
|
|
Loading…
Reference in New Issue