mirror of https://github.com/open-mmlab/mmyolo.git
311 lines
11 KiB
Python
311 lines
11 KiB
Python
from typing import List, Tuple, Union
|
|
|
|
import numpy as np
|
|
from config import ModelType
|
|
from numpy import ndarray
|
|
|
|
|
|
def softmax(x: ndarray, axis: int = -1) -> ndarray:
|
|
e_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
|
|
y = e_x / e_x.sum(axis=axis, keepdims=True)
|
|
return y
|
|
|
|
|
|
def sigmoid(x: ndarray) -> ndarray:
|
|
return 1. / (1. + np.exp(-x))
|
|
|
|
|
|
class Decoder:
|
|
|
|
def __init__(self, model_type: ModelType, model_only: bool = False):
|
|
self.model_type = model_type
|
|
self.model_only = model_only
|
|
self.boxes_pro = []
|
|
self.scores_pro = []
|
|
self.labels_pro = []
|
|
self.is_logging = False
|
|
|
|
def __call__(self,
|
|
feats: Union[List, Tuple],
|
|
conf_thres: float,
|
|
num_labels: int = 80,
|
|
**kwargs) -> Tuple:
|
|
if not self.is_logging:
|
|
print('Only support decode in batch==1')
|
|
self.is_logging = True
|
|
self.boxes_pro.clear()
|
|
self.scores_pro.clear()
|
|
self.labels_pro.clear()
|
|
|
|
if self.model_only:
|
|
# transpose channel to last dim for easy decoding
|
|
feats = [
|
|
np.ascontiguousarray(feat[0].transpose(1, 2, 0))
|
|
for feat in feats
|
|
]
|
|
else:
|
|
# ax620a horizonX3 transpose channel to last dim by default
|
|
feats = [np.ascontiguousarray(feat) for feat in feats]
|
|
if self.model_type == ModelType.YOLOV5:
|
|
self.__yolov5_decode(feats, conf_thres, num_labels, **kwargs)
|
|
elif self.model_type == ModelType.YOLOX:
|
|
self.__yolox_decode(feats, conf_thres, num_labels, **kwargs)
|
|
elif self.model_type in (ModelType.PPYOLOE, ModelType.PPYOLOEP):
|
|
self.__ppyoloe_decode(feats, conf_thres, num_labels, **kwargs)
|
|
elif self.model_type == ModelType.YOLOV6:
|
|
self.__yolov6_decode(feats, conf_thres, num_labels, **kwargs)
|
|
elif self.model_type == ModelType.YOLOV7:
|
|
self.__yolov7_decode(feats, conf_thres, num_labels, **kwargs)
|
|
elif self.model_type == ModelType.RTMDET:
|
|
self.__rtmdet_decode(feats, conf_thres, num_labels, **kwargs)
|
|
elif self.model_type == ModelType.YOLOV8:
|
|
self.__yolov8_decode(feats, conf_thres, num_labels, **kwargs)
|
|
else:
|
|
raise NotImplementedError
|
|
return self.boxes_pro, self.scores_pro, self.labels_pro
|
|
|
|
def __yolov5_decode(self,
|
|
feats: List[ndarray],
|
|
conf_thres: float,
|
|
num_labels: int = 80,
|
|
**kwargs):
|
|
anchors: Union[List, Tuple] = kwargs.get(
|
|
'anchors',
|
|
[[(10, 13), (16, 30),
|
|
(33, 23)], [(30, 61), (62, 45),
|
|
(59, 119)], [(116, 90), (156, 198), (373, 326)]])
|
|
for i, feat in enumerate(feats):
|
|
stride = 8 << i
|
|
feat_h, feat_w, _ = feat.shape
|
|
anchor = anchors[i]
|
|
feat = sigmoid(feat)
|
|
feat = feat.reshape((feat_h, feat_w, len(anchor), -1))
|
|
box_feat, conf_feat, score_feat = np.split(feat, [4, 5], -1)
|
|
|
|
hIdx, wIdx, aIdx, _ = np.where(conf_feat > conf_thres)
|
|
|
|
num_proposal = hIdx.size
|
|
if not num_proposal:
|
|
continue
|
|
|
|
score_feat = score_feat[hIdx, wIdx, aIdx] * conf_feat[hIdx, wIdx,
|
|
aIdx]
|
|
boxes = box_feat[hIdx, wIdx, aIdx]
|
|
labels = score_feat.argmax(-1)
|
|
scores = score_feat.max(-1)
|
|
|
|
indices = np.where(scores > conf_thres)[0]
|
|
if len(indices) == 0:
|
|
continue
|
|
|
|
for idx in indices:
|
|
a_w, a_h = anchor[aIdx[idx]]
|
|
x, y, w, h = boxes[idx]
|
|
x = (x * 2.0 - 0.5 + wIdx[idx]) * stride
|
|
y = (y * 2.0 - 0.5 + hIdx[idx]) * stride
|
|
w = (w * 2.0)**2 * a_w
|
|
h = (h * 2.0)**2 * a_h
|
|
|
|
x0 = x - w / 2
|
|
y0 = y - h / 2
|
|
|
|
self.scores_pro.append(float(scores[idx]))
|
|
self.boxes_pro.append(
|
|
np.array([x0, y0, w, h], dtype=np.float32))
|
|
self.labels_pro.append(int(labels[idx]))
|
|
|
|
def __yolox_decode(self,
|
|
feats: List[ndarray],
|
|
conf_thres: float,
|
|
num_labels: int = 80,
|
|
**kwargs):
|
|
for i, feat in enumerate(feats):
|
|
stride = 8 << i
|
|
score_feat, box_feat, conf_feat = np.split(
|
|
feat, [num_labels, num_labels + 4], -1)
|
|
conf_feat = sigmoid(conf_feat)
|
|
|
|
hIdx, wIdx, _ = np.where(conf_feat > conf_thres)
|
|
|
|
num_proposal = hIdx.size
|
|
if not num_proposal:
|
|
continue
|
|
|
|
score_feat = sigmoid(score_feat[hIdx, wIdx]) * conf_feat[hIdx,
|
|
wIdx]
|
|
boxes = box_feat[hIdx, wIdx]
|
|
labels = score_feat.argmax(-1)
|
|
scores = score_feat.max(-1)
|
|
indices = np.where(scores > conf_thres)[0]
|
|
|
|
if len(indices) == 0:
|
|
continue
|
|
|
|
for idx in indices:
|
|
score = scores[idx]
|
|
label = labels[idx]
|
|
|
|
x, y, w, h = boxes[idx]
|
|
|
|
x = (x + wIdx[idx]) * stride
|
|
y = (y + hIdx[idx]) * stride
|
|
w = np.exp(w) * stride
|
|
h = np.exp(h) * stride
|
|
|
|
x0 = x - w / 2
|
|
y0 = y - h / 2
|
|
|
|
self.scores_pro.append(float(score))
|
|
self.boxes_pro.append(
|
|
np.array([x0, y0, w, h], dtype=np.float32))
|
|
self.labels_pro.append(int(label))
|
|
|
|
def __ppyoloe_decode(self,
|
|
feats: List[ndarray],
|
|
conf_thres: float,
|
|
num_labels: int = 80,
|
|
**kwargs):
|
|
reg_max: int = kwargs.get('reg_max', 17)
|
|
dfl = np.arange(0, reg_max, dtype=np.float32)
|
|
for i, feat in enumerate(feats):
|
|
stride = 8 << i
|
|
score_feat, box_feat = np.split(feat, [
|
|
num_labels,
|
|
], -1)
|
|
score_feat = sigmoid(score_feat)
|
|
_argmax = score_feat.argmax(-1)
|
|
_max = score_feat.max(-1)
|
|
indices = np.where(_max > conf_thres)
|
|
hIdx, wIdx = indices
|
|
num_proposal = hIdx.size
|
|
if not num_proposal:
|
|
continue
|
|
|
|
scores = _max[hIdx, wIdx]
|
|
boxes = box_feat[hIdx, wIdx].reshape(num_proposal, 4, reg_max)
|
|
boxes = softmax(boxes, -1) @ dfl
|
|
labels = _argmax[hIdx, wIdx]
|
|
|
|
for k in range(num_proposal):
|
|
score = scores[k]
|
|
label = labels[k]
|
|
|
|
x0, y0, x1, y1 = boxes[k]
|
|
|
|
x0 = (wIdx[k] + 0.5 - x0) * stride
|
|
y0 = (hIdx[k] + 0.5 - y0) * stride
|
|
x1 = (wIdx[k] + 0.5 + x1) * stride
|
|
y1 = (hIdx[k] + 0.5 + y1) * stride
|
|
|
|
w = x1 - x0
|
|
h = y1 - y0
|
|
|
|
self.scores_pro.append(float(score))
|
|
self.boxes_pro.append(
|
|
np.array([x0, y0, w, h], dtype=np.float32))
|
|
self.labels_pro.append(int(label))
|
|
|
|
def __yolov6_decode(self,
|
|
feats: List[ndarray],
|
|
conf_thres: float,
|
|
num_labels: int = 80,
|
|
**kwargs):
|
|
for i, feat in enumerate(feats):
|
|
stride = 8 << i
|
|
score_feat, box_feat = np.split(feat, [
|
|
num_labels,
|
|
], -1)
|
|
score_feat = sigmoid(score_feat)
|
|
_argmax = score_feat.argmax(-1)
|
|
_max = score_feat.max(-1)
|
|
indices = np.where(_max > conf_thres)
|
|
hIdx, wIdx = indices
|
|
num_proposal = hIdx.size
|
|
if not num_proposal:
|
|
continue
|
|
|
|
scores = _max[hIdx, wIdx]
|
|
boxes = box_feat[hIdx, wIdx]
|
|
labels = _argmax[hIdx, wIdx]
|
|
|
|
for k in range(num_proposal):
|
|
score = scores[k]
|
|
label = labels[k]
|
|
|
|
x0, y0, x1, y1 = boxes[k]
|
|
|
|
x0 = (wIdx[k] + 0.5 - x0) * stride
|
|
y0 = (hIdx[k] + 0.5 - y0) * stride
|
|
x1 = (wIdx[k] + 0.5 + x1) * stride
|
|
y1 = (hIdx[k] + 0.5 + y1) * stride
|
|
|
|
w = x1 - x0
|
|
h = y1 - y0
|
|
|
|
self.scores_pro.append(float(score))
|
|
self.boxes_pro.append(
|
|
np.array([x0, y0, w, h], dtype=np.float32))
|
|
self.labels_pro.append(int(label))
|
|
|
|
def __yolov7_decode(self,
|
|
feats: List[ndarray],
|
|
conf_thres: float,
|
|
num_labels: int = 80,
|
|
**kwargs):
|
|
anchors: Union[List, Tuple] = kwargs.get(
|
|
'anchors',
|
|
[[(12, 16), (19, 36),
|
|
(40, 28)], [(36, 75), (76, 55),
|
|
(72, 146)], [(142, 110), (192, 243), (459, 401)]])
|
|
self.__yolov5_decode(feats, conf_thres, num_labels, anchors=anchors)
|
|
|
|
def __rtmdet_decode(self,
|
|
feats: List[ndarray],
|
|
conf_thres: float,
|
|
num_labels: int = 80,
|
|
**kwargs):
|
|
for i, feat in enumerate(feats):
|
|
stride = 8 << i
|
|
score_feat, box_feat = np.split(feat, [
|
|
num_labels,
|
|
], -1)
|
|
score_feat = sigmoid(score_feat)
|
|
_argmax = score_feat.argmax(-1)
|
|
_max = score_feat.max(-1)
|
|
indices = np.where(_max > conf_thres)
|
|
hIdx, wIdx = indices
|
|
num_proposal = hIdx.size
|
|
if not num_proposal:
|
|
continue
|
|
|
|
scores = _max[hIdx, wIdx]
|
|
boxes = box_feat[hIdx, wIdx]
|
|
labels = _argmax[hIdx, wIdx]
|
|
|
|
for k in range(num_proposal):
|
|
score = scores[k]
|
|
label = labels[k]
|
|
|
|
x0, y0, x1, y1 = boxes[k]
|
|
|
|
x0 = (wIdx[k] - x0) * stride
|
|
y0 = (hIdx[k] - y0) * stride
|
|
x1 = (wIdx[k] + x1) * stride
|
|
y1 = (hIdx[k] + y1) * stride
|
|
|
|
w = x1 - x0
|
|
h = y1 - y0
|
|
|
|
self.scores_pro.append(float(score))
|
|
self.boxes_pro.append(
|
|
np.array([x0, y0, w, h], dtype=np.float32))
|
|
self.labels_pro.append(int(label))
|
|
|
|
def __yolov8_decode(self,
|
|
feats: List[ndarray],
|
|
conf_thres: float,
|
|
num_labels: int = 80,
|
|
**kwargs):
|
|
reg_max: int = kwargs.get('reg_max', 16)
|
|
self.__ppyoloe_decode(feats, conf_thres, num_labels, reg_max=reg_max)
|