mirror of https://github.com/open-mmlab/mmyolo.git
[Deploy] MMYOLO model convert to onnx for deployment. (#279)
* Fromat code * Support ONNXRUNTIME * Support mmyolo model convert to onnx for deploy. * Same as dev branch * Support yolox focus rewrite * Support GConv Focus * Update mmyolo/easydeploy/backbone/focus.py Co-authored-by: HinGwenWoong <peterhuang0323@qq.com> * Add TensorRT build/infer Wrapper * Add image_demo for deploy model * Fix * Merge dev * Remove image-demo * Roll back to dev * Support model switch to deploy * Remove --deploy * Add new deploy method * Format code and add doc * Move md to project * add readme and readme_zh to easy_deploy * Update projects/easydeploy/README_zh-CN.md Co-authored-by: Haian Huang(深度眸) <1286304229@qq.com> * Update projects/easydeploy/README.md Co-authored-by: Haian Huang(深度眸) <1286304229@qq.com> Co-authored-by: HinGwenWoong <peterhuang0323@qq.com> Co-authored-by: xin-li-67 <williamlee.xin@gmail.com> Co-authored-by: Xin Li <7219519+xin-li-67@users.noreply.github.com> Co-authored-by: Haian Huang(深度眸) <1286304229@qq.com>pull/315/head
parent
b1e478a8bd
commit
1045b41b68
|
@ -0,0 +1,11 @@
|
|||
# MMYOLO Model Easy-Deployment
|
||||
|
||||
## Introduction
|
||||
|
||||
This project is developed for easily converting your MMYOLO models to other inference backends without the need of MMDeploy, which reduces the cost of both time and effort on getting familiar with MMDeploy.
|
||||
|
||||
Currently we support converting to `ONNX` and `TensorRT` formats, other inference backends such `ncnn` will be added to this project as well.
|
||||
|
||||
## Supported Backends
|
||||
|
||||
- [Model Convert](docs/model_convert.md)
|
|
@ -0,0 +1,11 @@
|
|||
# MMYOLO 模型转换
|
||||
|
||||
## 介绍
|
||||
|
||||
本项目作为 MMYOLO 的部署 project 单独存在,意图剥离 MMDeploy 当前的体系,独自支持用户完成模型训练后的转换和部署功能,使用户的学习和工程成本下降。
|
||||
|
||||
当前支持对 ONNX 格式和 TensorRT 格式的转换,后续对其他推理平台也会支持起来。
|
||||
|
||||
## 转换教程
|
||||
|
||||
- [Model Convert](docs/model_convert.md)
|
|
@ -0,0 +1,4 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .focus import DeployFocus, GConvFocus, NcnnFocus
|
||||
|
||||
__all__ = ['DeployFocus', 'NcnnFocus', 'GConvFocus']
|
|
@ -0,0 +1,79 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
class DeployFocus(nn.Module):
|
||||
|
||||
def __init__(self, orin_Focus: nn.Module):
|
||||
super().__init__()
|
||||
self.__dict__.update(orin_Focus.__dict__)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
batch_size, channel, height, width = x.shape
|
||||
x = x.reshape(batch_size, channel, -1, 2, width)
|
||||
x = x.reshape(batch_size, channel, x.shape[2], 2, -1, 2)
|
||||
half_h = x.shape[2]
|
||||
half_w = x.shape[4]
|
||||
x = x.permute(0, 5, 3, 1, 2, 4)
|
||||
x = x.reshape(batch_size, channel * 4, half_h, half_w)
|
||||
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class NcnnFocus(nn.Module):
|
||||
|
||||
def __init__(self, orin_Focus: nn.Module):
|
||||
super().__init__()
|
||||
self.__dict__.update(orin_Focus.__dict__)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
batch_size, c, h, w = x.shape
|
||||
assert h % 2 == 0 and w % 2 == 0, f'focus for yolox needs even feature\
|
||||
height and width, got {(h, w)}.'
|
||||
|
||||
x = x.reshape(batch_size, c * h, 1, w)
|
||||
_b, _c, _h, _w = x.shape
|
||||
g = _c // 2
|
||||
# fuse to ncnn's shufflechannel
|
||||
x = x.view(_b, g, 2, _h, _w)
|
||||
x = torch.transpose(x, 1, 2).contiguous()
|
||||
x = x.view(_b, -1, _h, _w)
|
||||
|
||||
x = x.reshape(_b, c * h * w, 1, 1)
|
||||
|
||||
_b, _c, _h, _w = x.shape
|
||||
g = _c // 2
|
||||
# fuse to ncnn's shufflechannel
|
||||
x = x.view(_b, g, 2, _h, _w)
|
||||
x = torch.transpose(x, 1, 2).contiguous()
|
||||
x = x.view(_b, -1, _h, _w)
|
||||
|
||||
x = x.reshape(_b, c * 4, h // 2, w // 2)
|
||||
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class GConvFocus(nn.Module):
|
||||
|
||||
def __init__(self, orin_Focus: nn.Module):
|
||||
super().__init__()
|
||||
device = next(orin_Focus.parameters()).device
|
||||
self.weight1 = torch.tensor([[1., 0], [0, 0]]).expand(3, 1, 2,
|
||||
2).to(device)
|
||||
self.weight2 = torch.tensor([[0, 0], [1., 0]]).expand(3, 1, 2,
|
||||
2).to(device)
|
||||
self.weight3 = torch.tensor([[0, 1.], [0, 0]]).expand(3, 1, 2,
|
||||
2).to(device)
|
||||
self.weight4 = torch.tensor([[0, 0], [0, 1.]]).expand(3, 1, 2,
|
||||
2).to(device)
|
||||
self.__dict__.update(orin_Focus.__dict__)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
conv1 = F.conv2d(x, self.weight1, stride=2, groups=3)
|
||||
conv2 = F.conv2d(x, self.weight2, stride=2, groups=3)
|
||||
conv3 = F.conv2d(x, self.weight3, stride=2, groups=3)
|
||||
conv4 = F.conv2d(x, self.weight4, stride=2, groups=3)
|
||||
return self.conv(torch.cat([conv1, conv2, conv3, conv4], dim=1))
|
|
@ -0,0 +1,4 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .bbox_coder import rtmdet_bbox_decoder, yolov5_bbox_decoder
|
||||
|
||||
__all__ = ['yolov5_bbox_decoder', 'rtmdet_bbox_decoder']
|
|
@ -0,0 +1,35 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def yolov5_bbox_decoder(priors: Tensor, bbox_preds: Tensor,
|
||||
stride: Tensor) -> Tensor:
|
||||
bbox_preds = bbox_preds.sigmoid()
|
||||
|
||||
x_center = (priors[..., 0] + priors[..., 2]) * 0.5
|
||||
y_center = (priors[..., 1] + priors[..., 3]) * 0.5
|
||||
w = priors[..., 2] - priors[..., 0]
|
||||
h = priors[..., 3] - priors[..., 1]
|
||||
|
||||
x_center_pred = (bbox_preds[..., 0] - 0.5) * 2 * stride + x_center
|
||||
y_center_pred = (bbox_preds[..., 1] - 0.5) * 2 * stride + y_center
|
||||
w_pred = (bbox_preds[..., 2] * 2)**2 * w
|
||||
h_pred = (bbox_preds[..., 3] * 2)**2 * h
|
||||
|
||||
decoded_bboxes = torch.stack(
|
||||
[x_center_pred, y_center_pred, w_pred, h_pred], dim=-1)
|
||||
|
||||
return decoded_bboxes
|
||||
|
||||
|
||||
def rtmdet_bbox_decoder(priors: Tensor, bbox_preds: Tensor,
|
||||
stride: Optional[Tensor]) -> Tensor:
|
||||
tl_x = (priors[..., 0] - bbox_preds[..., 0])
|
||||
tl_y = (priors[..., 1] - bbox_preds[..., 1])
|
||||
br_x = (priors[..., 0] + bbox_preds[..., 2])
|
||||
br_y = (priors[..., 1] + bbox_preds[..., 3])
|
||||
decoded_bboxes = torch.stack([tl_x, tl_y, br_x, br_y], -1)
|
||||
return decoded_bboxes
|
|
@ -0,0 +1,56 @@
|
|||
# MMYOLO 模型 ONNX 转换
|
||||
|
||||
## 环境依赖
|
||||
|
||||
- [onnx](https://github.com/onnx/onnx)
|
||||
|
||||
```shell
|
||||
pip install onnx
|
||||
```
|
||||
|
||||
[onnx-simplifier](https://github.com/daquexian/onnx-simplifier) (可选,用于简化模型)
|
||||
|
||||
```shell
|
||||
pip install onnx-simplifier
|
||||
```
|
||||
|
||||
## 使用方法
|
||||
|
||||
[模型导出脚本](./projects/easydeploy/tools/export.py)用于将 `MMYOLO` 模型转换为 `onnx` 。
|
||||
|
||||
### 参数介绍:
|
||||
|
||||
- `config` : 构建模型使用的配置文件,如 [`yolov5_s-v61_syncbn_fast_8xb16-300e_coco.py`](./configs/yolov5/yolov5_s-v61_syncbn_fast_8xb16-300e_coco.py) 。
|
||||
- `checkpoint` : 训练得到的权重文件,如 `yolov5s.pth` 。
|
||||
- `--work-dir` : 转换后的模型保存路径。
|
||||
- `--img-size`: 转换模型时输入的尺寸,如 `640 640`。
|
||||
- `--batch-size`: 转换后的模型输入 `batch size` 。
|
||||
- `--device`: 转换模型使用的设备,默认为 `cuda:0`。
|
||||
- `--simplify`: 是否简化导出的 `onnx` 模型,需要安装 [onnx-simplifier](https://github.com/daquexian/onnx-simplifier),默认关闭。
|
||||
- `--opset`: 指定导出 `onnx` 的 `opset`,默认为 `11` 。
|
||||
- `--backend`: 指定导出 `onnx` 用于的后端 id,`ONNXRuntime`: `1`, `TensorRT8`: `2`, `TensorRT7`: `3`,默认为`1`即 `ONNXRuntime`。
|
||||
- `--pre-topk`: 指定导出 `onnx` 的后处理筛选候选框个数阈值,默认为 `1000`。
|
||||
- `--keep-topk`: 指定导出 `onnx` 的非极大值抑制输出的候选框个数阈值,默认为 `100`。
|
||||
- `--iou-threshold`: 非极大值抑制中过滤重复候选框的 `iou` 阈值,默认为 `0.65`。
|
||||
- `--score-threshold`: 非极大值抑制中过滤候选框得分的阈值,默认为 `0.25`。
|
||||
|
||||
例子:
|
||||
|
||||
```shell
|
||||
python ./projects/easydeploy/tools/export.py \
|
||||
configs/yolov5/yolov5_s-v61_syncbn_fast_8xb16-300e_coco.py \
|
||||
yolov5s.pth \
|
||||
--work-dir work_dir \
|
||||
--img-size 640 640 \
|
||||
--batch 1 \
|
||||
--device cpu \
|
||||
--simplify \
|
||||
--opset 11 \
|
||||
--backend 1 \
|
||||
--pre-topk 1000 \
|
||||
--keep-topk 100 \
|
||||
--iou-threshold 0.65 \
|
||||
--score-threshold 0.25
|
||||
```
|
||||
|
||||
然后利用后端支持的工具如 `TensorRT` 读取 `onnx` 再次转换为后端支持的模型格式如 `.engine/.plan` 等
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .backendwrapper import BackendWrapper, EngineBuilder
|
||||
from .model import DeployModel
|
||||
|
||||
__all__ = ['DeployModel', 'BackendWrapper', 'EngineBuilder']
|
|
@ -0,0 +1,256 @@
|
|||
import warnings
|
||||
from collections import OrderedDict, namedtuple
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime
|
||||
import tensorrt as trt
|
||||
import torch
|
||||
from numpy import ndarray
|
||||
from torch import Tensor
|
||||
|
||||
warnings.filterwarnings(action='ignore', category=DeprecationWarning)
|
||||
|
||||
|
||||
class BackendWrapper:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight: Union[str, Path],
|
||||
device: Optional[Union[str, int, torch.device]] = None) -> None:
|
||||
weight = Path(weight) if isinstance(weight, str) else weight
|
||||
assert weight.exists() and weight.suffix in ('.onnx', '.engine',
|
||||
'.plan')
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
elif isinstance(device, int):
|
||||
device = torch.device(f'cuda:{device}')
|
||||
self.weight = weight
|
||||
self.device = device
|
||||
self.__build_model()
|
||||
self.__init_runtime()
|
||||
self.__warm_up(10)
|
||||
|
||||
def __build_model(self) -> None:
|
||||
model_info = dict()
|
||||
num_input = num_output = 0
|
||||
names = []
|
||||
is_dynamic = False
|
||||
if self.weight.suffix == '.onnx':
|
||||
model_info['backend'] = 'ONNXRuntime'
|
||||
providers = ['CPUExecutionProvider']
|
||||
if 'cuda' in self.device.type:
|
||||
providers.insert(0, 'CUDAExecutionProvider')
|
||||
model = onnxruntime.InferenceSession(
|
||||
str(self.weight), providers=providers)
|
||||
for i, tensor in enumerate(model.get_inputs()):
|
||||
model_info[tensor.name] = dict(
|
||||
shape=tensor.shape, dtype=tensor.type)
|
||||
num_input += 1
|
||||
names.append(tensor.name)
|
||||
is_dynamic |= any(
|
||||
map(lambda x: isinstance(x, str), tensor.shape))
|
||||
for i, tensor in enumerate(model.get_outputs()):
|
||||
model_info[tensor.name] = dict(
|
||||
shape=tensor.shape, dtype=tensor.type)
|
||||
num_output += 1
|
||||
names.append(tensor.name)
|
||||
else:
|
||||
model_info['backend'] = 'TensorRT'
|
||||
logger = trt.Logger(trt.Logger.ERROR)
|
||||
trt.init_libnvinfer_plugins(logger, namespace='')
|
||||
with trt.Runtime(logger) as runtime:
|
||||
model = runtime.deserialize_cuda_engine(
|
||||
self.weight.read_bytes())
|
||||
profile_shape = []
|
||||
for i in range(model.num_bindings):
|
||||
name = model.get_binding_name(i)
|
||||
shape = tuple(model.get_binding_shape(i))
|
||||
dtype = trt.nptype(model.get_binding_dtype(i))
|
||||
is_dynamic |= (-1 in shape)
|
||||
if model.binding_is_input(i):
|
||||
num_input += 1
|
||||
profile_shape.append(model.get_profile_shape(i, 0))
|
||||
else:
|
||||
num_output += 1
|
||||
model_info[name] = dict(shape=shape, dtype=dtype)
|
||||
names.append(name)
|
||||
model_info['profile_shape'] = profile_shape
|
||||
|
||||
self.num_input = num_input
|
||||
self.num_output = num_output
|
||||
self.names = names
|
||||
self.is_dynamic = is_dynamic
|
||||
self.model = model
|
||||
self.model_info = model_info
|
||||
|
||||
def __init_runtime(self) -> None:
|
||||
bindings = OrderedDict()
|
||||
Binding = namedtuple('Binding',
|
||||
('name', 'dtype', 'shape', 'data', 'ptr'))
|
||||
if self.model_info['backend'] == 'TensorRT':
|
||||
context = self.model.create_execution_context()
|
||||
for name in self.names:
|
||||
shape, dtype = self.model_info[name].values()
|
||||
if self.is_dynamic:
|
||||
cpu_tensor, gpu_tensor, ptr = None, None, None
|
||||
else:
|
||||
cpu_tensor = np.empty(shape, dtype=np.dtype(dtype))
|
||||
gpu_tensor = torch.from_numpy(cpu_tensor).to(self.device)
|
||||
ptr = int(gpu_tensor.data_ptr())
|
||||
bindings[name] = Binding(name, dtype, shape, gpu_tensor, ptr)
|
||||
else:
|
||||
output_names = []
|
||||
for i, name in enumerate(self.names):
|
||||
if i >= self.num_input:
|
||||
output_names.append(name)
|
||||
shape, dtype = self.model_info[name].values()
|
||||
bindings[name] = Binding(name, dtype, shape, None, None)
|
||||
context = partial(self.model.run, output_names)
|
||||
self.addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
|
||||
self.bindings = bindings
|
||||
self.context = context
|
||||
|
||||
def __infer(
|
||||
self, inputs: List[Union[ndarray,
|
||||
Tensor]]) -> List[Union[ndarray, Tensor]]:
|
||||
assert len(inputs) == self.num_input
|
||||
if self.model_info['backend'] == 'TensorRT':
|
||||
outputs = []
|
||||
for i, (name, gpu_input) in enumerate(
|
||||
zip(self.names[:self.num_input], inputs)):
|
||||
if self.is_dynamic:
|
||||
self.context.set_binding_shape(i, gpu_input.shape)
|
||||
self.addrs[name] = gpu_input.data_ptr()
|
||||
|
||||
for i, name in enumerate(self.names[self.num_input:]):
|
||||
i += self.num_input
|
||||
if self.is_dynamic:
|
||||
shape = tuple(self.context.get_binding_shape(i))
|
||||
dtype = self.bindings[name].dtype
|
||||
cpu_tensor = np.empty(shape, dtype=np.dtype(dtype))
|
||||
out = torch.from_numpy(cpu_tensor).to(self.device)
|
||||
self.addrs[name] = out.data_ptr()
|
||||
else:
|
||||
out = self.bindings[name].data
|
||||
outputs.append(out)
|
||||
assert self.context.execute_v2(list(
|
||||
self.addrs.values())), 'Infer fault'
|
||||
else:
|
||||
input_feed = {
|
||||
name: inputs[i]
|
||||
for i, name in enumerate(self.names[:self.num_input])
|
||||
}
|
||||
outputs = self.context(input_feed)
|
||||
return outputs
|
||||
|
||||
def __warm_up(self, n=10) -> None:
|
||||
for _ in range(n):
|
||||
_tmp = []
|
||||
if self.model_info['backend'] == 'TensorRT':
|
||||
for i, name in enumerate(self.names[:self.num_input]):
|
||||
if self.is_dynamic:
|
||||
shape = self.model_info['profile_shape'][i][1]
|
||||
dtype = self.bindings[name].dtype
|
||||
cpu_tensor = np.empty(shape, dtype=np.dtype(dtype))
|
||||
_tmp.append(
|
||||
torch.from_numpy(cpu_tensor).to(self.device))
|
||||
else:
|
||||
_tmp.append(self.bindings[name].data)
|
||||
else:
|
||||
print('Please warm up ONNXRuntime model by yourself')
|
||||
print("So this model doesn't warm up")
|
||||
return
|
||||
_ = self.__infer(_tmp)
|
||||
|
||||
def __call__(
|
||||
self, inputs: Union[List, Tensor,
|
||||
ndarray]) -> List[Union[Tensor, ndarray]]:
|
||||
if not isinstance(inputs, list):
|
||||
inputs = [inputs]
|
||||
outputs = self.__infer(inputs)
|
||||
return outputs
|
||||
|
||||
|
||||
class EngineBuilder:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
checkpoint: Union[str, Path],
|
||||
opt_shape: Union[Tuple, List] = (1, 3, 640, 640),
|
||||
device: Optional[Union[str, int, torch.device]] = None) -> None:
|
||||
checkpoint = Path(checkpoint) if isinstance(checkpoint,
|
||||
str) else checkpoint
|
||||
assert checkpoint.exists() and checkpoint.suffix == '.onnx'
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
elif isinstance(device, int):
|
||||
device = torch.device(f'cuda:{device}')
|
||||
|
||||
self.checkpoint = checkpoint
|
||||
self.opt_shape = np.array(opt_shape, dtype=np.float32)
|
||||
self.device = device
|
||||
|
||||
def __build_engine(self,
|
||||
scale: Optional[List[List]] = None,
|
||||
fp16: bool = True,
|
||||
with_profiling: bool = True) -> None:
|
||||
logger = trt.Logger(trt.Logger.WARNING)
|
||||
trt.init_libnvinfer_plugins(logger, namespace='')
|
||||
builder = trt.Builder(logger)
|
||||
config = builder.create_builder_config()
|
||||
config.max_workspace_size = torch.cuda.get_device_properties(
|
||||
self.device).total_memory
|
||||
flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
|
||||
network = builder.create_network(flag)
|
||||
parser = trt.OnnxParser(network, logger)
|
||||
if not parser.parse_from_file(str(self.checkpoint)):
|
||||
raise RuntimeError(
|
||||
f'failed to load ONNX file: {str(self.checkpoint)}')
|
||||
inputs = [network.get_input(i) for i in range(network.num_inputs)]
|
||||
outputs = [network.get_output(i) for i in range(network.num_outputs)]
|
||||
profile = None
|
||||
dshape = -1 in network.get_input(0).shape
|
||||
if dshape:
|
||||
profile = builder.create_optimization_profile()
|
||||
if scale is None:
|
||||
scale = np.array(
|
||||
[[1, 1, 0.5, 0.5], [1, 1, 1, 1], [4, 1, 1.5, 1.5]],
|
||||
dtype=np.float32)
|
||||
scale = (self.opt_shape * scale).astype(np.int32)
|
||||
elif isinstance(scale, List):
|
||||
scale = np.array(scale, dtype=np.int32)
|
||||
assert scale.shape[0] == 3, 'Input a wrong scale list'
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
for inp in inputs:
|
||||
logger.log(
|
||||
trt.Logger.WARNING,
|
||||
f'input "{inp.name}" with shape{inp.shape} {inp.dtype}')
|
||||
if dshape:
|
||||
profile.set_shape(inp.name, *scale)
|
||||
for out in outputs:
|
||||
logger.log(
|
||||
trt.Logger.WARNING,
|
||||
f'output "{out.name}" with shape{out.shape} {out.dtype}')
|
||||
if fp16 and builder.platform_has_fast_fp16:
|
||||
config.set_flag(trt.BuilderFlag.FP16)
|
||||
self.weight = self.checkpoint.with_suffix('.engine')
|
||||
if dshape:
|
||||
config.add_optimization_profile(profile)
|
||||
if with_profiling:
|
||||
config.profiling_verbosity = trt.ProfilingVerbosity.DETAILED
|
||||
with builder.build_engine(network, config) as engine:
|
||||
self.weight.write_bytes(engine.serialize())
|
||||
logger.log(
|
||||
trt.Logger.WARNING, f'Build tensorrt engine finish.\n'
|
||||
f'Save in {str(self.weight.absolute())}')
|
||||
|
||||
def build(self,
|
||||
scale: Optional[List[List]] = None,
|
||||
fp16: bool = True,
|
||||
with_profiling=True):
|
||||
self.__build_engine(scale, fp16, with_profiling)
|
|
@ -0,0 +1,144 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from functools import partial
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmdet.models.backbones.csp_darknet import Focus
|
||||
from mmengine.config import ConfigDict
|
||||
from torch import Tensor
|
||||
|
||||
from mmyolo.models import RepVGGBlock
|
||||
from mmyolo.models.dense_heads import RTMDetHead, YOLOv5Head
|
||||
from ..backbone import DeployFocus, GConvFocus, NcnnFocus
|
||||
from ..bbox_code import rtmdet_bbox_decoder, yolov5_bbox_decoder
|
||||
from ..nms import batched_nms, efficient_nms, onnx_nms
|
||||
|
||||
|
||||
class DeployModel(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
baseModel: nn.Module,
|
||||
postprocess_cfg: Optional[ConfigDict] = None):
|
||||
super().__init__()
|
||||
self.baseModel = baseModel
|
||||
self.baseHead = baseModel.bbox_head
|
||||
self.__init_sub_attributes()
|
||||
detector_type = type(self.baseHead)
|
||||
if postprocess_cfg is None:
|
||||
pre_top_k = 1000
|
||||
keep_top_k = 100
|
||||
iou_threshold = 0.65
|
||||
score_threshold = 0.25
|
||||
backend = 1
|
||||
else:
|
||||
pre_top_k = postprocess_cfg.get('pre_top_k', 1000)
|
||||
keep_top_k = postprocess_cfg.get('keep_top_k', 100)
|
||||
iou_threshold = postprocess_cfg.get('iou_threshold', 0.65)
|
||||
score_threshold = postprocess_cfg.get('score_threshold', 0.25)
|
||||
backend = postprocess_cfg.get('backend', 1)
|
||||
self.__switch_deploy()
|
||||
self.__dict__.update(locals())
|
||||
|
||||
def __init_sub_attributes(self):
|
||||
self.bbox_decoder = self.baseHead.bbox_coder.decode
|
||||
self.prior_generate = self.baseHead.prior_generator.grid_priors
|
||||
self.num_base_priors = self.baseHead.num_base_priors
|
||||
self.featmap_strides = self.baseHead.featmap_strides
|
||||
self.num_classes = self.baseHead.num_classes
|
||||
|
||||
def __switch_deploy(self):
|
||||
for layer in self.baseModel.modules():
|
||||
if isinstance(layer, RepVGGBlock):
|
||||
layer.switch_to_deploy()
|
||||
if isinstance(layer, Focus):
|
||||
# onnxruntime tensorrt8 tensorrt7
|
||||
if self.backend in (1, 2, 3):
|
||||
self.baseModel.backbone.stem = DeployFocus(layer)
|
||||
# ncnn
|
||||
elif self.backend == 4:
|
||||
self.baseModel.backbone.stem = NcnnFocus(layer)
|
||||
# switch focus to group conv
|
||||
else:
|
||||
self.baseModel.backbone.stem = GConvFocus(layer)
|
||||
|
||||
def pred_by_feat(self,
|
||||
cls_scores: List[Tensor],
|
||||
bbox_preds: List[Tensor],
|
||||
objectnesses: Optional[List[Tensor]] = None,
|
||||
**kwargs):
|
||||
assert len(cls_scores) == len(bbox_preds)
|
||||
dtype = cls_scores[0].dtype
|
||||
device = cls_scores[0].device
|
||||
|
||||
nms_func = self.select_nms()
|
||||
if self.detector_type is YOLOv5Head:
|
||||
bbox_decoder = yolov5_bbox_decoder
|
||||
elif self.detector_type is RTMDetHead:
|
||||
bbox_decoder = rtmdet_bbox_decoder
|
||||
else:
|
||||
bbox_decoder = self.bbox_decoder
|
||||
|
||||
num_imgs = cls_scores[0].shape[0]
|
||||
featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores]
|
||||
|
||||
mlvl_priors = self.prior_generate(
|
||||
featmap_sizes, dtype=dtype, device=device)
|
||||
|
||||
flatten_priors = torch.cat(mlvl_priors)
|
||||
|
||||
mlvl_strides = [
|
||||
flatten_priors.new_full(
|
||||
(featmap_size[0] * featmap_size[1] * self.num_base_priors, ),
|
||||
stride) for featmap_size, stride in zip(
|
||||
featmap_sizes, self.featmap_strides)
|
||||
]
|
||||
flatten_stride = torch.cat(mlvl_strides)
|
||||
|
||||
# flatten cls_scores, bbox_preds and objectness
|
||||
flatten_cls_scores = [
|
||||
cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1,
|
||||
self.num_classes)
|
||||
for cls_score in cls_scores
|
||||
]
|
||||
cls_scores = torch.cat(flatten_cls_scores, dim=1).sigmoid()
|
||||
|
||||
flatten_bbox_preds = [
|
||||
bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
|
||||
for bbox_pred in bbox_preds
|
||||
]
|
||||
flatten_bbox_preds = torch.cat(flatten_bbox_preds, dim=1)
|
||||
|
||||
if objectnesses is not None:
|
||||
flatten_objectness = [
|
||||
objectness.permute(0, 2, 3, 1).reshape(num_imgs, -1)
|
||||
for objectness in objectnesses
|
||||
]
|
||||
flatten_objectness = torch.cat(flatten_objectness, dim=1).sigmoid()
|
||||
cls_scores = cls_scores * (flatten_objectness.unsqueeze(-1))
|
||||
|
||||
scores = cls_scores
|
||||
|
||||
bboxes = bbox_decoder(flatten_priors[None], flatten_bbox_preds,
|
||||
flatten_stride)
|
||||
|
||||
return nms_func(bboxes, scores, self.keep_top_k, self.iou_threshold,
|
||||
self.score_threshold, self.pre_top_k, self.keep_top_k)
|
||||
|
||||
def select_nms(self):
|
||||
if self.backend == 1:
|
||||
nms_func = onnx_nms
|
||||
elif self.backend == 2:
|
||||
nms_func = efficient_nms
|
||||
elif self.backend == 3:
|
||||
nms_func = batched_nms
|
||||
else:
|
||||
raise NotImplementedError
|
||||
if type(self.baseHead) is YOLOv5Head:
|
||||
nms_func = partial(nms_func, box_coding=1)
|
||||
return nms_func
|
||||
|
||||
def forward(self, inputs: Tensor):
|
||||
neck_outputs = self.baseModel(inputs)
|
||||
outputs = self.pred_by_feat(*neck_outputs)
|
||||
return outputs
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .ort_nms import onnx_nms
|
||||
from .trt_nms import batched_nms, efficient_nms
|
||||
|
||||
__all__ = ['efficient_nms', 'batched_nms', 'onnx_nms']
|
|
@ -0,0 +1,122 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
_XYWH2XYXY = torch.tensor([[1.0, 0.0, 1.0, 0.0], [0.0, 1.0, 0.0, 1.0],
|
||||
[-0.5, 0.0, 0.5, 0.0], [0.0, -0.5, 0.0, 0.5]],
|
||||
dtype=torch.float32)
|
||||
|
||||
|
||||
def select_nms_index(scores: Tensor,
|
||||
boxes: Tensor,
|
||||
nms_index: Tensor,
|
||||
batch_size: int,
|
||||
keep_top_k: int = -1):
|
||||
batch_inds, cls_inds = nms_index[:, 0], nms_index[:, 1]
|
||||
box_inds = nms_index[:, 2]
|
||||
|
||||
scores = scores[batch_inds, cls_inds, box_inds].unsqueeze(1)
|
||||
boxes = boxes[batch_inds, box_inds, ...]
|
||||
dets = torch.cat([boxes, scores], dim=1)
|
||||
|
||||
batched_dets = dets.unsqueeze(0).repeat(batch_size, 1, 1)
|
||||
batch_template = torch.arange(
|
||||
0, batch_size, dtype=batch_inds.dtype, device=batch_inds.device)
|
||||
batched_dets = batched_dets.where(
|
||||
(batch_inds == batch_template.unsqueeze(1)).unsqueeze(-1),
|
||||
batched_dets.new_zeros(1))
|
||||
|
||||
batched_labels = cls_inds.unsqueeze(0).repeat(batch_size, 1)
|
||||
batched_labels = batched_labels.where(
|
||||
(batch_inds == batch_template.unsqueeze(1)),
|
||||
batched_labels.new_ones(1) * -1)
|
||||
|
||||
N = batched_dets.shape[0]
|
||||
|
||||
batched_dets = torch.cat((batched_dets, batched_dets.new_zeros((N, 1, 5))),
|
||||
1)
|
||||
batched_labels = torch.cat((batched_labels, -batched_labels.new_ones(
|
||||
(N, 1))), 1)
|
||||
|
||||
_, topk_inds = batched_dets[:, :, -1].sort(dim=1, descending=True)
|
||||
topk_batch_inds = torch.arange(
|
||||
batch_size, dtype=topk_inds.dtype,
|
||||
device=topk_inds.device).view(-1, 1)
|
||||
batched_dets = batched_dets[topk_batch_inds, topk_inds, ...]
|
||||
batched_labels = batched_labels[topk_batch_inds, topk_inds, ...]
|
||||
batched_dets, batched_scores = batched_dets.split([4, 1], 2)
|
||||
batched_scores = batched_scores.squeeze(-1)
|
||||
|
||||
num_dets = (batched_scores > 0).sum(1, keepdim=True)
|
||||
return num_dets, batched_dets, batched_scores, batched_labels
|
||||
|
||||
|
||||
class ONNXNMSop(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
boxes: Tensor,
|
||||
scores: Tensor,
|
||||
max_output_boxes_per_class: Tensor = torch.tensor([100]),
|
||||
iou_threshold: Tensor = torch.tensor([0.5]),
|
||||
score_threshold: Tensor = torch.tensor([0.05])
|
||||
) -> Tensor:
|
||||
device = boxes.device
|
||||
batch = scores.shape[0]
|
||||
num_det = 20
|
||||
batches = torch.randint(0, batch, (num_det, )).sort()[0].to(device)
|
||||
idxs = torch.arange(100, 100 + num_det).to(device)
|
||||
zeros = torch.zeros((num_det, ), dtype=torch.int64).to(device)
|
||||
selected_indices = torch.cat([batches[None], zeros[None], idxs[None]],
|
||||
0).T.contiguous()
|
||||
selected_indices = selected_indices.to(torch.int64)
|
||||
|
||||
return selected_indices
|
||||
|
||||
@staticmethod
|
||||
def symbolic(
|
||||
g,
|
||||
boxes: Tensor,
|
||||
scores: Tensor,
|
||||
max_output_boxes_per_class: Tensor = torch.tensor([100]),
|
||||
iou_threshold: Tensor = torch.tensor([0.5]),
|
||||
score_threshold: Tensor = torch.tensor([0.05]),
|
||||
):
|
||||
return g.op(
|
||||
'NonMaxSuppression',
|
||||
boxes,
|
||||
scores,
|
||||
max_output_boxes_per_class,
|
||||
iou_threshold,
|
||||
score_threshold,
|
||||
outputs=1)
|
||||
|
||||
|
||||
def onnx_nms(
|
||||
boxes: torch.Tensor,
|
||||
scores: torch.Tensor,
|
||||
max_output_boxes_per_class: int = 100,
|
||||
iou_threshold: float = 0.5,
|
||||
score_threshold: float = 0.05,
|
||||
pre_top_k: int = -1,
|
||||
keep_top_k: int = 100,
|
||||
box_coding: int = 0,
|
||||
):
|
||||
max_output_boxes_per_class = torch.tensor([max_output_boxes_per_class])
|
||||
iou_threshold = torch.tensor([iou_threshold])
|
||||
score_threshold = torch.tensor([score_threshold])
|
||||
|
||||
batch_size, _, _ = scores.shape
|
||||
if box_coding == 1:
|
||||
boxes = boxes @ (_XYWH2XYXY.to(boxes.device))
|
||||
scores = scores.transpose(1, 2).contiguous()
|
||||
selected_indices = ONNXNMSop.apply(boxes, scores,
|
||||
max_output_boxes_per_class,
|
||||
iou_threshold, score_threshold)
|
||||
|
||||
num_dets, batched_dets, batched_scores, batched_labels = select_nms_index(
|
||||
scores, boxes, selected_indices, batch_size, keep_top_k=keep_top_k)
|
||||
|
||||
return num_dets, batched_dets, batched_scores, batched_labels.to(
|
||||
torch.int32)
|
|
@ -0,0 +1,220 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
class TRTEfficientNMSop(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
boxes: Tensor,
|
||||
scores: Tensor,
|
||||
background_class: int = -1,
|
||||
box_coding: int = 0,
|
||||
iou_threshold: float = 0.45,
|
||||
max_output_boxes: int = 100,
|
||||
plugin_version: str = '1',
|
||||
score_activation: int = 0,
|
||||
score_threshold: float = 0.25,
|
||||
):
|
||||
batch_size, _, num_classes = scores.shape
|
||||
num_det = torch.randint(
|
||||
0, max_output_boxes, (batch_size, 1), dtype=torch.int32)
|
||||
det_boxes = torch.randn(batch_size, max_output_boxes, 4)
|
||||
det_scores = torch.randn(batch_size, max_output_boxes)
|
||||
det_classes = torch.randint(
|
||||
0, num_classes, (batch_size, max_output_boxes), dtype=torch.int32)
|
||||
return num_det, det_boxes, det_scores, det_classes
|
||||
|
||||
@staticmethod
|
||||
def symbolic(g,
|
||||
boxes: Tensor,
|
||||
scores: Tensor,
|
||||
background_class: int = -1,
|
||||
box_coding: int = 0,
|
||||
iou_threshold: float = 0.45,
|
||||
max_output_boxes: int = 100,
|
||||
plugin_version: str = '1',
|
||||
score_activation: int = 0,
|
||||
score_threshold: float = 0.25):
|
||||
out = g.op(
|
||||
'TRT::EfficientNMS_TRT',
|
||||
boxes,
|
||||
scores,
|
||||
background_class_i=background_class,
|
||||
box_coding_i=box_coding,
|
||||
iou_threshold_f=iou_threshold,
|
||||
max_output_boxes_i=max_output_boxes,
|
||||
plugin_version_s=plugin_version,
|
||||
score_activation_i=score_activation,
|
||||
score_threshold_f=score_threshold,
|
||||
outputs=4)
|
||||
num_det, det_boxes, det_scores, det_classes = out
|
||||
return num_det, det_boxes, det_scores, det_classes
|
||||
|
||||
|
||||
class TRTbatchedNMSop(torch.autograd.Function):
|
||||
"""TensorRT NMS operation."""
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
boxes: Tensor,
|
||||
scores: Tensor,
|
||||
plugin_version: str = '1',
|
||||
shareLocation: int = 1,
|
||||
backgroundLabelId: int = -1,
|
||||
numClasses: int = 80,
|
||||
topK: int = 1000,
|
||||
keepTopK: int = 100,
|
||||
scoreThreshold: float = 0.25,
|
||||
iouThreshold: float = 0.45,
|
||||
isNormalized: int = 0,
|
||||
clipBoxes: int = 0,
|
||||
scoreBits: int = 16,
|
||||
caffeSemantics: int = 1,
|
||||
):
|
||||
batch_size, _, numClasses = scores.shape
|
||||
num_det = torch.randint(
|
||||
0, keepTopK, (batch_size, 1), dtype=torch.int32)
|
||||
det_boxes = torch.randn(batch_size, keepTopK, 4)
|
||||
det_scores = torch.randn(batch_size, keepTopK)
|
||||
det_classes = torch.randint(0, numClasses,
|
||||
(batch_size, keepTopK)).float()
|
||||
return num_det, det_boxes, det_scores, det_classes
|
||||
|
||||
@staticmethod
|
||||
def symbolic(
|
||||
g,
|
||||
boxes: Tensor,
|
||||
scores: Tensor,
|
||||
plugin_version: str = '1',
|
||||
shareLocation: int = 1,
|
||||
backgroundLabelId: int = -1,
|
||||
numClasses: int = 80,
|
||||
topK: int = 1000,
|
||||
keepTopK: int = 100,
|
||||
scoreThreshold: float = 0.25,
|
||||
iouThreshold: float = 0.45,
|
||||
isNormalized: int = 0,
|
||||
clipBoxes: int = 0,
|
||||
scoreBits: int = 16,
|
||||
caffeSemantics: int = 1,
|
||||
):
|
||||
out = g.op(
|
||||
'TRT::BatchedNMSDynamic_TRT',
|
||||
boxes,
|
||||
scores,
|
||||
shareLocation_i=shareLocation,
|
||||
plugin_version_s=plugin_version,
|
||||
backgroundLabelId_i=backgroundLabelId,
|
||||
numClasses_i=numClasses,
|
||||
topK_i=topK,
|
||||
keepTopK_i=keepTopK,
|
||||
scoreThreshold_f=scoreThreshold,
|
||||
iouThreshold_f=iouThreshold,
|
||||
isNormalized_i=isNormalized,
|
||||
clipBoxes_i=clipBoxes,
|
||||
scoreBits_i=scoreBits,
|
||||
caffeSemantics_i=caffeSemantics,
|
||||
outputs=4)
|
||||
num_det, det_boxes, det_scores, det_classes = out
|
||||
return num_det, det_boxes, det_scores, det_classes
|
||||
|
||||
|
||||
def _efficient_nms(
|
||||
boxes: Tensor,
|
||||
scores: Tensor,
|
||||
max_output_boxes_per_class: int = 1000,
|
||||
iou_threshold: float = 0.5,
|
||||
score_threshold: float = 0.05,
|
||||
pre_top_k: int = -1,
|
||||
keep_top_k: int = 100,
|
||||
box_coding: int = 0,
|
||||
):
|
||||
"""Wrapper for `efficient_nms` with TensorRT.
|
||||
Args:
|
||||
boxes (Tensor): The bounding boxes of shape [N, num_boxes, 4].
|
||||
scores (Tensor): The detection scores of shape
|
||||
[N, num_boxes, num_classes].
|
||||
max_output_boxes_per_class (int): Maximum number of output
|
||||
boxes per class of nms. Defaults to 1000.
|
||||
iou_threshold (float): IOU threshold of nms. Defaults to 0.5.
|
||||
score_threshold (float): score threshold of nms.
|
||||
Defaults to 0.05.
|
||||
pre_top_k (int): Number of top K boxes to keep before nms.
|
||||
Defaults to -1.
|
||||
keep_top_k (int): Number of top K boxes to keep after nms.
|
||||
Defaults to -1.
|
||||
box_coding (int): Bounding boxes format for nms.
|
||||
Defaults to 0 means [x1, y1 ,x2, y2].
|
||||
Set to 1 means [x, y, w, h].
|
||||
Returns:
|
||||
tuple[Tensor, Tensor, Tensor, Tensor]:
|
||||
(num_det, det_boxes, det_scores, det_classes),
|
||||
`num_det` of shape [N, 1]
|
||||
`det_boxes` of shape [N, num_det, 4]
|
||||
`det_scores` of shape [N, num_det]
|
||||
`det_classes` of shape [N, num_det]
|
||||
"""
|
||||
num_det, det_boxes, det_scores, det_classes = TRTEfficientNMSop.apply(
|
||||
boxes, scores, -1, box_coding, iou_threshold, keep_top_k, '1', 0,
|
||||
score_threshold)
|
||||
return num_det, det_boxes, det_scores, det_classes
|
||||
|
||||
|
||||
def _batched_nms(
|
||||
boxes: Tensor,
|
||||
scores: Tensor,
|
||||
max_output_boxes_per_class: int = 1000,
|
||||
iou_threshold: float = 0.5,
|
||||
score_threshold: float = 0.05,
|
||||
pre_top_k: int = -1,
|
||||
keep_top_k: int = 100,
|
||||
box_coding: int = 0,
|
||||
):
|
||||
"""Wrapper for `efficient_nms` with TensorRT.
|
||||
Args:
|
||||
boxes (Tensor): The bounding boxes of shape [N, num_boxes, 4].
|
||||
scores (Tensor): The detection scores of shape
|
||||
[N, num_boxes, num_classes].
|
||||
max_output_boxes_per_class (int): Maximum number of output
|
||||
boxes per class of nms. Defaults to 1000.
|
||||
iou_threshold (float): IOU threshold of nms. Defaults to 0.5.
|
||||
score_threshold (float): score threshold of nms.
|
||||
Defaults to 0.05.
|
||||
pre_top_k (int): Number of top K boxes to keep before nms.
|
||||
Defaults to -1.
|
||||
keep_top_k (int): Number of top K boxes to keep after nms.
|
||||
Defaults to -1.
|
||||
box_coding (int): Bounding boxes format for nms.
|
||||
Defaults to 0 means [x1, y1 ,x2, y2].
|
||||
Set to 1 means [x, y, w, h].
|
||||
Returns:
|
||||
tuple[Tensor, Tensor, Tensor, Tensor]:
|
||||
(num_det, det_boxes, det_scores, det_classes),
|
||||
`num_det` of shape [N, 1]
|
||||
`det_boxes` of shape [N, num_det, 4]
|
||||
`det_scores` of shape [N, num_det]
|
||||
`det_classes` of shape [N, num_det]
|
||||
"""
|
||||
boxes = boxes if boxes.dim() == 4 else boxes.unsqueeze(2)
|
||||
_, _, numClasses = scores.shape
|
||||
|
||||
num_det, det_boxes, det_scores, det_classes = TRTbatchedNMSop.apply(
|
||||
boxes, scores, '1', 1, -1, int(numClasses), min(pre_top_k, 4096),
|
||||
keep_top_k, score_threshold, iou_threshold, 0, 0, 16, 1)
|
||||
|
||||
det_classes = det_classes.int()
|
||||
return num_det, det_boxes, det_scores, det_classes
|
||||
|
||||
|
||||
def efficient_nms(*args, **kwargs):
|
||||
"""Wrapper function for `_efficient_nms`."""
|
||||
return _efficient_nms(*args, **kwargs)
|
||||
|
||||
|
||||
def batched_nms(*args, **kwargs):
|
||||
"""Wrapper function for `_batched_nms`."""
|
||||
return _batched_nms(*args, **kwargs)
|
|
@ -0,0 +1,43 @@
|
|||
import argparse
|
||||
|
||||
from ..model import EngineBuilder
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('checkpoint', help='Checkpoint file')
|
||||
parser.add_argument(
|
||||
'--img-size',
|
||||
nargs='+',
|
||||
type=int,
|
||||
default=[640, 640],
|
||||
help='Image size of height and width')
|
||||
parser.add_argument(
|
||||
'--device', type=str, default='cuda:0', help='TensorRT builder device')
|
||||
parser.add_argument(
|
||||
'--scales',
|
||||
type=str,
|
||||
default='[[1,3,640,640],[1,3,640,640],[1,3,640,640]]',
|
||||
help='Input scales for build dynamic input shape engine')
|
||||
parser.add_argument(
|
||||
'--fp16', action='store_true', help='Build model with fp16 mode')
|
||||
args = parser.parse_args()
|
||||
args.img_size *= 2 if len(args.img_size) == 1 else 1
|
||||
return args
|
||||
|
||||
|
||||
def main(args):
|
||||
img_size = (1, 3, *args.img_size)
|
||||
try:
|
||||
scales = eval(args.scales)
|
||||
except Exception:
|
||||
print('Input scales is not a python variable')
|
||||
print('Set scales default None')
|
||||
scales = None
|
||||
builder = EngineBuilder(args.checkpoint, img_size, args.device)
|
||||
builder.build(scales, fp16=args.fp16)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
main(args)
|
|
@ -0,0 +1,135 @@
|
|||
import argparse
|
||||
import os
|
||||
import warnings
|
||||
from io import BytesIO
|
||||
|
||||
import onnx
|
||||
import torch
|
||||
from mmdet.apis import init_detector
|
||||
from mmengine.config import ConfigDict
|
||||
|
||||
from mmyolo.utils import register_all_modules
|
||||
from projects.easydeploy.model import DeployModel
|
||||
|
||||
warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning)
|
||||
warnings.filterwarnings(action='ignore', category=torch.jit.ScriptWarning)
|
||||
warnings.filterwarnings(action='ignore', category=UserWarning)
|
||||
warnings.filterwarnings(action='ignore', category=FutureWarning)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('config', help='Config file')
|
||||
parser.add_argument('checkpoint', help='Checkpoint file')
|
||||
parser.add_argument(
|
||||
'--work-dir', default='./work_dir', help='Path to save export model')
|
||||
parser.add_argument(
|
||||
'--img-size',
|
||||
nargs='+',
|
||||
type=int,
|
||||
default=[640, 640],
|
||||
help='Image size of height and width')
|
||||
parser.add_argument('--batch-size', type=int, default=1, help='Batch size')
|
||||
parser.add_argument(
|
||||
'--device', default='cuda:0', help='Device used for inference')
|
||||
parser.add_argument(
|
||||
'--simplify',
|
||||
action='store_true',
|
||||
help='Simplify onnx model by onnx-sim')
|
||||
parser.add_argument(
|
||||
'--opset', type=int, default=11, help='ONNX opset version')
|
||||
parser.add_argument(
|
||||
'--backend', type=int, default=1, help='Backend for export onnx')
|
||||
parser.add_argument(
|
||||
'--pre-topk',
|
||||
type=int,
|
||||
default=1000,
|
||||
help='Postprocess pre topk bboxes feed into NMS')
|
||||
parser.add_argument(
|
||||
'--keep-topk',
|
||||
type=int,
|
||||
default=100,
|
||||
help='Postprocess keep topk bboxes out of NMS')
|
||||
parser.add_argument(
|
||||
'--iou-threshold',
|
||||
type=float,
|
||||
default=0.65,
|
||||
help='IoU threshold for NMS')
|
||||
parser.add_argument(
|
||||
'--score-threshold',
|
||||
type=float,
|
||||
default=0.25,
|
||||
help='Score threshold for NMS')
|
||||
args = parser.parse_args()
|
||||
args.img_size *= 2 if len(args.img_size) == 1 else 1
|
||||
return args
|
||||
|
||||
|
||||
def build_model_from_cfg(config_path, checkpoint_path, device):
|
||||
model = init_detector(config_path, checkpoint_path, device=device)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
register_all_modules()
|
||||
|
||||
if not os.path.exists(args.work_dir):
|
||||
os.mkdir(args.work_dir)
|
||||
|
||||
postprocess_cfg = ConfigDict(
|
||||
pre_top_k=args.pre_topk,
|
||||
keep_top_k=args.keep_topk,
|
||||
iou_threshold=args.iou_threshold,
|
||||
score_threshold=args.score_threshold,
|
||||
backend=args.backend)
|
||||
|
||||
baseModel = build_model_from_cfg(args.config, args.checkpoint, args.device)
|
||||
|
||||
deploy_model = DeployModel(
|
||||
baseModel=baseModel, postprocess_cfg=postprocess_cfg)
|
||||
deploy_model.eval()
|
||||
|
||||
fake_input = torch.randn(args.batch_size, 3,
|
||||
*args.img_size).to(args.device)
|
||||
# dry run
|
||||
deploy_model(fake_input)
|
||||
|
||||
save_onnx_path = os.path.join(args.work_dir, 'end2end.onnx')
|
||||
# export onnx
|
||||
with BytesIO() as f:
|
||||
torch.onnx.export(
|
||||
deploy_model,
|
||||
fake_input,
|
||||
f,
|
||||
input_names=['images'],
|
||||
output_names=['num_det', 'det_boxes', 'det_scores', 'det_classes'],
|
||||
opset_version=args.opset)
|
||||
f.seek(0)
|
||||
onnx_model = onnx.load(f)
|
||||
onnx.checker.check_model(onnx_model)
|
||||
|
||||
# Fix tensorrt onnx output shape, just for view
|
||||
if args.backend in (2, 3):
|
||||
shapes = [
|
||||
args.batch_size, 1, args.batch_size, args.keep_topk, 4,
|
||||
args.batch_size, args.keep_topk, args.batch_size,
|
||||
args.keep_topk
|
||||
]
|
||||
for i in onnx_model.graph.output:
|
||||
for j in i.type.tensor_type.shape.dim:
|
||||
j.dim_param = str(shapes.pop(0))
|
||||
if args.simplify:
|
||||
try:
|
||||
import onnxsim
|
||||
onnx_model, check = onnxsim.simplify(onnx_model)
|
||||
assert check, 'assert check failed'
|
||||
except Exception as e:
|
||||
print(f'Simplify failure: {e}')
|
||||
onnx.save(onnx_model, save_onnx_path)
|
||||
print(f'ONNX export success, save into {save_onnx_path}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
Reference in New Issue