mirror of https://github.com/open-mmlab/mmyolo.git
Support model-only inference (#733)
* Support model-only inference * Fix ppyoloe std * Add doc * Fix typo * Renamepull/649/merge
parent
927e0a46af
commit
1aa1ecd27b
|
@ -1,5 +1,7 @@
|
|||
# MMYOLO 模型 ONNX 转换
|
||||
|
||||
## 1. 导出后端支持的 ONNX
|
||||
|
||||
## 环境依赖
|
||||
|
||||
- [onnx](https://github.com/onnx/onnx)
|
||||
|
@ -14,6 +16,8 @@
|
|||
pip install onnx-simplifier
|
||||
```
|
||||
|
||||
\*\*\* 请确保您在 `MMYOLO` 根目录下运行相关脚本,避免无法找到相关依赖包。\*\*\*
|
||||
|
||||
## 使用方法
|
||||
|
||||
[模型导出脚本](./projects/easydeploy/tools/export_onnx.py)用于将 `MMYOLO` 模型转换为 `onnx` 。
|
||||
|
@ -28,7 +32,7 @@
|
|||
- `--device`: 转换模型使用的设备,默认为 `cuda:0`。
|
||||
- `--simplify`: 是否简化导出的 `onnx` 模型,需要安装 [onnx-simplifier](https://github.com/daquexian/onnx-simplifier),默认关闭。
|
||||
- `--opset`: 指定导出 `onnx` 的 `opset`,默认为 `11` 。
|
||||
- `--backend`: 指定导出 `onnx` 用于的后端 id,`ONNXRuntime`: `onnxruntime`, `TensorRT8`: `tensorrt8`, `TensorRT7`: `tensorrt7`,默认为`onnxruntime`即 `ONNXRuntime`。
|
||||
- `--backend`: 指定导出 `onnx` 用于的后端名称,`ONNXRuntime`: `onnxruntime`, `TensorRT8`: `tensorrt8`, `TensorRT7`: `tensorrt7`,默认为`onnxruntime`即 `ONNXRuntime`。
|
||||
- `--pre-topk`: 指定导出 `onnx` 的后处理筛选候选框个数阈值,默认为 `1000`。
|
||||
- `--keep-topk`: 指定导出 `onnx` 的非极大值抑制输出的候选框个数阈值,默认为 `100`。
|
||||
- `--iou-threshold`: 非极大值抑制中过滤重复候选框的 `iou` 阈值,默认为 `0.65`。
|
||||
|
@ -54,4 +58,99 @@ python ./projects/easydeploy/tools/export.py \
|
|||
--score-threshold 0.25
|
||||
```
|
||||
|
||||
然后利用后端支持的工具如 `TensorRT` 读取 `onnx` 再次转换为后端支持的模型格式如 `.engine/.plan` 等
|
||||
然后利用后端支持的工具如 `TensorRT` 读取 `onnx` 再次转换为后端支持的模型格式如 `.engine/.plan` 等。
|
||||
|
||||
`MMYOLO` 目前支持 `TensorRT8`, `TensorRT7`, `ONNXRuntime` 后端的端到端模型转换,目前仅支持静态 shape 模型的导出和转换,动态 batch 或动态长宽的模型端到端转换会在未来继续支持。
|
||||
|
||||
端到端转换得到的 `onnx` 模型输入输出如图:
|
||||
|
||||
<div align=center>
|
||||
<img src="https://user-images.githubusercontent.com/92794867/232403745-101ca999-2003-46fa-bc5b-6b0eb2b2d41b.png"/>
|
||||
</div>
|
||||
|
||||
输入名: `images`, 尺寸 640x640
|
||||
|
||||
输出名: `num_dets`, 尺寸 1x1,表示检测目标数量。
|
||||
|
||||
输出名: `boxes`, 尺寸 1x100x4,表示检测框的坐标,格式为 `x1y1x2y1`。
|
||||
|
||||
输出名: `scores`, 尺寸 1x100,表示检测框的分数。
|
||||
|
||||
输出名: `labels`, 尺寸 1x100,表示检测框的类别 id。
|
||||
|
||||
可以利用 `num_dets` 中的个数对 `boxes`, `scores`, `labels` 进行截断,从 100 个检测结果中抽取前 `num_dets` 个目标作为最终检测结果。
|
||||
|
||||
## 2. 仅导出模型 Backbone + Neck
|
||||
|
||||
当您需要部署在非 `TensorRT`, `ONNXRuntime` 等支持端到端部署的平台时,您可以考虑使用`--model-only` 参数并且不要传递 `--backend` 参数,您将会导出仅包含 `Backbone` + `neck` 的模型,模型的部分输出如图:
|
||||
|
||||
<div align=center>
|
||||
<img src="https://user-images.githubusercontent.com/92794867/232406169-40eee9fd-bc53-4fdc-bd37-d0e9033826f9.png"/>
|
||||
</div>
|
||||
|
||||
这种导出方式获取的 `ONNX` 模型具有如下优点:
|
||||
|
||||
- 算子简单,一般而言只包含 `Conv`,激活函数等简单算子,几乎不存在无法正确导出的情况,对于嵌入式部署更加友好。
|
||||
- 方便不同算法之间对比速度性能,由于不同的算法后处理不同,仅对比 `backbone` + `Neck` 的速度更加公平。
|
||||
|
||||
也有如下缺点:
|
||||
|
||||
- 后处理逻辑需要单独完成,会有额外的 `decode` + `nms` 的操作需要实现。
|
||||
- 与 `TensorRT` 相比,由于 `TensorRT` 可以利用多核优势并行进行后处理,使用 `--model-only` 方式导出的模型性能会差很多。
|
||||
|
||||
### 使用方法
|
||||
|
||||
```shell
|
||||
python ./projects/easydeploy/tools/export.py \
|
||||
configs/yolov5/yolov5_s-v61_syncbn_fast_8xb16-300e_coco.py \
|
||||
yolov5s.pth \
|
||||
--work-dir work_dir \
|
||||
--img-size 640 640 \
|
||||
--batch 1 \
|
||||
--device cpu \
|
||||
--simplify \
|
||||
--opset 11 \
|
||||
--model-only
|
||||
```
|
||||
|
||||
## 使用 `model-only` 导出的 ONNX 进行推理
|
||||
|
||||
[模型推理脚本](./projects/easydeploy/examples/main_onnxruntime.py)用于推理导出的 `ONNX` 模型,需要安装基础依赖环境:
|
||||
|
||||
[`onnxruntime`](https://github.com/microsoft/onnxruntime) 和 [`opencv-python`](https://github.com/opencv/opencv-python)
|
||||
|
||||
```shell
|
||||
pip install onnxruntime
|
||||
pip install opencv-python==4.7.0.72 # 建议使用最新的 opencv
|
||||
```
|
||||
|
||||
### 参数介绍:
|
||||
|
||||
- `img` : 待检测的图片路径或图片文件夹路径。
|
||||
- `onnx` : 导出的 `model-only` ONNX 模型。
|
||||
- `--type` : 模型名称,目前支持 `yolov5`, `yolox`, `yolov6`, `ppyoloe`, `ppyoloep`, `yolov7`, `rtmdet`, `yolov8`。
|
||||
- `--img-size`: 转换模型时输入的尺寸,如 `640 640`。
|
||||
- `--out-dir`: 保存检测结果的路径 。
|
||||
- `--show`: 是否可视化检测结果。
|
||||
- `--score-thr`: 模型检测后处理的置信度分数 。
|
||||
- `--iou-thr`: 模型检测后处理的 IOU 分数 。
|
||||
|
||||
## 使用方法
|
||||
|
||||
```shell
|
||||
cd ./projects/easydeploy/examples
|
||||
python main_onnxruntime.py \
|
||||
"image_path_to_detect" \
|
||||
yolov5_s_model-only.onnx \
|
||||
--out-dir work_dir \
|
||||
--img-size 640 640 \
|
||||
--show \
|
||||
--score-thr 0.3 \
|
||||
--iou-thr 0.7
|
||||
```
|
||||
|
||||
*注意!!!*
|
||||
|
||||
当您使用自定义数据集训练得到的模型时,请修改 [`config.py`](./projects/easydeploy/examples/config.py) 中 `CLASS_NAMES` 和 `CLASS_COLORS`,如果是 `yolov5` 或者 `yolov7` 基于 `anchor` 的模型请同时修改 `YOLOv5_ANCHORS` 和 `YOLOv7_ANCHORS`。
|
||||
|
||||
[`numpy_coder.py`](./projects/easydeploy/examples/numpy_coder.py) 是目前所有算法仅使用 `numpy` 实现的 `decoder`,如果您对性能有较高的要求,可以参照相关代码改写为 `c/c++`。
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
from enum import Enum
|
||||
|
||||
|
||||
class TASK_TYPE(Enum):
|
||||
DET = 'det'
|
||||
SEG = 'seg'
|
||||
POSE = 'pose'
|
||||
|
||||
|
||||
class ModelType(Enum):
|
||||
YOLOV5 = 'yolov5'
|
||||
YOLOX = 'yolox'
|
||||
PPYOLOE = 'ppyoloe'
|
||||
PPYOLOEP = 'ppyoloep'
|
||||
YOLOV6 = 'yolov6'
|
||||
YOLOV7 = 'yolov7'
|
||||
RTMDET = 'rtmdet'
|
||||
YOLOV8 = 'yolov8'
|
||||
|
||||
|
||||
CLASS_NAMES = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
|
||||
'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
|
||||
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
|
||||
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe',
|
||||
'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
|
||||
'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat',
|
||||
'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
|
||||
'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
|
||||
'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
|
||||
'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
|
||||
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop',
|
||||
'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
|
||||
'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock',
|
||||
'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush')
|
||||
|
||||
CLASS_COLORS = [(220, 20, 60), (119, 11, 32), (0, 0, 142), (0, 0, 230),
|
||||
(106, 0, 228), (0, 60, 100), (0, 80, 100), (0, 0, 70),
|
||||
(0, 0, 192), (250, 170, 30), (100, 170, 30), (220, 220, 0),
|
||||
(175, 116, 175), (250, 0, 30), (165, 42, 42), (255, 77, 255),
|
||||
(0, 226, 252), (182, 182, 255), (0, 82, 0), (120, 166, 157),
|
||||
(110, 76, 0), (174, 57, 255), (199, 100, 0), (72, 0, 118),
|
||||
(255, 179, 240), (0, 125, 92), (209, 0, 151), (188, 208, 182),
|
||||
(0, 220, 176), (255, 99, 164), (92, 0, 73), (133, 129, 255),
|
||||
(78, 180, 255), (0, 228, 0), (174, 255, 243), (45, 89, 255),
|
||||
(134, 134, 103), (145, 148, 174), (255, 208, 186),
|
||||
(197, 226, 255), (171, 134, 1), (109, 63, 54), (207, 138, 255),
|
||||
(151, 0, 95), (9, 80, 61), (84, 105, 51), (74, 65, 105),
|
||||
(166, 196, 102), (208, 195, 210), (255, 109, 65),
|
||||
(0, 143, 149), (179, 0, 194), (209, 99, 106), (5, 121, 0),
|
||||
(227, 255, 205), (147, 186, 208), (153, 69, 1), (3, 95, 161),
|
||||
(163, 255, 0), (119, 0, 170), (0, 182, 199), (0, 165, 120),
|
||||
(183, 130, 88), (95, 32, 0), (130, 114, 135), (110, 129, 133),
|
||||
(166, 74, 118), (219, 142, 185), (79, 210, 114), (178, 90, 62),
|
||||
(65, 70, 15), (127, 167, 115), (59, 105, 106), (142, 108, 45),
|
||||
(196, 172, 0), (95, 54, 80), (128, 76, 255), (201, 57, 1),
|
||||
(246, 0, 122), (191, 162, 208)]
|
||||
|
||||
YOLOv5_ANCHORS = [[(10, 13), (16, 30), (33, 23)],
|
||||
[(30, 61), (62, 45), (59, 119)],
|
||||
[(116, 90), (156, 198), (373, 326)]]
|
||||
|
||||
YOLOv7_ANCHORS = [[(12, 16), (19, 36), (40, 28)],
|
||||
[(36, 75), (76, 55), (72, 146)],
|
||||
[(142, 110), (192, 243), (459, 401)]]
|
|
@ -0,0 +1,36 @@
|
|||
from typing import List, Tuple, Union
|
||||
|
||||
import cv2
|
||||
from numpy import ndarray
|
||||
|
||||
MAJOR, MINOR = map(int, cv2.__version__.split('.')[:2])
|
||||
assert MAJOR == 4
|
||||
|
||||
|
||||
def non_max_suppression(boxes: Union[List[ndarray], Tuple[ndarray]],
|
||||
scores: Union[List[float], Tuple[float]],
|
||||
labels: Union[List[int], Tuple[int]],
|
||||
conf_thres: float = 0.25,
|
||||
iou_thres: float = 0.65) -> Tuple[List, List, List]:
|
||||
if MINOR >= 7:
|
||||
indices = cv2.dnn.NMSBoxesBatched(boxes, scores, labels, conf_thres,
|
||||
iou_thres)
|
||||
elif MINOR == 6:
|
||||
indices = cv2.dnn.NMSBoxes(boxes, scores, conf_thres, iou_thres)
|
||||
else:
|
||||
indices = cv2.dnn.NMSBoxes(boxes, scores, conf_thres,
|
||||
iou_thres).flatten()
|
||||
|
||||
nmsd_boxes = []
|
||||
nmsd_scores = []
|
||||
nmsd_labels = []
|
||||
for idx in indices:
|
||||
box = boxes[idx]
|
||||
# x0y0wh -> x0y0x1y1
|
||||
box[2:] = box[:2] + box[2:]
|
||||
score = scores[idx]
|
||||
label = labels[idx]
|
||||
nmsd_boxes.append(box)
|
||||
nmsd_scores.append(score)
|
||||
nmsd_labels.append(label)
|
||||
return nmsd_boxes, nmsd_scores, nmsd_labels
|
|
@ -0,0 +1,110 @@
|
|||
import math
|
||||
import sys
|
||||
from argparse import ArgumentParser
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import onnxruntime
|
||||
from config import (CLASS_COLORS, CLASS_NAMES, ModelType, YOLOv5_ANCHORS,
|
||||
YOLOv7_ANCHORS)
|
||||
from cv2_nms import non_max_suppression
|
||||
from numpy_coder import Decoder
|
||||
from preprocess import Preprocess
|
||||
from tqdm import tqdm
|
||||
|
||||
# Add __FILE__ to sys.path
|
||||
sys.path.append(str(Path(__file__).resolve().parents[0]))
|
||||
|
||||
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif',
|
||||
'.tiff', '.webp')
|
||||
|
||||
|
||||
def path_to_list(path: str):
|
||||
path = Path(path)
|
||||
if path.is_file() and path.suffix in IMG_EXTENSIONS:
|
||||
res_list = [str(path.absolute())]
|
||||
elif path.is_dir():
|
||||
res_list = [
|
||||
str(p.absolute()) for p in path.iterdir()
|
||||
if p.suffix in IMG_EXTENSIONS
|
||||
]
|
||||
else:
|
||||
raise RuntimeError
|
||||
return res_list
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument(
|
||||
'img', help='Image path, include image file, dir and URL.')
|
||||
parser.add_argument('onnx', type=str, help='Onnx file')
|
||||
parser.add_argument('--type', type=str, help='Model type')
|
||||
parser.add_argument(
|
||||
'--img-size',
|
||||
nargs='+',
|
||||
type=int,
|
||||
default=[640, 640],
|
||||
help='Image size of height and width')
|
||||
parser.add_argument(
|
||||
'--out-dir', default='./output', type=str, help='Path to output file')
|
||||
parser.add_argument(
|
||||
'--show', action='store_true', help='Show the detection results')
|
||||
parser.add_argument(
|
||||
'--score-thr', type=float, default=0.3, help='Bbox score threshold')
|
||||
parser.add_argument(
|
||||
'--iou-thr', type=float, default=0.7, help='Bbox iou threshold')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
out_dir = Path(args.out_dir)
|
||||
model_type = ModelType(args.type.lower())
|
||||
|
||||
if not args.show:
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
files = path_to_list(args.img)
|
||||
session = onnxruntime.InferenceSession(
|
||||
args.onnx, providers=['CPUExecutionProvider'])
|
||||
preprocessor = Preprocess(model_type)
|
||||
decoder = Decoder(model_type, model_only=True)
|
||||
if model_type == ModelType.YOLOV5:
|
||||
anchors = YOLOv5_ANCHORS
|
||||
elif model_type == ModelType.YOLOV7:
|
||||
anchors = YOLOv7_ANCHORS
|
||||
else:
|
||||
anchors = None
|
||||
|
||||
for file in tqdm(files):
|
||||
image = cv2.imread(file)
|
||||
image_h, image_w = image.shape[:2]
|
||||
img, (ratio_w, ratio_h) = preprocessor(image, args.img_size)
|
||||
features = session.run(None, {'images': img})
|
||||
decoder_outputs = decoder(
|
||||
features,
|
||||
args.score_thr,
|
||||
num_labels=len(CLASS_NAMES),
|
||||
anchors=anchors)
|
||||
nmsd_boxes, nmsd_scores, nmsd_labels = non_max_suppression(
|
||||
*decoder_outputs, args.score_thr, args.iou_thr)
|
||||
for box, score, label in zip(nmsd_boxes, nmsd_scores, nmsd_labels):
|
||||
x0, y0, x1, y1 = box
|
||||
x0 = math.floor(min(max(x0 / ratio_w, 1), image_w - 1))
|
||||
y0 = math.floor(min(max(y0 / ratio_h, 1), image_h - 1))
|
||||
x1 = math.ceil(min(max(x1 / ratio_w, 1), image_w - 1))
|
||||
y1 = math.ceil(min(max(y1 / ratio_h, 1), image_h - 1))
|
||||
cv2.rectangle(image, (x0, y0), (x1, y1), CLASS_COLORS[label], 2)
|
||||
cv2.putText(image, f'{CLASS_NAMES[label]}: {score:.2f}',
|
||||
(x0, y0 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5,
|
||||
(0, 255, 255), 2)
|
||||
if args.show:
|
||||
cv2.imshow('result', image)
|
||||
cv2.waitKey(0)
|
||||
else:
|
||||
cv2.imwrite(f'{out_dir / Path(file).name}', image)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,310 @@
|
|||
from typing import List, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from config import ModelType
|
||||
from numpy import ndarray
|
||||
|
||||
|
||||
def softmax(x: ndarray, axis: int = -1) -> ndarray:
|
||||
e_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
|
||||
y = e_x / e_x.sum(axis=axis, keepdims=True)
|
||||
return y
|
||||
|
||||
|
||||
def sigmoid(x: ndarray) -> ndarray:
|
||||
return 1. / (1. + np.exp(-x))
|
||||
|
||||
|
||||
class Decoder:
|
||||
|
||||
def __init__(self, model_type: ModelType, model_only: bool = False):
|
||||
self.model_type = model_type
|
||||
self.model_only = model_only
|
||||
self.boxes_pro = []
|
||||
self.scores_pro = []
|
||||
self.labels_pro = []
|
||||
self.is_logging = False
|
||||
|
||||
def __call__(self,
|
||||
feats: Union[List, Tuple],
|
||||
conf_thres: float,
|
||||
num_labels: int = 80,
|
||||
**kwargs) -> Tuple:
|
||||
if not self.is_logging:
|
||||
print('Only support decode in batch==1')
|
||||
self.is_logging = True
|
||||
self.boxes_pro.clear()
|
||||
self.scores_pro.clear()
|
||||
self.labels_pro.clear()
|
||||
|
||||
if self.model_only:
|
||||
# transpose channel to last dim for easy decoding
|
||||
feats = [
|
||||
np.ascontiguousarray(feat[0].transpose(1, 2, 0))
|
||||
for feat in feats
|
||||
]
|
||||
else:
|
||||
# ax620a horizonX3 transpose channel to last dim by default
|
||||
feats = [np.ascontiguousarray(feat) for feat in feats]
|
||||
if self.model_type == ModelType.YOLOV5:
|
||||
self.__yolov5_decode(feats, conf_thres, num_labels, **kwargs)
|
||||
elif self.model_type == ModelType.YOLOX:
|
||||
self.__yolox_decode(feats, conf_thres, num_labels, **kwargs)
|
||||
elif self.model_type in (ModelType.PPYOLOE, ModelType.PPYOLOEP):
|
||||
self.__ppyoloe_decode(feats, conf_thres, num_labels, **kwargs)
|
||||
elif self.model_type == ModelType.YOLOV6:
|
||||
self.__yolov6_decode(feats, conf_thres, num_labels, **kwargs)
|
||||
elif self.model_type == ModelType.YOLOV7:
|
||||
self.__yolov7_decode(feats, conf_thres, num_labels, **kwargs)
|
||||
elif self.model_type == ModelType.RTMDET:
|
||||
self.__rtmdet_decode(feats, conf_thres, num_labels, **kwargs)
|
||||
elif self.model_type == ModelType.YOLOV8:
|
||||
self.__yolov8_decode(feats, conf_thres, num_labels, **kwargs)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return self.boxes_pro, self.scores_pro, self.labels_pro
|
||||
|
||||
def __yolov5_decode(self,
|
||||
feats: List[ndarray],
|
||||
conf_thres: float,
|
||||
num_labels: int = 80,
|
||||
**kwargs):
|
||||
anchors: Union[List, Tuple] = kwargs.get(
|
||||
'anchors',
|
||||
[[(10, 13), (16, 30),
|
||||
(33, 23)], [(30, 61), (62, 45),
|
||||
(59, 119)], [(116, 90), (156, 198), (373, 326)]])
|
||||
for i, feat in enumerate(feats):
|
||||
stride = 8 << i
|
||||
feat_h, feat_w, _ = feat.shape
|
||||
anchor = anchors[i]
|
||||
feat = sigmoid(feat)
|
||||
feat = feat.reshape((feat_h, feat_w, len(anchor), -1))
|
||||
box_feat, conf_feat, score_feat = np.split(feat, [4, 5], -1)
|
||||
|
||||
hIdx, wIdx, aIdx, _ = np.where(conf_feat > conf_thres)
|
||||
|
||||
num_proposal = hIdx.size
|
||||
if not num_proposal:
|
||||
continue
|
||||
|
||||
score_feat = score_feat[hIdx, wIdx, aIdx] * conf_feat[hIdx, wIdx,
|
||||
aIdx]
|
||||
boxes = box_feat[hIdx, wIdx, aIdx]
|
||||
labels = score_feat.argmax(-1)
|
||||
scores = score_feat.max(-1)
|
||||
|
||||
indices = np.where(scores > conf_thres)[0]
|
||||
if len(indices) == 0:
|
||||
continue
|
||||
|
||||
for idx in indices:
|
||||
a_w, a_h = anchor[aIdx[idx]]
|
||||
x, y, w, h = boxes[idx]
|
||||
x = (x * 2.0 - 0.5 + wIdx[idx]) * stride
|
||||
y = (y * 2.0 - 0.5 + hIdx[idx]) * stride
|
||||
w = (w * 2.0)**2 * a_w
|
||||
h = (h * 2.0)**2 * a_h
|
||||
|
||||
x0 = x - w / 2
|
||||
y0 = y - h / 2
|
||||
|
||||
self.scores_pro.append(float(scores[idx]))
|
||||
self.boxes_pro.append(
|
||||
np.array([x0, y0, w, h], dtype=np.float32))
|
||||
self.labels_pro.append(int(labels[idx]))
|
||||
|
||||
def __yolox_decode(self,
|
||||
feats: List[ndarray],
|
||||
conf_thres: float,
|
||||
num_labels: int = 80,
|
||||
**kwargs):
|
||||
for i, feat in enumerate(feats):
|
||||
stride = 8 << i
|
||||
score_feat, box_feat, conf_feat = np.split(
|
||||
feat, [num_labels, num_labels + 4], -1)
|
||||
conf_feat = sigmoid(conf_feat)
|
||||
|
||||
hIdx, wIdx, _ = np.where(conf_feat > conf_thres)
|
||||
|
||||
num_proposal = hIdx.size
|
||||
if not num_proposal:
|
||||
continue
|
||||
|
||||
score_feat = sigmoid(score_feat[hIdx, wIdx]) * conf_feat[hIdx,
|
||||
wIdx]
|
||||
boxes = box_feat[hIdx, wIdx]
|
||||
labels = score_feat.argmax(-1)
|
||||
scores = score_feat.max(-1)
|
||||
indices = np.where(scores > conf_thres)[0]
|
||||
|
||||
if len(indices) == 0:
|
||||
continue
|
||||
|
||||
for idx in indices:
|
||||
score = scores[idx]
|
||||
label = labels[idx]
|
||||
|
||||
x, y, w, h = boxes[idx]
|
||||
|
||||
x = (x + wIdx[idx]) * stride
|
||||
y = (y + hIdx[idx]) * stride
|
||||
w = np.exp(w) * stride
|
||||
h = np.exp(h) * stride
|
||||
|
||||
x0 = x - w / 2
|
||||
y0 = y - h / 2
|
||||
|
||||
self.scores_pro.append(float(score))
|
||||
self.boxes_pro.append(
|
||||
np.array([x0, y0, w, h], dtype=np.float32))
|
||||
self.labels_pro.append(int(label))
|
||||
|
||||
def __ppyoloe_decode(self,
|
||||
feats: List[ndarray],
|
||||
conf_thres: float,
|
||||
num_labels: int = 80,
|
||||
**kwargs):
|
||||
reg_max: int = kwargs.get('reg_max', 17)
|
||||
dfl = np.arange(0, reg_max, dtype=np.float32)
|
||||
for i, feat in enumerate(feats):
|
||||
stride = 8 << i
|
||||
score_feat, box_feat = np.split(feat, [
|
||||
num_labels,
|
||||
], -1)
|
||||
score_feat = sigmoid(score_feat)
|
||||
_argmax = score_feat.argmax(-1)
|
||||
_max = score_feat.max(-1)
|
||||
indices = np.where(_max > conf_thres)
|
||||
hIdx, wIdx = indices
|
||||
num_proposal = hIdx.size
|
||||
if not num_proposal:
|
||||
continue
|
||||
|
||||
scores = _max[hIdx, wIdx]
|
||||
boxes = box_feat[hIdx, wIdx].reshape(num_proposal, 4, reg_max)
|
||||
boxes = softmax(boxes, -1) @ dfl
|
||||
labels = _argmax[hIdx, wIdx]
|
||||
|
||||
for k in range(num_proposal):
|
||||
score = scores[k]
|
||||
label = labels[k]
|
||||
|
||||
x0, y0, x1, y1 = boxes[k]
|
||||
|
||||
x0 = (wIdx[k] + 0.5 - x0) * stride
|
||||
y0 = (hIdx[k] + 0.5 - y0) * stride
|
||||
x1 = (wIdx[k] + 0.5 + x1) * stride
|
||||
y1 = (hIdx[k] + 0.5 + y1) * stride
|
||||
|
||||
w = x1 - x0
|
||||
h = y1 - y0
|
||||
|
||||
self.scores_pro.append(float(score))
|
||||
self.boxes_pro.append(
|
||||
np.array([x0, y0, w, h], dtype=np.float32))
|
||||
self.labels_pro.append(int(label))
|
||||
|
||||
def __yolov6_decode(self,
|
||||
feats: List[ndarray],
|
||||
conf_thres: float,
|
||||
num_labels: int = 80,
|
||||
**kwargs):
|
||||
for i, feat in enumerate(feats):
|
||||
stride = 8 << i
|
||||
score_feat, box_feat = np.split(feat, [
|
||||
num_labels,
|
||||
], -1)
|
||||
score_feat = sigmoid(score_feat)
|
||||
_argmax = score_feat.argmax(-1)
|
||||
_max = score_feat.max(-1)
|
||||
indices = np.where(_max > conf_thres)
|
||||
hIdx, wIdx = indices
|
||||
num_proposal = hIdx.size
|
||||
if not num_proposal:
|
||||
continue
|
||||
|
||||
scores = _max[hIdx, wIdx]
|
||||
boxes = box_feat[hIdx, wIdx]
|
||||
labels = _argmax[hIdx, wIdx]
|
||||
|
||||
for k in range(num_proposal):
|
||||
score = scores[k]
|
||||
label = labels[k]
|
||||
|
||||
x0, y0, x1, y1 = boxes[k]
|
||||
|
||||
x0 = (wIdx[k] + 0.5 - x0) * stride
|
||||
y0 = (hIdx[k] + 0.5 - y0) * stride
|
||||
x1 = (wIdx[k] + 0.5 + x1) * stride
|
||||
y1 = (hIdx[k] + 0.5 + y1) * stride
|
||||
|
||||
w = x1 - x0
|
||||
h = y1 - y0
|
||||
|
||||
self.scores_pro.append(float(score))
|
||||
self.boxes_pro.append(
|
||||
np.array([x0, y0, w, h], dtype=np.float32))
|
||||
self.labels_pro.append(int(label))
|
||||
|
||||
def __yolov7_decode(self,
|
||||
feats: List[ndarray],
|
||||
conf_thres: float,
|
||||
num_labels: int = 80,
|
||||
**kwargs):
|
||||
anchors: Union[List, Tuple] = kwargs.get(
|
||||
'anchors',
|
||||
[[(12, 16), (19, 36),
|
||||
(40, 28)], [(36, 75), (76, 55),
|
||||
(72, 146)], [(142, 110), (192, 243), (459, 401)]])
|
||||
self.__yolov5_decode(feats, conf_thres, num_labels, anchors=anchors)
|
||||
|
||||
def __rtmdet_decode(self,
|
||||
feats: List[ndarray],
|
||||
conf_thres: float,
|
||||
num_labels: int = 80,
|
||||
**kwargs):
|
||||
for i, feat in enumerate(feats):
|
||||
stride = 8 << i
|
||||
score_feat, box_feat = np.split(feat, [
|
||||
num_labels,
|
||||
], -1)
|
||||
score_feat = sigmoid(score_feat)
|
||||
_argmax = score_feat.argmax(-1)
|
||||
_max = score_feat.max(-1)
|
||||
indices = np.where(_max > conf_thres)
|
||||
hIdx, wIdx = indices
|
||||
num_proposal = hIdx.size
|
||||
if not num_proposal:
|
||||
continue
|
||||
|
||||
scores = _max[hIdx, wIdx]
|
||||
boxes = box_feat[hIdx, wIdx]
|
||||
labels = _argmax[hIdx, wIdx]
|
||||
|
||||
for k in range(num_proposal):
|
||||
score = scores[k]
|
||||
label = labels[k]
|
||||
|
||||
x0, y0, x1, y1 = boxes[k]
|
||||
|
||||
x0 = (wIdx[k] - x0) * stride
|
||||
y0 = (hIdx[k] - y0) * stride
|
||||
x1 = (wIdx[k] + x1) * stride
|
||||
y1 = (hIdx[k] + y1) * stride
|
||||
|
||||
w = x1 - x0
|
||||
h = y1 - y0
|
||||
|
||||
self.scores_pro.append(float(score))
|
||||
self.boxes_pro.append(
|
||||
np.array([x0, y0, w, h], dtype=np.float32))
|
||||
self.labels_pro.append(int(label))
|
||||
|
||||
def __yolov8_decode(self,
|
||||
feats: List[ndarray],
|
||||
conf_thres: float,
|
||||
num_labels: int = 80,
|
||||
**kwargs):
|
||||
reg_max: int = kwargs.get('reg_max', 16)
|
||||
self.__ppyoloe_decode(feats, conf_thres, num_labels, reg_max=reg_max)
|
|
@ -0,0 +1,57 @@
|
|||
from typing import List, Tuple, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from config import ModelType
|
||||
from numpy import ndarray
|
||||
|
||||
|
||||
class Preprocess:
|
||||
|
||||
def __init__(self, model_type: ModelType):
|
||||
if model_type in (ModelType.YOLOV5, ModelType.YOLOV6, ModelType.YOLOV7,
|
||||
ModelType.YOLOV8):
|
||||
mean = np.array([0, 0, 0], dtype=np.float32)
|
||||
std = np.array([255, 255, 255], dtype=np.float32)
|
||||
is_rgb = True
|
||||
elif model_type == ModelType.YOLOX:
|
||||
mean = np.array([0, 0, 0], dtype=np.float32)
|
||||
std = np.array([1, 1, 1], dtype=np.float32)
|
||||
is_rgb = False
|
||||
elif model_type == ModelType.PPYOLOE:
|
||||
mean = np.array([123.675, 116.28, 103.53], dtype=np.float32)
|
||||
std = np.array([58.395, 57.12, 57.375], dtype=np.float32)
|
||||
is_rgb = True
|
||||
|
||||
elif model_type == ModelType.PPYOLOEP:
|
||||
mean = np.array([0, 0, 0], dtype=np.float32)
|
||||
std = np.array([255, 255, 255], dtype=np.float32)
|
||||
is_rgb = True
|
||||
elif model_type == ModelType.RTMDET:
|
||||
mean = np.array([103.53, 116.28, 123.675], dtype=np.float32)
|
||||
std = np.array([57.375, 57.12, 58.3955], dtype=np.float32)
|
||||
is_rgb = False
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
self.mean = mean.reshape((3, 1, 1))
|
||||
self.std = std.reshape((3, 1, 1))
|
||||
self.is_rgb = is_rgb
|
||||
|
||||
def __call__(self,
|
||||
image: ndarray,
|
||||
new_size: Union[List[int], Tuple[int]] = (640, 640),
|
||||
**kwargs) -> Tuple[ndarray, Tuple[float, float]]:
|
||||
# new_size: (height, width)
|
||||
height, width = image.shape[:2]
|
||||
ratio_h, ratio_w = new_size[0] / height, new_size[1] / width
|
||||
image = cv2.resize(
|
||||
image, (0, 0),
|
||||
fx=ratio_w,
|
||||
fy=ratio_h,
|
||||
interpolation=cv2.INTER_LINEAR)
|
||||
image = np.ascontiguousarray(image.transpose(2, 0, 1))
|
||||
image = image.astype(np.float32)
|
||||
image -= self.mean
|
||||
image /= self.std
|
||||
return image[np.newaxis], (ratio_w, ratio_h)
|
|
@ -0,0 +1,2 @@
|
|||
onnxruntime
|
||||
opencv-python==4.7.0.72
|
Loading…
Reference in New Issue