[Deploy] MMYOLO model convert to onnx for deployment. (#279)

* Fromat code

* Support ONNXRUNTIME

* Support mmyolo model convert to onnx for deploy.

* Same as dev branch

* Support yolox focus rewrite

* Support GConv Focus

* Update mmyolo/easydeploy/backbone/focus.py

Co-authored-by: HinGwenWoong <peterhuang0323@qq.com>

* Add TensorRT build/infer Wrapper

* Add image_demo for deploy model

* Fix

* Merge dev

* Remove image-demo

* Roll back to dev

* Support model switch to deploy

* Remove --deploy

* Add new deploy method

* Format code and add doc

* Move md to project

* add readme and readme_zh to easy_deploy

* Update projects/easydeploy/README_zh-CN.md

Co-authored-by: Haian Huang(深度眸) <1286304229@qq.com>

* Update projects/easydeploy/README.md

Co-authored-by: Haian Huang(深度眸) <1286304229@qq.com>

Co-authored-by: HinGwenWoong <peterhuang0323@qq.com>
Co-authored-by: xin-li-67 <williamlee.xin@gmail.com>
Co-authored-by: Xin Li <7219519+xin-li-67@users.noreply.github.com>
Co-authored-by: Haian Huang(深度眸) <1286304229@qq.com>
pull/315/head
tripleMu 2022-11-23 10:43:41 +08:00 committed by GitHub
parent b1e478a8bd
commit 1045b41b68
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 1130 additions and 0 deletions

View File

@ -0,0 +1,11 @@
# MMYOLO Model Easy-Deployment
## Introduction
This project is developed for easily converting your MMYOLO models to other inference backends without the need of MMDeploy, which reduces the cost of both time and effort on getting familiar with MMDeploy.
Currently we support converting to `ONNX` and `TensorRT` formats, other inference backends such `ncnn` will be added to this project as well.
## Supported Backends
- [Model Convert](docs/model_convert.md)

View File

@ -0,0 +1,11 @@
# MMYOLO 模型转换
## 介绍
本项目作为 MMYOLO 的部署 project 单独存在,意图剥离 MMDeploy 当前的体系,独自支持用户完成模型训练后的转换和部署功能,使用户的学习和工程成本下降。
当前支持对 ONNX 格式和 TensorRT 格式的转换,后续对其他推理平台也会支持起来。
## 转换教程
- [Model Convert](docs/model_convert.md)

View File

@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .focus import DeployFocus, GConvFocus, NcnnFocus
__all__ = ['DeployFocus', 'NcnnFocus', 'GConvFocus']

View File

@ -0,0 +1,79 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
class DeployFocus(nn.Module):
def __init__(self, orin_Focus: nn.Module):
super().__init__()
self.__dict__.update(orin_Focus.__dict__)
def forward(self, x: Tensor) -> Tensor:
batch_size, channel, height, width = x.shape
x = x.reshape(batch_size, channel, -1, 2, width)
x = x.reshape(batch_size, channel, x.shape[2], 2, -1, 2)
half_h = x.shape[2]
half_w = x.shape[4]
x = x.permute(0, 5, 3, 1, 2, 4)
x = x.reshape(batch_size, channel * 4, half_h, half_w)
return self.conv(x)
class NcnnFocus(nn.Module):
def __init__(self, orin_Focus: nn.Module):
super().__init__()
self.__dict__.update(orin_Focus.__dict__)
def forward(self, x: Tensor) -> Tensor:
batch_size, c, h, w = x.shape
assert h % 2 == 0 and w % 2 == 0, f'focus for yolox needs even feature\
height and width, got {(h, w)}.'
x = x.reshape(batch_size, c * h, 1, w)
_b, _c, _h, _w = x.shape
g = _c // 2
# fuse to ncnn's shufflechannel
x = x.view(_b, g, 2, _h, _w)
x = torch.transpose(x, 1, 2).contiguous()
x = x.view(_b, -1, _h, _w)
x = x.reshape(_b, c * h * w, 1, 1)
_b, _c, _h, _w = x.shape
g = _c // 2
# fuse to ncnn's shufflechannel
x = x.view(_b, g, 2, _h, _w)
x = torch.transpose(x, 1, 2).contiguous()
x = x.view(_b, -1, _h, _w)
x = x.reshape(_b, c * 4, h // 2, w // 2)
return self.conv(x)
class GConvFocus(nn.Module):
def __init__(self, orin_Focus: nn.Module):
super().__init__()
device = next(orin_Focus.parameters()).device
self.weight1 = torch.tensor([[1., 0], [0, 0]]).expand(3, 1, 2,
2).to(device)
self.weight2 = torch.tensor([[0, 0], [1., 0]]).expand(3, 1, 2,
2).to(device)
self.weight3 = torch.tensor([[0, 1.], [0, 0]]).expand(3, 1, 2,
2).to(device)
self.weight4 = torch.tensor([[0, 0], [0, 1.]]).expand(3, 1, 2,
2).to(device)
self.__dict__.update(orin_Focus.__dict__)
def forward(self, x: Tensor) -> Tensor:
conv1 = F.conv2d(x, self.weight1, stride=2, groups=3)
conv2 = F.conv2d(x, self.weight2, stride=2, groups=3)
conv3 = F.conv2d(x, self.weight3, stride=2, groups=3)
conv4 = F.conv2d(x, self.weight4, stride=2, groups=3)
return self.conv(torch.cat([conv1, conv2, conv3, conv4], dim=1))

View File

@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .bbox_coder import rtmdet_bbox_decoder, yolov5_bbox_decoder
__all__ = ['yolov5_bbox_decoder', 'rtmdet_bbox_decoder']

View File

@ -0,0 +1,35 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional
import torch
from torch import Tensor
def yolov5_bbox_decoder(priors: Tensor, bbox_preds: Tensor,
stride: Tensor) -> Tensor:
bbox_preds = bbox_preds.sigmoid()
x_center = (priors[..., 0] + priors[..., 2]) * 0.5
y_center = (priors[..., 1] + priors[..., 3]) * 0.5
w = priors[..., 2] - priors[..., 0]
h = priors[..., 3] - priors[..., 1]
x_center_pred = (bbox_preds[..., 0] - 0.5) * 2 * stride + x_center
y_center_pred = (bbox_preds[..., 1] - 0.5) * 2 * stride + y_center
w_pred = (bbox_preds[..., 2] * 2)**2 * w
h_pred = (bbox_preds[..., 3] * 2)**2 * h
decoded_bboxes = torch.stack(
[x_center_pred, y_center_pred, w_pred, h_pred], dim=-1)
return decoded_bboxes
def rtmdet_bbox_decoder(priors: Tensor, bbox_preds: Tensor,
stride: Optional[Tensor]) -> Tensor:
tl_x = (priors[..., 0] - bbox_preds[..., 0])
tl_y = (priors[..., 1] - bbox_preds[..., 1])
br_x = (priors[..., 0] + bbox_preds[..., 2])
br_y = (priors[..., 1] + bbox_preds[..., 3])
decoded_bboxes = torch.stack([tl_x, tl_y, br_x, br_y], -1)
return decoded_bboxes

View File

@ -0,0 +1,56 @@
# MMYOLO 模型 ONNX 转换
## 环境依赖
- [onnx](https://github.com/onnx/onnx)
```shell
pip install onnx
```
[onnx-simplifier](https://github.com/daquexian/onnx-simplifier) (可选,用于简化模型)
```shell
pip install onnx-simplifier
```
## 使用方法
[模型导出脚本](./projects/easydeploy/tools/export.py)用于将 `MMYOLO` 模型转换为 `onnx`
### 参数介绍:
- `config` : 构建模型使用的配置文件,如 [`yolov5_s-v61_syncbn_fast_8xb16-300e_coco.py`](./configs/yolov5/yolov5_s-v61_syncbn_fast_8xb16-300e_coco.py) 。
- `checkpoint` : 训练得到的权重文件,如 `yolov5s.pth`
- `--work-dir` : 转换后的模型保存路径。
- `--img-size`: 转换模型时输入的尺寸,如 `640 640`
- `--batch-size`: 转换后的模型输入 `batch size`
- `--device`: 转换模型使用的设备,默认为 `cuda:0`
- `--simplify`: 是否简化导出的 `onnx` 模型,需要安装 [onnx-simplifier](https://github.com/daquexian/onnx-simplifier),默认关闭。
- `--opset`: 指定导出 `onnx``opset`,默认为 `11`
- `--backend`: 指定导出 `onnx` 用于的后端 id`ONNXRuntime`: `1`, `TensorRT8`: `2`, `TensorRT7`: `3`,默认为`1`即 `ONNXRuntime`
- `--pre-topk`: 指定导出 `onnx` 的后处理筛选候选框个数阈值,默认为 `1000`
- `--keep-topk`: 指定导出 `onnx` 的非极大值抑制输出的候选框个数阈值,默认为 `100`
- `--iou-threshold`: 非极大值抑制中过滤重复候选框的 `iou` 阈值,默认为 `0.65`
- `--score-threshold`: 非极大值抑制中过滤候选框得分的阈值,默认为 `0.25`
例子:
```shell
python ./projects/easydeploy/tools/export.py \
configs/yolov5/yolov5_s-v61_syncbn_fast_8xb16-300e_coco.py \
yolov5s.pth \
--work-dir work_dir \
--img-size 640 640 \
--batch 1 \
--device cpu \
--simplify \
--opset 11 \
--backend 1 \
--pre-topk 1000 \
--keep-topk 100 \
--iou-threshold 0.65 \
--score-threshold 0.25
```
然后利用后端支持的工具如 `TensorRT` 读取 `onnx` 再次转换为后端支持的模型格式如 `.engine/.plan`

View File

@ -0,0 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .backendwrapper import BackendWrapper, EngineBuilder
from .model import DeployModel
__all__ = ['DeployModel', 'BackendWrapper', 'EngineBuilder']

View File

@ -0,0 +1,256 @@
import warnings
from collections import OrderedDict, namedtuple
from functools import partial
from pathlib import Path
from typing import List, Optional, Tuple, Union
import numpy as np
import onnxruntime
import tensorrt as trt
import torch
from numpy import ndarray
from torch import Tensor
warnings.filterwarnings(action='ignore', category=DeprecationWarning)
class BackendWrapper:
def __init__(
self,
weight: Union[str, Path],
device: Optional[Union[str, int, torch.device]] = None) -> None:
weight = Path(weight) if isinstance(weight, str) else weight
assert weight.exists() and weight.suffix in ('.onnx', '.engine',
'.plan')
if isinstance(device, str):
device = torch.device(device)
elif isinstance(device, int):
device = torch.device(f'cuda:{device}')
self.weight = weight
self.device = device
self.__build_model()
self.__init_runtime()
self.__warm_up(10)
def __build_model(self) -> None:
model_info = dict()
num_input = num_output = 0
names = []
is_dynamic = False
if self.weight.suffix == '.onnx':
model_info['backend'] = 'ONNXRuntime'
providers = ['CPUExecutionProvider']
if 'cuda' in self.device.type:
providers.insert(0, 'CUDAExecutionProvider')
model = onnxruntime.InferenceSession(
str(self.weight), providers=providers)
for i, tensor in enumerate(model.get_inputs()):
model_info[tensor.name] = dict(
shape=tensor.shape, dtype=tensor.type)
num_input += 1
names.append(tensor.name)
is_dynamic |= any(
map(lambda x: isinstance(x, str), tensor.shape))
for i, tensor in enumerate(model.get_outputs()):
model_info[tensor.name] = dict(
shape=tensor.shape, dtype=tensor.type)
num_output += 1
names.append(tensor.name)
else:
model_info['backend'] = 'TensorRT'
logger = trt.Logger(trt.Logger.ERROR)
trt.init_libnvinfer_plugins(logger, namespace='')
with trt.Runtime(logger) as runtime:
model = runtime.deserialize_cuda_engine(
self.weight.read_bytes())
profile_shape = []
for i in range(model.num_bindings):
name = model.get_binding_name(i)
shape = tuple(model.get_binding_shape(i))
dtype = trt.nptype(model.get_binding_dtype(i))
is_dynamic |= (-1 in shape)
if model.binding_is_input(i):
num_input += 1
profile_shape.append(model.get_profile_shape(i, 0))
else:
num_output += 1
model_info[name] = dict(shape=shape, dtype=dtype)
names.append(name)
model_info['profile_shape'] = profile_shape
self.num_input = num_input
self.num_output = num_output
self.names = names
self.is_dynamic = is_dynamic
self.model = model
self.model_info = model_info
def __init_runtime(self) -> None:
bindings = OrderedDict()
Binding = namedtuple('Binding',
('name', 'dtype', 'shape', 'data', 'ptr'))
if self.model_info['backend'] == 'TensorRT':
context = self.model.create_execution_context()
for name in self.names:
shape, dtype = self.model_info[name].values()
if self.is_dynamic:
cpu_tensor, gpu_tensor, ptr = None, None, None
else:
cpu_tensor = np.empty(shape, dtype=np.dtype(dtype))
gpu_tensor = torch.from_numpy(cpu_tensor).to(self.device)
ptr = int(gpu_tensor.data_ptr())
bindings[name] = Binding(name, dtype, shape, gpu_tensor, ptr)
else:
output_names = []
for i, name in enumerate(self.names):
if i >= self.num_input:
output_names.append(name)
shape, dtype = self.model_info[name].values()
bindings[name] = Binding(name, dtype, shape, None, None)
context = partial(self.model.run, output_names)
self.addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
self.bindings = bindings
self.context = context
def __infer(
self, inputs: List[Union[ndarray,
Tensor]]) -> List[Union[ndarray, Tensor]]:
assert len(inputs) == self.num_input
if self.model_info['backend'] == 'TensorRT':
outputs = []
for i, (name, gpu_input) in enumerate(
zip(self.names[:self.num_input], inputs)):
if self.is_dynamic:
self.context.set_binding_shape(i, gpu_input.shape)
self.addrs[name] = gpu_input.data_ptr()
for i, name in enumerate(self.names[self.num_input:]):
i += self.num_input
if self.is_dynamic:
shape = tuple(self.context.get_binding_shape(i))
dtype = self.bindings[name].dtype
cpu_tensor = np.empty(shape, dtype=np.dtype(dtype))
out = torch.from_numpy(cpu_tensor).to(self.device)
self.addrs[name] = out.data_ptr()
else:
out = self.bindings[name].data
outputs.append(out)
assert self.context.execute_v2(list(
self.addrs.values())), 'Infer fault'
else:
input_feed = {
name: inputs[i]
for i, name in enumerate(self.names[:self.num_input])
}
outputs = self.context(input_feed)
return outputs
def __warm_up(self, n=10) -> None:
for _ in range(n):
_tmp = []
if self.model_info['backend'] == 'TensorRT':
for i, name in enumerate(self.names[:self.num_input]):
if self.is_dynamic:
shape = self.model_info['profile_shape'][i][1]
dtype = self.bindings[name].dtype
cpu_tensor = np.empty(shape, dtype=np.dtype(dtype))
_tmp.append(
torch.from_numpy(cpu_tensor).to(self.device))
else:
_tmp.append(self.bindings[name].data)
else:
print('Please warm up ONNXRuntime model by yourself')
print("So this model doesn't warm up")
return
_ = self.__infer(_tmp)
def __call__(
self, inputs: Union[List, Tensor,
ndarray]) -> List[Union[Tensor, ndarray]]:
if not isinstance(inputs, list):
inputs = [inputs]
outputs = self.__infer(inputs)
return outputs
class EngineBuilder:
def __init__(
self,
checkpoint: Union[str, Path],
opt_shape: Union[Tuple, List] = (1, 3, 640, 640),
device: Optional[Union[str, int, torch.device]] = None) -> None:
checkpoint = Path(checkpoint) if isinstance(checkpoint,
str) else checkpoint
assert checkpoint.exists() and checkpoint.suffix == '.onnx'
if isinstance(device, str):
device = torch.device(device)
elif isinstance(device, int):
device = torch.device(f'cuda:{device}')
self.checkpoint = checkpoint
self.opt_shape = np.array(opt_shape, dtype=np.float32)
self.device = device
def __build_engine(self,
scale: Optional[List[List]] = None,
fp16: bool = True,
with_profiling: bool = True) -> None:
logger = trt.Logger(trt.Logger.WARNING)
trt.init_libnvinfer_plugins(logger, namespace='')
builder = trt.Builder(logger)
config = builder.create_builder_config()
config.max_workspace_size = torch.cuda.get_device_properties(
self.device).total_memory
flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
network = builder.create_network(flag)
parser = trt.OnnxParser(network, logger)
if not parser.parse_from_file(str(self.checkpoint)):
raise RuntimeError(
f'failed to load ONNX file: {str(self.checkpoint)}')
inputs = [network.get_input(i) for i in range(network.num_inputs)]
outputs = [network.get_output(i) for i in range(network.num_outputs)]
profile = None
dshape = -1 in network.get_input(0).shape
if dshape:
profile = builder.create_optimization_profile()
if scale is None:
scale = np.array(
[[1, 1, 0.5, 0.5], [1, 1, 1, 1], [4, 1, 1.5, 1.5]],
dtype=np.float32)
scale = (self.opt_shape * scale).astype(np.int32)
elif isinstance(scale, List):
scale = np.array(scale, dtype=np.int32)
assert scale.shape[0] == 3, 'Input a wrong scale list'
else:
raise NotImplementedError
for inp in inputs:
logger.log(
trt.Logger.WARNING,
f'input "{inp.name}" with shape{inp.shape} {inp.dtype}')
if dshape:
profile.set_shape(inp.name, *scale)
for out in outputs:
logger.log(
trt.Logger.WARNING,
f'output "{out.name}" with shape{out.shape} {out.dtype}')
if fp16 and builder.platform_has_fast_fp16:
config.set_flag(trt.BuilderFlag.FP16)
self.weight = self.checkpoint.with_suffix('.engine')
if dshape:
config.add_optimization_profile(profile)
if with_profiling:
config.profiling_verbosity = trt.ProfilingVerbosity.DETAILED
with builder.build_engine(network, config) as engine:
self.weight.write_bytes(engine.serialize())
logger.log(
trt.Logger.WARNING, f'Build tensorrt engine finish.\n'
f'Save in {str(self.weight.absolute())}')
def build(self,
scale: Optional[List[List]] = None,
fp16: bool = True,
with_profiling=True):
self.__build_engine(scale, fp16, with_profiling)

View File

@ -0,0 +1,144 @@
# Copyright (c) OpenMMLab. All rights reserved.
from functools import partial
from typing import List, Optional
import torch
import torch.nn as nn
from mmdet.models.backbones.csp_darknet import Focus
from mmengine.config import ConfigDict
from torch import Tensor
from mmyolo.models import RepVGGBlock
from mmyolo.models.dense_heads import RTMDetHead, YOLOv5Head
from ..backbone import DeployFocus, GConvFocus, NcnnFocus
from ..bbox_code import rtmdet_bbox_decoder, yolov5_bbox_decoder
from ..nms import batched_nms, efficient_nms, onnx_nms
class DeployModel(nn.Module):
def __init__(self,
baseModel: nn.Module,
postprocess_cfg: Optional[ConfigDict] = None):
super().__init__()
self.baseModel = baseModel
self.baseHead = baseModel.bbox_head
self.__init_sub_attributes()
detector_type = type(self.baseHead)
if postprocess_cfg is None:
pre_top_k = 1000
keep_top_k = 100
iou_threshold = 0.65
score_threshold = 0.25
backend = 1
else:
pre_top_k = postprocess_cfg.get('pre_top_k', 1000)
keep_top_k = postprocess_cfg.get('keep_top_k', 100)
iou_threshold = postprocess_cfg.get('iou_threshold', 0.65)
score_threshold = postprocess_cfg.get('score_threshold', 0.25)
backend = postprocess_cfg.get('backend', 1)
self.__switch_deploy()
self.__dict__.update(locals())
def __init_sub_attributes(self):
self.bbox_decoder = self.baseHead.bbox_coder.decode
self.prior_generate = self.baseHead.prior_generator.grid_priors
self.num_base_priors = self.baseHead.num_base_priors
self.featmap_strides = self.baseHead.featmap_strides
self.num_classes = self.baseHead.num_classes
def __switch_deploy(self):
for layer in self.baseModel.modules():
if isinstance(layer, RepVGGBlock):
layer.switch_to_deploy()
if isinstance(layer, Focus):
# onnxruntime tensorrt8 tensorrt7
if self.backend in (1, 2, 3):
self.baseModel.backbone.stem = DeployFocus(layer)
# ncnn
elif self.backend == 4:
self.baseModel.backbone.stem = NcnnFocus(layer)
# switch focus to group conv
else:
self.baseModel.backbone.stem = GConvFocus(layer)
def pred_by_feat(self,
cls_scores: List[Tensor],
bbox_preds: List[Tensor],
objectnesses: Optional[List[Tensor]] = None,
**kwargs):
assert len(cls_scores) == len(bbox_preds)
dtype = cls_scores[0].dtype
device = cls_scores[0].device
nms_func = self.select_nms()
if self.detector_type is YOLOv5Head:
bbox_decoder = yolov5_bbox_decoder
elif self.detector_type is RTMDetHead:
bbox_decoder = rtmdet_bbox_decoder
else:
bbox_decoder = self.bbox_decoder
num_imgs = cls_scores[0].shape[0]
featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores]
mlvl_priors = self.prior_generate(
featmap_sizes, dtype=dtype, device=device)
flatten_priors = torch.cat(mlvl_priors)
mlvl_strides = [
flatten_priors.new_full(
(featmap_size[0] * featmap_size[1] * self.num_base_priors, ),
stride) for featmap_size, stride in zip(
featmap_sizes, self.featmap_strides)
]
flatten_stride = torch.cat(mlvl_strides)
# flatten cls_scores, bbox_preds and objectness
flatten_cls_scores = [
cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1,
self.num_classes)
for cls_score in cls_scores
]
cls_scores = torch.cat(flatten_cls_scores, dim=1).sigmoid()
flatten_bbox_preds = [
bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
for bbox_pred in bbox_preds
]
flatten_bbox_preds = torch.cat(flatten_bbox_preds, dim=1)
if objectnesses is not None:
flatten_objectness = [
objectness.permute(0, 2, 3, 1).reshape(num_imgs, -1)
for objectness in objectnesses
]
flatten_objectness = torch.cat(flatten_objectness, dim=1).sigmoid()
cls_scores = cls_scores * (flatten_objectness.unsqueeze(-1))
scores = cls_scores
bboxes = bbox_decoder(flatten_priors[None], flatten_bbox_preds,
flatten_stride)
return nms_func(bboxes, scores, self.keep_top_k, self.iou_threshold,
self.score_threshold, self.pre_top_k, self.keep_top_k)
def select_nms(self):
if self.backend == 1:
nms_func = onnx_nms
elif self.backend == 2:
nms_func = efficient_nms
elif self.backend == 3:
nms_func = batched_nms
else:
raise NotImplementedError
if type(self.baseHead) is YOLOv5Head:
nms_func = partial(nms_func, box_coding=1)
return nms_func
def forward(self, inputs: Tensor):
neck_outputs = self.baseModel(inputs)
outputs = self.pred_by_feat(*neck_outputs)
return outputs

View File

@ -0,0 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .ort_nms import onnx_nms
from .trt_nms import batched_nms, efficient_nms
__all__ = ['efficient_nms', 'batched_nms', 'onnx_nms']

View File

@ -0,0 +1,122 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from torch import Tensor
_XYWH2XYXY = torch.tensor([[1.0, 0.0, 1.0, 0.0], [0.0, 1.0, 0.0, 1.0],
[-0.5, 0.0, 0.5, 0.0], [0.0, -0.5, 0.0, 0.5]],
dtype=torch.float32)
def select_nms_index(scores: Tensor,
boxes: Tensor,
nms_index: Tensor,
batch_size: int,
keep_top_k: int = -1):
batch_inds, cls_inds = nms_index[:, 0], nms_index[:, 1]
box_inds = nms_index[:, 2]
scores = scores[batch_inds, cls_inds, box_inds].unsqueeze(1)
boxes = boxes[batch_inds, box_inds, ...]
dets = torch.cat([boxes, scores], dim=1)
batched_dets = dets.unsqueeze(0).repeat(batch_size, 1, 1)
batch_template = torch.arange(
0, batch_size, dtype=batch_inds.dtype, device=batch_inds.device)
batched_dets = batched_dets.where(
(batch_inds == batch_template.unsqueeze(1)).unsqueeze(-1),
batched_dets.new_zeros(1))
batched_labels = cls_inds.unsqueeze(0).repeat(batch_size, 1)
batched_labels = batched_labels.where(
(batch_inds == batch_template.unsqueeze(1)),
batched_labels.new_ones(1) * -1)
N = batched_dets.shape[0]
batched_dets = torch.cat((batched_dets, batched_dets.new_zeros((N, 1, 5))),
1)
batched_labels = torch.cat((batched_labels, -batched_labels.new_ones(
(N, 1))), 1)
_, topk_inds = batched_dets[:, :, -1].sort(dim=1, descending=True)
topk_batch_inds = torch.arange(
batch_size, dtype=topk_inds.dtype,
device=topk_inds.device).view(-1, 1)
batched_dets = batched_dets[topk_batch_inds, topk_inds, ...]
batched_labels = batched_labels[topk_batch_inds, topk_inds, ...]
batched_dets, batched_scores = batched_dets.split([4, 1], 2)
batched_scores = batched_scores.squeeze(-1)
num_dets = (batched_scores > 0).sum(1, keepdim=True)
return num_dets, batched_dets, batched_scores, batched_labels
class ONNXNMSop(torch.autograd.Function):
@staticmethod
def forward(
ctx,
boxes: Tensor,
scores: Tensor,
max_output_boxes_per_class: Tensor = torch.tensor([100]),
iou_threshold: Tensor = torch.tensor([0.5]),
score_threshold: Tensor = torch.tensor([0.05])
) -> Tensor:
device = boxes.device
batch = scores.shape[0]
num_det = 20
batches = torch.randint(0, batch, (num_det, )).sort()[0].to(device)
idxs = torch.arange(100, 100 + num_det).to(device)
zeros = torch.zeros((num_det, ), dtype=torch.int64).to(device)
selected_indices = torch.cat([batches[None], zeros[None], idxs[None]],
0).T.contiguous()
selected_indices = selected_indices.to(torch.int64)
return selected_indices
@staticmethod
def symbolic(
g,
boxes: Tensor,
scores: Tensor,
max_output_boxes_per_class: Tensor = torch.tensor([100]),
iou_threshold: Tensor = torch.tensor([0.5]),
score_threshold: Tensor = torch.tensor([0.05]),
):
return g.op(
'NonMaxSuppression',
boxes,
scores,
max_output_boxes_per_class,
iou_threshold,
score_threshold,
outputs=1)
def onnx_nms(
boxes: torch.Tensor,
scores: torch.Tensor,
max_output_boxes_per_class: int = 100,
iou_threshold: float = 0.5,
score_threshold: float = 0.05,
pre_top_k: int = -1,
keep_top_k: int = 100,
box_coding: int = 0,
):
max_output_boxes_per_class = torch.tensor([max_output_boxes_per_class])
iou_threshold = torch.tensor([iou_threshold])
score_threshold = torch.tensor([score_threshold])
batch_size, _, _ = scores.shape
if box_coding == 1:
boxes = boxes @ (_XYWH2XYXY.to(boxes.device))
scores = scores.transpose(1, 2).contiguous()
selected_indices = ONNXNMSop.apply(boxes, scores,
max_output_boxes_per_class,
iou_threshold, score_threshold)
num_dets, batched_dets, batched_scores, batched_labels = select_nms_index(
scores, boxes, selected_indices, batch_size, keep_top_k=keep_top_k)
return num_dets, batched_dets, batched_scores, batched_labels.to(
torch.int32)

View File

@ -0,0 +1,220 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from torch import Tensor
class TRTEfficientNMSop(torch.autograd.Function):
@staticmethod
def forward(
ctx,
boxes: Tensor,
scores: Tensor,
background_class: int = -1,
box_coding: int = 0,
iou_threshold: float = 0.45,
max_output_boxes: int = 100,
plugin_version: str = '1',
score_activation: int = 0,
score_threshold: float = 0.25,
):
batch_size, _, num_classes = scores.shape
num_det = torch.randint(
0, max_output_boxes, (batch_size, 1), dtype=torch.int32)
det_boxes = torch.randn(batch_size, max_output_boxes, 4)
det_scores = torch.randn(batch_size, max_output_boxes)
det_classes = torch.randint(
0, num_classes, (batch_size, max_output_boxes), dtype=torch.int32)
return num_det, det_boxes, det_scores, det_classes
@staticmethod
def symbolic(g,
boxes: Tensor,
scores: Tensor,
background_class: int = -1,
box_coding: int = 0,
iou_threshold: float = 0.45,
max_output_boxes: int = 100,
plugin_version: str = '1',
score_activation: int = 0,
score_threshold: float = 0.25):
out = g.op(
'TRT::EfficientNMS_TRT',
boxes,
scores,
background_class_i=background_class,
box_coding_i=box_coding,
iou_threshold_f=iou_threshold,
max_output_boxes_i=max_output_boxes,
plugin_version_s=plugin_version,
score_activation_i=score_activation,
score_threshold_f=score_threshold,
outputs=4)
num_det, det_boxes, det_scores, det_classes = out
return num_det, det_boxes, det_scores, det_classes
class TRTbatchedNMSop(torch.autograd.Function):
"""TensorRT NMS operation."""
@staticmethod
def forward(
ctx,
boxes: Tensor,
scores: Tensor,
plugin_version: str = '1',
shareLocation: int = 1,
backgroundLabelId: int = -1,
numClasses: int = 80,
topK: int = 1000,
keepTopK: int = 100,
scoreThreshold: float = 0.25,
iouThreshold: float = 0.45,
isNormalized: int = 0,
clipBoxes: int = 0,
scoreBits: int = 16,
caffeSemantics: int = 1,
):
batch_size, _, numClasses = scores.shape
num_det = torch.randint(
0, keepTopK, (batch_size, 1), dtype=torch.int32)
det_boxes = torch.randn(batch_size, keepTopK, 4)
det_scores = torch.randn(batch_size, keepTopK)
det_classes = torch.randint(0, numClasses,
(batch_size, keepTopK)).float()
return num_det, det_boxes, det_scores, det_classes
@staticmethod
def symbolic(
g,
boxes: Tensor,
scores: Tensor,
plugin_version: str = '1',
shareLocation: int = 1,
backgroundLabelId: int = -1,
numClasses: int = 80,
topK: int = 1000,
keepTopK: int = 100,
scoreThreshold: float = 0.25,
iouThreshold: float = 0.45,
isNormalized: int = 0,
clipBoxes: int = 0,
scoreBits: int = 16,
caffeSemantics: int = 1,
):
out = g.op(
'TRT::BatchedNMSDynamic_TRT',
boxes,
scores,
shareLocation_i=shareLocation,
plugin_version_s=plugin_version,
backgroundLabelId_i=backgroundLabelId,
numClasses_i=numClasses,
topK_i=topK,
keepTopK_i=keepTopK,
scoreThreshold_f=scoreThreshold,
iouThreshold_f=iouThreshold,
isNormalized_i=isNormalized,
clipBoxes_i=clipBoxes,
scoreBits_i=scoreBits,
caffeSemantics_i=caffeSemantics,
outputs=4)
num_det, det_boxes, det_scores, det_classes = out
return num_det, det_boxes, det_scores, det_classes
def _efficient_nms(
boxes: Tensor,
scores: Tensor,
max_output_boxes_per_class: int = 1000,
iou_threshold: float = 0.5,
score_threshold: float = 0.05,
pre_top_k: int = -1,
keep_top_k: int = 100,
box_coding: int = 0,
):
"""Wrapper for `efficient_nms` with TensorRT.
Args:
boxes (Tensor): The bounding boxes of shape [N, num_boxes, 4].
scores (Tensor): The detection scores of shape
[N, num_boxes, num_classes].
max_output_boxes_per_class (int): Maximum number of output
boxes per class of nms. Defaults to 1000.
iou_threshold (float): IOU threshold of nms. Defaults to 0.5.
score_threshold (float): score threshold of nms.
Defaults to 0.05.
pre_top_k (int): Number of top K boxes to keep before nms.
Defaults to -1.
keep_top_k (int): Number of top K boxes to keep after nms.
Defaults to -1.
box_coding (int): Bounding boxes format for nms.
Defaults to 0 means [x1, y1 ,x2, y2].
Set to 1 means [x, y, w, h].
Returns:
tuple[Tensor, Tensor, Tensor, Tensor]:
(num_det, det_boxes, det_scores, det_classes),
`num_det` of shape [N, 1]
`det_boxes` of shape [N, num_det, 4]
`det_scores` of shape [N, num_det]
`det_classes` of shape [N, num_det]
"""
num_det, det_boxes, det_scores, det_classes = TRTEfficientNMSop.apply(
boxes, scores, -1, box_coding, iou_threshold, keep_top_k, '1', 0,
score_threshold)
return num_det, det_boxes, det_scores, det_classes
def _batched_nms(
boxes: Tensor,
scores: Tensor,
max_output_boxes_per_class: int = 1000,
iou_threshold: float = 0.5,
score_threshold: float = 0.05,
pre_top_k: int = -1,
keep_top_k: int = 100,
box_coding: int = 0,
):
"""Wrapper for `efficient_nms` with TensorRT.
Args:
boxes (Tensor): The bounding boxes of shape [N, num_boxes, 4].
scores (Tensor): The detection scores of shape
[N, num_boxes, num_classes].
max_output_boxes_per_class (int): Maximum number of output
boxes per class of nms. Defaults to 1000.
iou_threshold (float): IOU threshold of nms. Defaults to 0.5.
score_threshold (float): score threshold of nms.
Defaults to 0.05.
pre_top_k (int): Number of top K boxes to keep before nms.
Defaults to -1.
keep_top_k (int): Number of top K boxes to keep after nms.
Defaults to -1.
box_coding (int): Bounding boxes format for nms.
Defaults to 0 means [x1, y1 ,x2, y2].
Set to 1 means [x, y, w, h].
Returns:
tuple[Tensor, Tensor, Tensor, Tensor]:
(num_det, det_boxes, det_scores, det_classes),
`num_det` of shape [N, 1]
`det_boxes` of shape [N, num_det, 4]
`det_scores` of shape [N, num_det]
`det_classes` of shape [N, num_det]
"""
boxes = boxes if boxes.dim() == 4 else boxes.unsqueeze(2)
_, _, numClasses = scores.shape
num_det, det_boxes, det_scores, det_classes = TRTbatchedNMSop.apply(
boxes, scores, '1', 1, -1, int(numClasses), min(pre_top_k, 4096),
keep_top_k, score_threshold, iou_threshold, 0, 0, 16, 1)
det_classes = det_classes.int()
return num_det, det_boxes, det_scores, det_classes
def efficient_nms(*args, **kwargs):
"""Wrapper function for `_efficient_nms`."""
return _efficient_nms(*args, **kwargs)
def batched_nms(*args, **kwargs):
"""Wrapper function for `_batched_nms`."""
return _batched_nms(*args, **kwargs)

View File

@ -0,0 +1,43 @@
import argparse
from ..model import EngineBuilder
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('checkpoint', help='Checkpoint file')
parser.add_argument(
'--img-size',
nargs='+',
type=int,
default=[640, 640],
help='Image size of height and width')
parser.add_argument(
'--device', type=str, default='cuda:0', help='TensorRT builder device')
parser.add_argument(
'--scales',
type=str,
default='[[1,3,640,640],[1,3,640,640],[1,3,640,640]]',
help='Input scales for build dynamic input shape engine')
parser.add_argument(
'--fp16', action='store_true', help='Build model with fp16 mode')
args = parser.parse_args()
args.img_size *= 2 if len(args.img_size) == 1 else 1
return args
def main(args):
img_size = (1, 3, *args.img_size)
try:
scales = eval(args.scales)
except Exception:
print('Input scales is not a python variable')
print('Set scales default None')
scales = None
builder = EngineBuilder(args.checkpoint, img_size, args.device)
builder.build(scales, fp16=args.fp16)
if __name__ == '__main__':
args = parse_args()
main(args)

View File

@ -0,0 +1,135 @@
import argparse
import os
import warnings
from io import BytesIO
import onnx
import torch
from mmdet.apis import init_detector
from mmengine.config import ConfigDict
from mmyolo.utils import register_all_modules
from projects.easydeploy.model import DeployModel
warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning)
warnings.filterwarnings(action='ignore', category=torch.jit.ScriptWarning)
warnings.filterwarnings(action='ignore', category=UserWarning)
warnings.filterwarnings(action='ignore', category=FutureWarning)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('config', help='Config file')
parser.add_argument('checkpoint', help='Checkpoint file')
parser.add_argument(
'--work-dir', default='./work_dir', help='Path to save export model')
parser.add_argument(
'--img-size',
nargs='+',
type=int,
default=[640, 640],
help='Image size of height and width')
parser.add_argument('--batch-size', type=int, default=1, help='Batch size')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
parser.add_argument(
'--simplify',
action='store_true',
help='Simplify onnx model by onnx-sim')
parser.add_argument(
'--opset', type=int, default=11, help='ONNX opset version')
parser.add_argument(
'--backend', type=int, default=1, help='Backend for export onnx')
parser.add_argument(
'--pre-topk',
type=int,
default=1000,
help='Postprocess pre topk bboxes feed into NMS')
parser.add_argument(
'--keep-topk',
type=int,
default=100,
help='Postprocess keep topk bboxes out of NMS')
parser.add_argument(
'--iou-threshold',
type=float,
default=0.65,
help='IoU threshold for NMS')
parser.add_argument(
'--score-threshold',
type=float,
default=0.25,
help='Score threshold for NMS')
args = parser.parse_args()
args.img_size *= 2 if len(args.img_size) == 1 else 1
return args
def build_model_from_cfg(config_path, checkpoint_path, device):
model = init_detector(config_path, checkpoint_path, device=device)
model.eval()
return model
def main():
args = parse_args()
register_all_modules()
if not os.path.exists(args.work_dir):
os.mkdir(args.work_dir)
postprocess_cfg = ConfigDict(
pre_top_k=args.pre_topk,
keep_top_k=args.keep_topk,
iou_threshold=args.iou_threshold,
score_threshold=args.score_threshold,
backend=args.backend)
baseModel = build_model_from_cfg(args.config, args.checkpoint, args.device)
deploy_model = DeployModel(
baseModel=baseModel, postprocess_cfg=postprocess_cfg)
deploy_model.eval()
fake_input = torch.randn(args.batch_size, 3,
*args.img_size).to(args.device)
# dry run
deploy_model(fake_input)
save_onnx_path = os.path.join(args.work_dir, 'end2end.onnx')
# export onnx
with BytesIO() as f:
torch.onnx.export(
deploy_model,
fake_input,
f,
input_names=['images'],
output_names=['num_det', 'det_boxes', 'det_scores', 'det_classes'],
opset_version=args.opset)
f.seek(0)
onnx_model = onnx.load(f)
onnx.checker.check_model(onnx_model)
# Fix tensorrt onnx output shape, just for view
if args.backend in (2, 3):
shapes = [
args.batch_size, 1, args.batch_size, args.keep_topk, 4,
args.batch_size, args.keep_topk, args.batch_size,
args.keep_topk
]
for i in onnx_model.graph.output:
for j in i.type.tensor_type.shape.dim:
j.dim_param = str(shapes.pop(0))
if args.simplify:
try:
import onnxsim
onnx_model, check = onnxsim.simplify(onnx_model)
assert check, 'assert check failed'
except Exception as e:
print(f'Simplify failure: {e}')
onnx.save(onnx_model, save_onnx_path)
print(f'ONNX export success, save into {save_onnx_path}')
if __name__ == '__main__':
main()