Support model-only inference (#733)

* Support model-only inference

* Fix ppyoloe std

* Add doc

* Fix typo

* Rename
pull/649/merge
triple Mu 2023-04-20 13:55:34 +08:00 committed by GitHub
parent 927e0a46af
commit 1aa1ecd27b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 680 additions and 2 deletions

View File

@ -1,5 +1,7 @@
# MMYOLO 模型 ONNX 转换
## 1. 导出后端支持的 ONNX
## 环境依赖
- [onnx](https://github.com/onnx/onnx)
@ -14,6 +16,8 @@
pip install onnx-simplifier
```
\*\*\* 请确保您在 `MMYOLO` 根目录下运行相关脚本,避免无法找到相关依赖包。\*\*\*
## 使用方法
[模型导出脚本](./projects/easydeploy/tools/export_onnx.py)用于将 `MMYOLO` 模型转换为 `onnx`
@ -28,7 +32,7 @@
- `--device`: 转换模型使用的设备,默认为 `cuda:0`
- `--simplify`: 是否简化导出的 `onnx` 模型,需要安装 [onnx-simplifier](https://github.com/daquexian/onnx-simplifier),默认关闭。
- `--opset`: 指定导出 `onnx``opset`,默认为 `11`
- `--backend`: 指定导出 `onnx` 用于的后端 id`ONNXRuntime`: `onnxruntime`, `TensorRT8`: `tensorrt8`, `TensorRT7`: `tensorrt7`,默认为`onnxruntime`即 `ONNXRuntime`
- `--backend`: 指定导出 `onnx` 用于的后端名称`ONNXRuntime`: `onnxruntime`, `TensorRT8`: `tensorrt8`, `TensorRT7`: `tensorrt7`,默认为`onnxruntime`即 `ONNXRuntime`
- `--pre-topk`: 指定导出 `onnx` 的后处理筛选候选框个数阈值,默认为 `1000`
- `--keep-topk`: 指定导出 `onnx` 的非极大值抑制输出的候选框个数阈值,默认为 `100`
- `--iou-threshold`: 非极大值抑制中过滤重复候选框的 `iou` 阈值,默认为 `0.65`
@ -54,4 +58,99 @@ python ./projects/easydeploy/tools/export.py \
--score-threshold 0.25
```
然后利用后端支持的工具如 `TensorRT` 读取 `onnx` 再次转换为后端支持的模型格式如 `.engine/.plan`
然后利用后端支持的工具如 `TensorRT` 读取 `onnx` 再次转换为后端支持的模型格式如 `.engine/.plan` 等。
`MMYOLO` 目前支持 `TensorRT8`, `TensorRT7`, `ONNXRuntime` 后端的端到端模型转换,目前仅支持静态 shape 模型的导出和转换,动态 batch 或动态长宽的模型端到端转换会在未来继续支持。
端到端转换得到的 `onnx` 模型输入输出如图:
<div align=center>
<img src="https://user-images.githubusercontent.com/92794867/232403745-101ca999-2003-46fa-bc5b-6b0eb2b2d41b.png"/>
</div>
输入名: `images`, 尺寸 640x640
输出名: `num_dets`, 尺寸 1x1表示检测目标数量。
输出名: `boxes`, 尺寸 1x100x4表示检测框的坐标格式为 `x1y1x2y1`
输出名: `scores`, 尺寸 1x100表示检测框的分数。
输出名: `labels`, 尺寸 1x100表示检测框的类别 id。
可以利用 `num_dets` 中的个数对 `boxes`, `scores`, `labels` 进行截断,从 100 个检测结果中抽取前 `num_dets` 个目标作为最终检测结果。
## 2. 仅导出模型 Backbone + Neck
当您需要部署在非 `TensorRT`, `ONNXRuntime` 等支持端到端部署的平台时,您可以考虑使用`--model-only` 参数并且不要传递 `--backend` 参数,您将会导出仅包含 `Backbone` + `neck` 的模型,模型的部分输出如图:
<div align=center>
<img src="https://user-images.githubusercontent.com/92794867/232406169-40eee9fd-bc53-4fdc-bd37-d0e9033826f9.png"/>
</div>
这种导出方式获取的 `ONNX` 模型具有如下优点:
- 算子简单,一般而言只包含 `Conv`,激活函数等简单算子,几乎不存在无法正确导出的情况,对于嵌入式部署更加友好。
- 方便不同算法之间对比速度性能,由于不同的算法后处理不同,仅对比 `backbone` + `Neck` 的速度更加公平。
也有如下缺点:
- 后处理逻辑需要单独完成,会有额外的 `decode` + `nms` 的操作需要实现。
- 与 `TensorRT` 相比,由于 `TensorRT` 可以利用多核优势并行进行后处理,使用 `--model-only` 方式导出的模型性能会差很多。
### 使用方法
```shell
python ./projects/easydeploy/tools/export.py \
configs/yolov5/yolov5_s-v61_syncbn_fast_8xb16-300e_coco.py \
yolov5s.pth \
--work-dir work_dir \
--img-size 640 640 \
--batch 1 \
--device cpu \
--simplify \
--opset 11 \
--model-only
```
## 使用 `model-only` 导出的 ONNX 进行推理
[模型推理脚本](./projects/easydeploy/examples/main_onnxruntime.py)用于推理导出的 `ONNX` 模型,需要安装基础依赖环境:
[`onnxruntime`](https://github.com/microsoft/onnxruntime) 和 [`opencv-python`](https://github.com/opencv/opencv-python)
```shell
pip install onnxruntime
pip install opencv-python==4.7.0.72 # 建议使用最新的 opencv
```
### 参数介绍:
- `img` : 待检测的图片路径或图片文件夹路径。
- `onnx` : 导出的 `model-only` ONNX 模型。
- `--type` : 模型名称,目前支持 `yolov5`, `yolox`, `yolov6`, `ppyoloe`, `ppyoloep`, `yolov7`, `rtmdet`, `yolov8`
- `--img-size`: 转换模型时输入的尺寸,如 `640 640`
- `--out-dir`: 保存检测结果的路径 。
- `--show`: 是否可视化检测结果。
- `--score-thr`: 模型检测后处理的置信度分数 。
- `--iou-thr`: 模型检测后处理的 IOU 分数 。
## 使用方法
```shell
cd ./projects/easydeploy/examples
python main_onnxruntime.py \
"image_path_to_detect" \
yolov5_s_model-only.onnx \
--out-dir work_dir \
--img-size 640 640 \
--show \
--score-thr 0.3 \
--iou-thr 0.7
```
*注意!!!*
当您使用自定义数据集训练得到的模型时,请修改 [`config.py`](./projects/easydeploy/examples/config.py) 中 `CLASS_NAMES``CLASS_COLORS`,如果是 `yolov5` 或者 `yolov7` 基于 `anchor` 的模型请同时修改 `YOLOv5_ANCHORS``YOLOv7_ANCHORS`
[`numpy_coder.py`](./projects/easydeploy/examples/numpy_coder.py) 是目前所有算法仅使用 `numpy` 实现的 `decoder`,如果您对性能有较高的要求,可以参照相关代码改写为 `c/c++`

View File

@ -0,0 +1,64 @@
from enum import Enum
class TASK_TYPE(Enum):
DET = 'det'
SEG = 'seg'
POSE = 'pose'
class ModelType(Enum):
YOLOV5 = 'yolov5'
YOLOX = 'yolox'
PPYOLOE = 'ppyoloe'
PPYOLOEP = 'ppyoloep'
YOLOV6 = 'yolov6'
YOLOV7 = 'yolov7'
RTMDET = 'rtmdet'
YOLOV8 = 'yolov8'
CLASS_NAMES = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe',
'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat',
'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop',
'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock',
'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush')
CLASS_COLORS = [(220, 20, 60), (119, 11, 32), (0, 0, 142), (0, 0, 230),
(106, 0, 228), (0, 60, 100), (0, 80, 100), (0, 0, 70),
(0, 0, 192), (250, 170, 30), (100, 170, 30), (220, 220, 0),
(175, 116, 175), (250, 0, 30), (165, 42, 42), (255, 77, 255),
(0, 226, 252), (182, 182, 255), (0, 82, 0), (120, 166, 157),
(110, 76, 0), (174, 57, 255), (199, 100, 0), (72, 0, 118),
(255, 179, 240), (0, 125, 92), (209, 0, 151), (188, 208, 182),
(0, 220, 176), (255, 99, 164), (92, 0, 73), (133, 129, 255),
(78, 180, 255), (0, 228, 0), (174, 255, 243), (45, 89, 255),
(134, 134, 103), (145, 148, 174), (255, 208, 186),
(197, 226, 255), (171, 134, 1), (109, 63, 54), (207, 138, 255),
(151, 0, 95), (9, 80, 61), (84, 105, 51), (74, 65, 105),
(166, 196, 102), (208, 195, 210), (255, 109, 65),
(0, 143, 149), (179, 0, 194), (209, 99, 106), (5, 121, 0),
(227, 255, 205), (147, 186, 208), (153, 69, 1), (3, 95, 161),
(163, 255, 0), (119, 0, 170), (0, 182, 199), (0, 165, 120),
(183, 130, 88), (95, 32, 0), (130, 114, 135), (110, 129, 133),
(166, 74, 118), (219, 142, 185), (79, 210, 114), (178, 90, 62),
(65, 70, 15), (127, 167, 115), (59, 105, 106), (142, 108, 45),
(196, 172, 0), (95, 54, 80), (128, 76, 255), (201, 57, 1),
(246, 0, 122), (191, 162, 208)]
YOLOv5_ANCHORS = [[(10, 13), (16, 30), (33, 23)],
[(30, 61), (62, 45), (59, 119)],
[(116, 90), (156, 198), (373, 326)]]
YOLOv7_ANCHORS = [[(12, 16), (19, 36), (40, 28)],
[(36, 75), (76, 55), (72, 146)],
[(142, 110), (192, 243), (459, 401)]]

View File

@ -0,0 +1,36 @@
from typing import List, Tuple, Union
import cv2
from numpy import ndarray
MAJOR, MINOR = map(int, cv2.__version__.split('.')[:2])
assert MAJOR == 4
def non_max_suppression(boxes: Union[List[ndarray], Tuple[ndarray]],
scores: Union[List[float], Tuple[float]],
labels: Union[List[int], Tuple[int]],
conf_thres: float = 0.25,
iou_thres: float = 0.65) -> Tuple[List, List, List]:
if MINOR >= 7:
indices = cv2.dnn.NMSBoxesBatched(boxes, scores, labels, conf_thres,
iou_thres)
elif MINOR == 6:
indices = cv2.dnn.NMSBoxes(boxes, scores, conf_thres, iou_thres)
else:
indices = cv2.dnn.NMSBoxes(boxes, scores, conf_thres,
iou_thres).flatten()
nmsd_boxes = []
nmsd_scores = []
nmsd_labels = []
for idx in indices:
box = boxes[idx]
# x0y0wh -> x0y0x1y1
box[2:] = box[:2] + box[2:]
score = scores[idx]
label = labels[idx]
nmsd_boxes.append(box)
nmsd_scores.append(score)
nmsd_labels.append(label)
return nmsd_boxes, nmsd_scores, nmsd_labels

View File

@ -0,0 +1,110 @@
import math
import sys
from argparse import ArgumentParser
from pathlib import Path
import cv2
import onnxruntime
from config import (CLASS_COLORS, CLASS_NAMES, ModelType, YOLOv5_ANCHORS,
YOLOv7_ANCHORS)
from cv2_nms import non_max_suppression
from numpy_coder import Decoder
from preprocess import Preprocess
from tqdm import tqdm
# Add __FILE__ to sys.path
sys.path.append(str(Path(__file__).resolve().parents[0]))
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif',
'.tiff', '.webp')
def path_to_list(path: str):
path = Path(path)
if path.is_file() and path.suffix in IMG_EXTENSIONS:
res_list = [str(path.absolute())]
elif path.is_dir():
res_list = [
str(p.absolute()) for p in path.iterdir()
if p.suffix in IMG_EXTENSIONS
]
else:
raise RuntimeError
return res_list
def parse_args():
parser = ArgumentParser()
parser.add_argument(
'img', help='Image path, include image file, dir and URL.')
parser.add_argument('onnx', type=str, help='Onnx file')
parser.add_argument('--type', type=str, help='Model type')
parser.add_argument(
'--img-size',
nargs='+',
type=int,
default=[640, 640],
help='Image size of height and width')
parser.add_argument(
'--out-dir', default='./output', type=str, help='Path to output file')
parser.add_argument(
'--show', action='store_true', help='Show the detection results')
parser.add_argument(
'--score-thr', type=float, default=0.3, help='Bbox score threshold')
parser.add_argument(
'--iou-thr', type=float, default=0.7, help='Bbox iou threshold')
args = parser.parse_args()
return args
def main():
args = parse_args()
out_dir = Path(args.out_dir)
model_type = ModelType(args.type.lower())
if not args.show:
out_dir.mkdir(parents=True, exist_ok=True)
files = path_to_list(args.img)
session = onnxruntime.InferenceSession(
args.onnx, providers=['CPUExecutionProvider'])
preprocessor = Preprocess(model_type)
decoder = Decoder(model_type, model_only=True)
if model_type == ModelType.YOLOV5:
anchors = YOLOv5_ANCHORS
elif model_type == ModelType.YOLOV7:
anchors = YOLOv7_ANCHORS
else:
anchors = None
for file in tqdm(files):
image = cv2.imread(file)
image_h, image_w = image.shape[:2]
img, (ratio_w, ratio_h) = preprocessor(image, args.img_size)
features = session.run(None, {'images': img})
decoder_outputs = decoder(
features,
args.score_thr,
num_labels=len(CLASS_NAMES),
anchors=anchors)
nmsd_boxes, nmsd_scores, nmsd_labels = non_max_suppression(
*decoder_outputs, args.score_thr, args.iou_thr)
for box, score, label in zip(nmsd_boxes, nmsd_scores, nmsd_labels):
x0, y0, x1, y1 = box
x0 = math.floor(min(max(x0 / ratio_w, 1), image_w - 1))
y0 = math.floor(min(max(y0 / ratio_h, 1), image_h - 1))
x1 = math.ceil(min(max(x1 / ratio_w, 1), image_w - 1))
y1 = math.ceil(min(max(y1 / ratio_h, 1), image_h - 1))
cv2.rectangle(image, (x0, y0), (x1, y1), CLASS_COLORS[label], 2)
cv2.putText(image, f'{CLASS_NAMES[label]}: {score:.2f}',
(x0, y0 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5,
(0, 255, 255), 2)
if args.show:
cv2.imshow('result', image)
cv2.waitKey(0)
else:
cv2.imwrite(f'{out_dir / Path(file).name}', image)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,310 @@
from typing import List, Tuple, Union
import numpy as np
from config import ModelType
from numpy import ndarray
def softmax(x: ndarray, axis: int = -1) -> ndarray:
e_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
y = e_x / e_x.sum(axis=axis, keepdims=True)
return y
def sigmoid(x: ndarray) -> ndarray:
return 1. / (1. + np.exp(-x))
class Decoder:
def __init__(self, model_type: ModelType, model_only: bool = False):
self.model_type = model_type
self.model_only = model_only
self.boxes_pro = []
self.scores_pro = []
self.labels_pro = []
self.is_logging = False
def __call__(self,
feats: Union[List, Tuple],
conf_thres: float,
num_labels: int = 80,
**kwargs) -> Tuple:
if not self.is_logging:
print('Only support decode in batch==1')
self.is_logging = True
self.boxes_pro.clear()
self.scores_pro.clear()
self.labels_pro.clear()
if self.model_only:
# transpose channel to last dim for easy decoding
feats = [
np.ascontiguousarray(feat[0].transpose(1, 2, 0))
for feat in feats
]
else:
# ax620a horizonX3 transpose channel to last dim by default
feats = [np.ascontiguousarray(feat) for feat in feats]
if self.model_type == ModelType.YOLOV5:
self.__yolov5_decode(feats, conf_thres, num_labels, **kwargs)
elif self.model_type == ModelType.YOLOX:
self.__yolox_decode(feats, conf_thres, num_labels, **kwargs)
elif self.model_type in (ModelType.PPYOLOE, ModelType.PPYOLOEP):
self.__ppyoloe_decode(feats, conf_thres, num_labels, **kwargs)
elif self.model_type == ModelType.YOLOV6:
self.__yolov6_decode(feats, conf_thres, num_labels, **kwargs)
elif self.model_type == ModelType.YOLOV7:
self.__yolov7_decode(feats, conf_thres, num_labels, **kwargs)
elif self.model_type == ModelType.RTMDET:
self.__rtmdet_decode(feats, conf_thres, num_labels, **kwargs)
elif self.model_type == ModelType.YOLOV8:
self.__yolov8_decode(feats, conf_thres, num_labels, **kwargs)
else:
raise NotImplementedError
return self.boxes_pro, self.scores_pro, self.labels_pro
def __yolov5_decode(self,
feats: List[ndarray],
conf_thres: float,
num_labels: int = 80,
**kwargs):
anchors: Union[List, Tuple] = kwargs.get(
'anchors',
[[(10, 13), (16, 30),
(33, 23)], [(30, 61), (62, 45),
(59, 119)], [(116, 90), (156, 198), (373, 326)]])
for i, feat in enumerate(feats):
stride = 8 << i
feat_h, feat_w, _ = feat.shape
anchor = anchors[i]
feat = sigmoid(feat)
feat = feat.reshape((feat_h, feat_w, len(anchor), -1))
box_feat, conf_feat, score_feat = np.split(feat, [4, 5], -1)
hIdx, wIdx, aIdx, _ = np.where(conf_feat > conf_thres)
num_proposal = hIdx.size
if not num_proposal:
continue
score_feat = score_feat[hIdx, wIdx, aIdx] * conf_feat[hIdx, wIdx,
aIdx]
boxes = box_feat[hIdx, wIdx, aIdx]
labels = score_feat.argmax(-1)
scores = score_feat.max(-1)
indices = np.where(scores > conf_thres)[0]
if len(indices) == 0:
continue
for idx in indices:
a_w, a_h = anchor[aIdx[idx]]
x, y, w, h = boxes[idx]
x = (x * 2.0 - 0.5 + wIdx[idx]) * stride
y = (y * 2.0 - 0.5 + hIdx[idx]) * stride
w = (w * 2.0)**2 * a_w
h = (h * 2.0)**2 * a_h
x0 = x - w / 2
y0 = y - h / 2
self.scores_pro.append(float(scores[idx]))
self.boxes_pro.append(
np.array([x0, y0, w, h], dtype=np.float32))
self.labels_pro.append(int(labels[idx]))
def __yolox_decode(self,
feats: List[ndarray],
conf_thres: float,
num_labels: int = 80,
**kwargs):
for i, feat in enumerate(feats):
stride = 8 << i
score_feat, box_feat, conf_feat = np.split(
feat, [num_labels, num_labels + 4], -1)
conf_feat = sigmoid(conf_feat)
hIdx, wIdx, _ = np.where(conf_feat > conf_thres)
num_proposal = hIdx.size
if not num_proposal:
continue
score_feat = sigmoid(score_feat[hIdx, wIdx]) * conf_feat[hIdx,
wIdx]
boxes = box_feat[hIdx, wIdx]
labels = score_feat.argmax(-1)
scores = score_feat.max(-1)
indices = np.where(scores > conf_thres)[0]
if len(indices) == 0:
continue
for idx in indices:
score = scores[idx]
label = labels[idx]
x, y, w, h = boxes[idx]
x = (x + wIdx[idx]) * stride
y = (y + hIdx[idx]) * stride
w = np.exp(w) * stride
h = np.exp(h) * stride
x0 = x - w / 2
y0 = y - h / 2
self.scores_pro.append(float(score))
self.boxes_pro.append(
np.array([x0, y0, w, h], dtype=np.float32))
self.labels_pro.append(int(label))
def __ppyoloe_decode(self,
feats: List[ndarray],
conf_thres: float,
num_labels: int = 80,
**kwargs):
reg_max: int = kwargs.get('reg_max', 17)
dfl = np.arange(0, reg_max, dtype=np.float32)
for i, feat in enumerate(feats):
stride = 8 << i
score_feat, box_feat = np.split(feat, [
num_labels,
], -1)
score_feat = sigmoid(score_feat)
_argmax = score_feat.argmax(-1)
_max = score_feat.max(-1)
indices = np.where(_max > conf_thres)
hIdx, wIdx = indices
num_proposal = hIdx.size
if not num_proposal:
continue
scores = _max[hIdx, wIdx]
boxes = box_feat[hIdx, wIdx].reshape(num_proposal, 4, reg_max)
boxes = softmax(boxes, -1) @ dfl
labels = _argmax[hIdx, wIdx]
for k in range(num_proposal):
score = scores[k]
label = labels[k]
x0, y0, x1, y1 = boxes[k]
x0 = (wIdx[k] + 0.5 - x0) * stride
y0 = (hIdx[k] + 0.5 - y0) * stride
x1 = (wIdx[k] + 0.5 + x1) * stride
y1 = (hIdx[k] + 0.5 + y1) * stride
w = x1 - x0
h = y1 - y0
self.scores_pro.append(float(score))
self.boxes_pro.append(
np.array([x0, y0, w, h], dtype=np.float32))
self.labels_pro.append(int(label))
def __yolov6_decode(self,
feats: List[ndarray],
conf_thres: float,
num_labels: int = 80,
**kwargs):
for i, feat in enumerate(feats):
stride = 8 << i
score_feat, box_feat = np.split(feat, [
num_labels,
], -1)
score_feat = sigmoid(score_feat)
_argmax = score_feat.argmax(-1)
_max = score_feat.max(-1)
indices = np.where(_max > conf_thres)
hIdx, wIdx = indices
num_proposal = hIdx.size
if not num_proposal:
continue
scores = _max[hIdx, wIdx]
boxes = box_feat[hIdx, wIdx]
labels = _argmax[hIdx, wIdx]
for k in range(num_proposal):
score = scores[k]
label = labels[k]
x0, y0, x1, y1 = boxes[k]
x0 = (wIdx[k] + 0.5 - x0) * stride
y0 = (hIdx[k] + 0.5 - y0) * stride
x1 = (wIdx[k] + 0.5 + x1) * stride
y1 = (hIdx[k] + 0.5 + y1) * stride
w = x1 - x0
h = y1 - y0
self.scores_pro.append(float(score))
self.boxes_pro.append(
np.array([x0, y0, w, h], dtype=np.float32))
self.labels_pro.append(int(label))
def __yolov7_decode(self,
feats: List[ndarray],
conf_thres: float,
num_labels: int = 80,
**kwargs):
anchors: Union[List, Tuple] = kwargs.get(
'anchors',
[[(12, 16), (19, 36),
(40, 28)], [(36, 75), (76, 55),
(72, 146)], [(142, 110), (192, 243), (459, 401)]])
self.__yolov5_decode(feats, conf_thres, num_labels, anchors=anchors)
def __rtmdet_decode(self,
feats: List[ndarray],
conf_thres: float,
num_labels: int = 80,
**kwargs):
for i, feat in enumerate(feats):
stride = 8 << i
score_feat, box_feat = np.split(feat, [
num_labels,
], -1)
score_feat = sigmoid(score_feat)
_argmax = score_feat.argmax(-1)
_max = score_feat.max(-1)
indices = np.where(_max > conf_thres)
hIdx, wIdx = indices
num_proposal = hIdx.size
if not num_proposal:
continue
scores = _max[hIdx, wIdx]
boxes = box_feat[hIdx, wIdx]
labels = _argmax[hIdx, wIdx]
for k in range(num_proposal):
score = scores[k]
label = labels[k]
x0, y0, x1, y1 = boxes[k]
x0 = (wIdx[k] - x0) * stride
y0 = (hIdx[k] - y0) * stride
x1 = (wIdx[k] + x1) * stride
y1 = (hIdx[k] + y1) * stride
w = x1 - x0
h = y1 - y0
self.scores_pro.append(float(score))
self.boxes_pro.append(
np.array([x0, y0, w, h], dtype=np.float32))
self.labels_pro.append(int(label))
def __yolov8_decode(self,
feats: List[ndarray],
conf_thres: float,
num_labels: int = 80,
**kwargs):
reg_max: int = kwargs.get('reg_max', 16)
self.__ppyoloe_decode(feats, conf_thres, num_labels, reg_max=reg_max)

View File

@ -0,0 +1,57 @@
from typing import List, Tuple, Union
import cv2
import numpy as np
from config import ModelType
from numpy import ndarray
class Preprocess:
def __init__(self, model_type: ModelType):
if model_type in (ModelType.YOLOV5, ModelType.YOLOV6, ModelType.YOLOV7,
ModelType.YOLOV8):
mean = np.array([0, 0, 0], dtype=np.float32)
std = np.array([255, 255, 255], dtype=np.float32)
is_rgb = True
elif model_type == ModelType.YOLOX:
mean = np.array([0, 0, 0], dtype=np.float32)
std = np.array([1, 1, 1], dtype=np.float32)
is_rgb = False
elif model_type == ModelType.PPYOLOE:
mean = np.array([123.675, 116.28, 103.53], dtype=np.float32)
std = np.array([58.395, 57.12, 57.375], dtype=np.float32)
is_rgb = True
elif model_type == ModelType.PPYOLOEP:
mean = np.array([0, 0, 0], dtype=np.float32)
std = np.array([255, 255, 255], dtype=np.float32)
is_rgb = True
elif model_type == ModelType.RTMDET:
mean = np.array([103.53, 116.28, 123.675], dtype=np.float32)
std = np.array([57.375, 57.12, 58.3955], dtype=np.float32)
is_rgb = False
else:
raise NotImplementedError
self.mean = mean.reshape((3, 1, 1))
self.std = std.reshape((3, 1, 1))
self.is_rgb = is_rgb
def __call__(self,
image: ndarray,
new_size: Union[List[int], Tuple[int]] = (640, 640),
**kwargs) -> Tuple[ndarray, Tuple[float, float]]:
# new_size: (height, width)
height, width = image.shape[:2]
ratio_h, ratio_w = new_size[0] / height, new_size[1] / width
image = cv2.resize(
image, (0, 0),
fx=ratio_w,
fy=ratio_h,
interpolation=cv2.INTER_LINEAR)
image = np.ascontiguousarray(image.transpose(2, 0, 1))
image = image.astype(np.float32)
image -= self.mean
image /= self.std
return image[np.newaxis], (ratio_w, ratio_h)

View File

@ -0,0 +1,2 @@
onnxruntime
opencv-python==4.7.0.72