mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
[Feature] Support end2end mmdet2.19 retina mobilessd (#286)
* support end2end mmdet2.19 retina mobilessd * fix yapf * add end2end fsaf * fix lint * fix comments * fix lint * add static configs * fix docformatter * move ssdhead * add rewrite for l2norm * fix ncnn ssd * fix isort * rename config * add ssd_head_ut * fix string * align ssd * remove unused bbox rewriter Co-authored-by: grimoire <yaoqian@sensetime.com> Co-authored-by: maningsheng <mnsheng@yeah.net>
This commit is contained in:
parent
31e8aed862
commit
3e8237d8bb
@ -0,0 +1,4 @@
|
||||
_base_ = ['../_base_/base_static.py', '../../_base_/backends/ncnn.py']
|
||||
|
||||
codebase_config = dict(model_type='ncnn_end2end')
|
||||
onnx_config = dict(output_names=['detection_output'], input_shape=[300, 300])
|
@ -0,0 +1,4 @@
|
||||
_base_ = ['../_base_/base_static.py', '../../_base_/backends/ncnn.py']
|
||||
|
||||
codebase_config = dict(model_type='ncnn_end2end')
|
||||
onnx_config = dict(output_names=['detection_output'], input_shape=[1344, 800])
|
@ -3765,11 +3765,16 @@ int main(int argc, char** argv) {
|
||||
int nms_top_k = get_node_attr_i(node, "nms_top_k");
|
||||
int keep_top_k = get_node_attr_i(node, "keep_top_k");
|
||||
int num_class = get_node_attr_i(node, "num_class");
|
||||
std::vector<float> vars = get_node_attr_af(node, "vars");
|
||||
fprintf(pp, " 0=%d", num_class);
|
||||
fprintf(pp, " 1=%f", nms_threshold);
|
||||
fprintf(pp, " 2=%d", nms_top_k);
|
||||
fprintf(pp, " 3=%d", keep_top_k);
|
||||
fprintf(pp, " 4=%f", score_threshold);
|
||||
fprintf(pp, " 5=%f", vars[0]);
|
||||
fprintf(pp, " 6=%f", vars[1]);
|
||||
fprintf(pp, " 7=%f", vars[2]);
|
||||
fprintf(pp, " 8=%f", vars[3]);
|
||||
} else if (op == "Div") {
|
||||
int op_type = 3;
|
||||
fprintf(pp, " 0=%d", op_type);
|
||||
@ -4660,10 +4665,14 @@ int main(int argc, char** argv) {
|
||||
}
|
||||
int image_width = get_node_attr_i(node, "image_width");
|
||||
int image_height = get_node_attr_i(node, "image_height");
|
||||
float step_width = get_node_attr_f(node, "step_width");
|
||||
float step_height = get_node_attr_f(node, "step_height");
|
||||
float offset = get_node_attr_f(node, "offset");
|
||||
int step_mmdetection = get_node_attr_i(node, "step_mmdetection");
|
||||
fprintf(pp, " 9=%d", image_width);
|
||||
fprintf(pp, " 10=%d", image_height);
|
||||
fprintf(pp, " 11=%f", step_width);
|
||||
fprintf(pp, " 12=%f", step_height);
|
||||
fprintf(pp, " 13=%f", offset);
|
||||
fprintf(pp, " 14=%d", step_mmdetection);
|
||||
} else if (op == "PixelShuffle") {
|
||||
|
@ -141,124 +141,3 @@ def delta2bbox(ctx,
|
||||
|
||||
bboxes = torch.stack([x1, y1, x2, y2], dim=-1).view(deltas.size())
|
||||
return bboxes
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdet.core.bbox.coder.delta_xywh_bbox_coder.delta2bbox', # noqa
|
||||
backend='ncnn')
|
||||
def delta2bbox__ncnn(ctx,
|
||||
rois,
|
||||
deltas,
|
||||
means=(0., 0., 0., 0.),
|
||||
stds=(1., 1., 1., 1.),
|
||||
max_shape=None,
|
||||
wh_ratio_clip=16 / 1000,
|
||||
clip_border=True,
|
||||
add_ctr_clamp=False,
|
||||
ctr_clamp=32):
|
||||
"""Rewrite `delta2bbox` for ncnn backend.
|
||||
|
||||
Batch dimension is not supported by ncnn, but supported by pytorch.
|
||||
NCNN regards the lowest two dimensions as continuous address with byte
|
||||
alignment, so the lowest two dimensions are not absolutely independent.
|
||||
Reshape operator with -1 arguments should operates ncnn::Mat with
|
||||
dimension >= 3.
|
||||
|
||||
Args:
|
||||
ctx (ContextCaller): The context with additional information.
|
||||
rois (Tensor): Boxes to be transformed. Has shape (N, 4) or (B, N, 4)
|
||||
deltas (Tensor): Encoded offsets with respect to each roi.
|
||||
Has shape (B, N, num_classes * 4) or (B, N, 4) or
|
||||
(N, num_classes * 4) or (N, 4). Note N = num_anchors * W * H
|
||||
when rois is a grid of anchors.Offset encoding follows [1]_.
|
||||
means (Sequence[float]): Denormalizing means for delta coordinates
|
||||
stds (Sequence[float]): Denormalizing standard deviation for delta
|
||||
coordinates
|
||||
max_shape (Sequence[int] or torch.Tensor or Sequence[
|
||||
Sequence[int]],optional): Maximum bounds for boxes, specifies
|
||||
(H, W, C) or (H, W). If rois shape is (B, N, 4), then
|
||||
the max_shape should be a Sequence[Sequence[int]]
|
||||
and the length of max_shape should also be B.
|
||||
wh_ratio_clip (float): Maximum aspect ratio for boxes.
|
||||
clip_border (bool, optional): Whether clip the objects outside the
|
||||
border of the image. Defaults to True.
|
||||
add_ctr_clamp (bool): Whether to add center clamp, when added, the
|
||||
predicted box is clamped is its center is too far away from
|
||||
the original anchor's center. Only used by YOLOF. Default False.
|
||||
ctr_clamp (int): the maximum pixel shift to clamp. Only used by YOLOF.
|
||||
Default 32.
|
||||
|
||||
Return:
|
||||
bboxes (Tensor): Boxes with shape (B, N, num_classes * 4) or (B, N, 4)
|
||||
or (N, num_classes * 4) or (N, 4), where 4 represent tl_x, tl_y,
|
||||
br_x, br_y.
|
||||
"""
|
||||
means = deltas.new_tensor(means).view(1, 1,
|
||||
-1).repeat(1, deltas.size(-2),
|
||||
deltas.size(-1) // 4).data
|
||||
stds = deltas.new_tensor(stds).view(1, 1,
|
||||
-1).repeat(1, deltas.size(-2),
|
||||
deltas.size(-1) // 4).data
|
||||
denorm_deltas = deltas * stds + means
|
||||
if denorm_deltas.shape[-1] == 4:
|
||||
dx = denorm_deltas[..., 0:1]
|
||||
dy = denorm_deltas[..., 1:2]
|
||||
dw = denorm_deltas[..., 2:3]
|
||||
dh = denorm_deltas[..., 3:4]
|
||||
else:
|
||||
dx = denorm_deltas[..., 0::4]
|
||||
dy = denorm_deltas[..., 1::4]
|
||||
dw = denorm_deltas[..., 2::4]
|
||||
dh = denorm_deltas[..., 3::4]
|
||||
|
||||
x1, y1 = rois[..., 0:1], rois[..., 1:2]
|
||||
x2, y2 = rois[..., 2:3], rois[..., 3:4]
|
||||
|
||||
# Compute center of each roi
|
||||
px = (x1 + x2) * 0.5
|
||||
py = (y1 + y2) * 0.5
|
||||
# Compute width/height of each roi
|
||||
pw = x2 - x1
|
||||
ph = y2 - y1
|
||||
|
||||
# do not use expand unless necessary
|
||||
# since expand is a custom ops
|
||||
if px.shape[-1] != 4:
|
||||
px = px.expand_as(dx)
|
||||
if py.shape[-1] != 4:
|
||||
py = py.expand_as(dy)
|
||||
if pw.shape[-1] != 4:
|
||||
pw = pw.expand_as(dw)
|
||||
if px.shape[-1] != 4:
|
||||
ph = ph.expand_as(dh)
|
||||
|
||||
dx_width = pw * dx
|
||||
dy_height = ph * dy
|
||||
|
||||
max_ratio = np.abs(np.log(wh_ratio_clip))
|
||||
if add_ctr_clamp:
|
||||
dx_width = torch.clamp(dx_width, max=ctr_clamp, min=-ctr_clamp)
|
||||
dy_height = torch.clamp(dy_height, max=ctr_clamp, min=-ctr_clamp)
|
||||
dw = torch.clamp(dw, max=max_ratio)
|
||||
dh = torch.clamp(dh, max=max_ratio)
|
||||
else:
|
||||
dw = dw.clamp(min=-max_ratio, max=max_ratio)
|
||||
dh = dh.clamp(min=-max_ratio, max=max_ratio)
|
||||
# Use exp(network energy) to enlarge/shrink each roi
|
||||
gw = pw * dw.exp()
|
||||
gh = ph * dh.exp()
|
||||
# Use network energy to shift the center of each roi
|
||||
gx = px + dx_width
|
||||
gy = py + dy_height
|
||||
# Convert center-xy/width/height to top-left, bottom-right
|
||||
x1 = gx - gw * 0.5
|
||||
y1 = gy - gh * 0.5
|
||||
x2 = gx + gw * 0.5
|
||||
y2 = gy + gh * 0.5
|
||||
|
||||
if clip_border and max_shape is not None:
|
||||
from mmdeploy.codebase.mmdet.deploy import clip_bboxes
|
||||
x1, y1, x2, y2 = clip_bboxes(x1, y1, x2, y2, max_shape)
|
||||
|
||||
bboxes = torch.stack([x1, y1, x2, y2], dim=-1).view(deltas.size())
|
||||
return bboxes
|
||||
|
@ -71,75 +71,3 @@ def tblr2bboxes(ctx,
|
||||
bboxes = torch.cat([xmin, ymin, xmax, ymax], dim=-1).view(priors.size())
|
||||
|
||||
return bboxes
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdet.core.bbox.coder.tblr_bbox_coder.tblr2bboxes',
|
||||
backend='ncnn')
|
||||
def tblr2bboxes__ncnn(ctx,
|
||||
priors,
|
||||
tblr,
|
||||
normalizer=4.0,
|
||||
normalize_by_wh=True,
|
||||
max_shape=None,
|
||||
clip_border=True):
|
||||
"""Rewrite `tblr2bboxes` for ncnn backend.
|
||||
|
||||
Batch dimension is not supported by ncnn, but supported by pytorch.
|
||||
The negative value of axis in torch.cat is rewritten as corresponding
|
||||
positive value to avoid axis shift.
|
||||
|
||||
Args:
|
||||
ctx (ContextCaller): The context with additional information.
|
||||
priors (Tensor): Prior boxes in point form (x0, y0, x1, y1)
|
||||
Shape: (N,4) or (B, N, 4).
|
||||
tblr (Tensor): Coords of network output in tblr form
|
||||
Shape: (N, 4) or (B, N, 4).
|
||||
normalizer (Sequence[float] | float): Normalization parameter of
|
||||
encoded boxes. By list, it represents the normalization factors at
|
||||
tblr dims. By float, it is the unified normalization factor at all
|
||||
dims. Default: 4.0
|
||||
normalize_by_wh (bool): Whether the tblr coordinates have been
|
||||
normalized by the side length (wh) of prior bboxes.
|
||||
max_shape (Sequence[int] or torch.Tensor or Sequence[
|
||||
Sequence[int]],optional): Maximum bounds for boxes, specifies
|
||||
(H, W, C) or (H, W). If priors shape is (B, N, 4), then
|
||||
the max_shape should be a Sequence[Sequence[int]]
|
||||
and the length of max_shape should also be B.
|
||||
clip_border (bool, optional): Whether clip the objects outside the
|
||||
border of the image. Defaults to True.
|
||||
|
||||
Return:
|
||||
bboxes (Tensor): Boxes with shape (N, 4) or (B, N, 4)
|
||||
"""
|
||||
assert priors.size(0) == tblr.size(0)
|
||||
if priors.ndim == 3:
|
||||
assert priors.size(1) == tblr.size(1)
|
||||
|
||||
loc_decode = tblr * normalizer
|
||||
prior_centers = (priors[..., 0:2] + priors[..., 2:4]) / 2
|
||||
if normalize_by_wh:
|
||||
w = priors[..., 2:3] - priors[..., 0:1]
|
||||
h = priors[..., 3:4] - priors[..., 1:2]
|
||||
_h = h.unsqueeze(0).unsqueeze(-1)
|
||||
_loc_h = loc_decode[..., 0:2].unsqueeze(0).unsqueeze(-1)
|
||||
_w = w.unsqueeze(0).unsqueeze(-1)
|
||||
_loc_w = loc_decode[..., 2:4].unsqueeze(0).unsqueeze(-1)
|
||||
th = (_h * _loc_h).reshape(1, -1, 2)
|
||||
tw = (_w * _loc_w).reshape(1, -1, 2)
|
||||
loc_decode = torch.cat([th, tw], dim=2)
|
||||
top = loc_decode[..., 0:1]
|
||||
bottom = loc_decode[..., 1:2]
|
||||
left = loc_decode[..., 2:3]
|
||||
right = loc_decode[..., 3:4]
|
||||
xmin = prior_centers[..., 0].unsqueeze(-1) - left
|
||||
xmax = prior_centers[..., 0].unsqueeze(-1) + right
|
||||
ymin = prior_centers[..., 1].unsqueeze(-1) - top
|
||||
ymax = prior_centers[..., 1].unsqueeze(-1) + bottom
|
||||
|
||||
if clip_border and max_shape is not None:
|
||||
from mmdeploy.codebase.mmdet.deploy import clip_bboxes
|
||||
xmin, ymin, xmax, ymax = clip_bboxes(xmin, ymin, xmax, ymax, max_shape)
|
||||
bboxes = torch.cat([xmin, ymin, xmax, ymax], dim=-1).view(priors.size())
|
||||
|
||||
return bboxes
|
||||
|
@ -37,7 +37,8 @@ class NcnnDetectionOutputOp(torch.autograd.Function):
|
||||
nms_threshold=0.45,
|
||||
nms_top_k=100,
|
||||
keep_top_k=100,
|
||||
num_class=81):
|
||||
num_class=81,
|
||||
target_stds=[0.1, 0.1, 0.2, 0.2]):
|
||||
"""Symbolic function of dummy onnx DetectionOutput op for ncnn."""
|
||||
return g.op(
|
||||
'mmdeploy::DetectionOutput',
|
||||
@ -49,6 +50,7 @@ class NcnnDetectionOutputOp(torch.autograd.Function):
|
||||
nms_top_k_i=nms_top_k,
|
||||
keep_top_k_i=keep_top_k,
|
||||
num_class_i=num_class,
|
||||
vars_f=target_stds,
|
||||
outputs=1)
|
||||
|
||||
@staticmethod
|
||||
@ -60,7 +62,8 @@ class NcnnDetectionOutputOp(torch.autograd.Function):
|
||||
nms_threshold=0.45,
|
||||
nms_top_k=100,
|
||||
keep_top_k=100,
|
||||
num_class=81):
|
||||
num_class=81,
|
||||
target_stds=[0.1, 0.1, 0.2, 0.2]):
|
||||
"""Forward function of dummy onnx DetectionOutput op for ncnn."""
|
||||
return torch.rand(1, 100, 6)
|
||||
|
||||
|
@ -40,6 +40,8 @@ class NcnnPriorBoxOp(torch.autograd.Function):
|
||||
aspect_ratios=[2, 3],
|
||||
image_height=300,
|
||||
image_width=300,
|
||||
step_height=300,
|
||||
step_width=300,
|
||||
max_sizes=[300],
|
||||
min_sizes=[285],
|
||||
offset=0.5,
|
||||
@ -51,6 +53,8 @@ class NcnnPriorBoxOp(torch.autograd.Function):
|
||||
aspect_ratios_f=aspect_ratios,
|
||||
image_height_i=image_height,
|
||||
image_width_i=image_width,
|
||||
step_height_f=step_height,
|
||||
step_width_f=step_width,
|
||||
max_sizes_f=max_sizes,
|
||||
min_sizes_f=min_sizes,
|
||||
offset_f=offset,
|
||||
@ -63,6 +67,8 @@ class NcnnPriorBoxOp(torch.autograd.Function):
|
||||
aspect_ratios=[2, 3],
|
||||
image_height=300,
|
||||
image_width=300,
|
||||
step_height=300,
|
||||
step_width=300,
|
||||
max_sizes=[300],
|
||||
min_sizes=[285],
|
||||
offset=0.5,
|
||||
|
@ -1,4 +1,5 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .dense_heads import * # noqa: F401,F403
|
||||
from .detectors import * # noqa: F401,F403
|
||||
from .necks import * # noqa: F401,F403
|
||||
from .roi_heads import * # noqa: F401,F403
|
||||
|
@ -1,16 +1,17 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .anchor_head import anchor_head__get_bboxes__ncnn
|
||||
from .base_dense_head import base_dense_head__get_bbox
|
||||
from .base_dense_head import (base_dense_head__get_bbox,
|
||||
base_dense_head__get_bboxes__ncnn)
|
||||
from .fcos_head import fcos_head__get_bboxes__ncnn
|
||||
from .fovea_head import fovea_head__get_bboxes
|
||||
from .rpn_head import rpn_head__get_bboxes, rpn_head__get_bboxes__ncnn
|
||||
from .ssd_head import ssd_head__get_bboxes__ncnn
|
||||
from .yolo_head import yolov3_head__get_bboxes, yolov3_head__get_bboxes__ncnn
|
||||
from .yolox_head import yolox_head__get_bboxes
|
||||
|
||||
__all__ = [
|
||||
'anchor_head__get_bboxes__ncnn', 'fcos_head__get_bboxes__ncnn',
|
||||
'rpn_head__get_bboxes', 'rpn_head__get_bboxes__ncnn',
|
||||
'yolov3_head__get_bboxes', 'yolov3_head__get_bboxes__ncnn',
|
||||
'yolox_head__get_bboxes', 'base_dense_head__get_bbox',
|
||||
'fovea_head__get_bboxes'
|
||||
'fcos_head__get_bboxes__ncnn', 'rpn_head__get_bboxes',
|
||||
'rpn_head__get_bboxes__ncnn', 'yolov3_head__get_bboxes',
|
||||
'yolov3_head__get_bboxes__ncnn', 'yolox_head__get_bboxes',
|
||||
'base_dense_head__get_bbox', 'fovea_head__get_bboxes',
|
||||
'base_dense_head__get_bboxes__ncnn', 'ssd_head__get_bboxes__ncnn'
|
||||
]
|
||||
|
@ -1,7 +1,10 @@
|
||||
import torch
|
||||
from mmdet.core.bbox.coder.delta_xywh_bbox_coder import DeltaXYWHBBoxCoder
|
||||
from mmdet.core.bbox.coder.tblr_bbox_coder import TBLRBBoxCoder
|
||||
|
||||
from mmdeploy.codebase.mmdet import (get_post_processing_params,
|
||||
multiclass_nms, pad_with_value)
|
||||
from mmdeploy.codebase.mmdet.core.ops import ncnn_detection_output_forward
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
from mmdeploy.utils import Backend, get_backend, is_dynamic_shape
|
||||
|
||||
@ -188,3 +191,205 @@ def base_dense_head__get_bbox(ctx,
|
||||
score_threshold=score_threshold,
|
||||
pre_top_k=pre_top_k,
|
||||
keep_top_k=keep_top_k)
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdet.models.dense_heads.base_dense_head.BaseDenseHead'
|
||||
'.get_bboxes',
|
||||
backend='ncnn')
|
||||
def base_dense_head__get_bboxes__ncnn(ctx,
|
||||
self,
|
||||
cls_scores,
|
||||
bbox_preds,
|
||||
score_factors=None,
|
||||
img_metas=None,
|
||||
cfg=None,
|
||||
rescale=False,
|
||||
with_nms=True,
|
||||
**kwargs):
|
||||
"""Rewrite `get_bboxes` of AnchorHead for NCNN backend.
|
||||
|
||||
Shape node and batch inference is not supported by ncnn. This function
|
||||
transform dynamic shape to constant shape and remove batch inference.
|
||||
|
||||
Args:
|
||||
ctx (ContextCaller): The context with additional information.
|
||||
cls_scores (list[Tensor]): Classification scores for all
|
||||
scale levels, each is a 4D-tensor, has shape
|
||||
(batch_size, num_priors * num_classes, H, W).
|
||||
bbox_preds (list[Tensor]): Box energies / deltas for all
|
||||
scale levels, each is a 4D-tensor, has shape
|
||||
(batch_size, num_priors * 4, H, W).
|
||||
score_factors (list[Tensor], Optional): Score factor for
|
||||
all scale level, each is a 4D-tensor, has shape
|
||||
(batch_size, num_priors * 1, H, W). Default None.
|
||||
img_metas (list[dict], Optional): Image meta info. Default None.
|
||||
cfg (mmcv.Config, Optional): Test / postprocessing configuration,
|
||||
if None, test_cfg would be used. Default None.
|
||||
rescale (bool): If True, return boxes in original image space.
|
||||
Default False.
|
||||
with_nms (bool): If True, do nms before return boxes.
|
||||
Default True.
|
||||
|
||||
|
||||
Returns:
|
||||
output__ncnn (Tensor): outputs, shape is [N, num_det, 6].
|
||||
"""
|
||||
assert len(cls_scores) == len(bbox_preds)
|
||||
deploy_cfg = ctx.cfg
|
||||
assert not is_dynamic_shape(deploy_cfg), 'base_dense_head for ncnn\
|
||||
only supports static shape.'
|
||||
|
||||
if score_factors is None:
|
||||
# e.g. Retina, FreeAnchor, Foveabox, etc.
|
||||
with_score_factors = False
|
||||
else:
|
||||
# e.g. FCOS, PAA, ATSS, AutoAssign, etc.
|
||||
with_score_factors = True
|
||||
assert len(cls_scores) == len(score_factors)
|
||||
batch_size = cls_scores[0].shape[0]
|
||||
assert batch_size == 1, f'ncnn deployment requires batch size 1, \
|
||||
got {batch_size}.'
|
||||
|
||||
num_levels = len(cls_scores)
|
||||
if with_score_factors:
|
||||
score_factor_list = score_factors
|
||||
else:
|
||||
score_factor_list = [None for _ in range(num_levels)]
|
||||
|
||||
if isinstance(self.bbox_coder, DeltaXYWHBBoxCoder):
|
||||
vars = torch.tensor(self.bbox_coder.stds)
|
||||
elif isinstance(self.bbox_coder, TBLRBBoxCoder):
|
||||
normalizer = self.bbox_coder.normalizer
|
||||
if isinstance(normalizer, float):
|
||||
vars = torch.tensor([normalizer, normalizer, 1, 1],
|
||||
dtype=torch.float32)
|
||||
else:
|
||||
assert len(normalizer) == 4, f'normalizer of tblr must be 4,\
|
||||
got {len(normalizer)}'
|
||||
|
||||
assert (normalizer[0] == normalizer[1] and normalizer[2]
|
||||
== normalizer[3]), 'normalizer between top \
|
||||
and bottom, left and right must be the same value, or \
|
||||
we can not transform it to delta_xywh format.'
|
||||
|
||||
vars = torch.tensor([normalizer[0], normalizer[2], 1, 1],
|
||||
dtype=torch.float32)
|
||||
else:
|
||||
vars = None
|
||||
if isinstance(img_metas[0]['img_shape'][0], int):
|
||||
assert isinstance(img_metas[0]['img_shape'][1], int)
|
||||
img_height = img_metas[0]['img_shape'][0]
|
||||
img_width = img_metas[0]['img_shape'][1]
|
||||
else:
|
||||
img_height = img_metas[0]['img_shape'][0].item()
|
||||
img_width = img_metas[0]['img_shape'][1].item()
|
||||
featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)]
|
||||
mlvl_priors = self.prior_generator.grid_priors(
|
||||
featmap_sizes, device=cls_scores[0].device)
|
||||
batch_mlvl_priors = []
|
||||
for i in range(num_levels):
|
||||
_priors = mlvl_priors[i].reshape(1, -1, mlvl_priors[i].shape[-1])
|
||||
x1 = _priors[:, :, 0:1] / img_width
|
||||
y1 = _priors[:, :, 1:2] / img_height
|
||||
x2 = _priors[:, :, 2:3] / img_width
|
||||
y2 = _priors[:, :, 3:4] / img_height
|
||||
priors = torch.cat([x1, y1, x2, y2], dim=2).data
|
||||
batch_mlvl_priors.append(priors)
|
||||
|
||||
cfg = self.test_cfg if cfg is None else cfg
|
||||
|
||||
batch_mlvl_bboxes = []
|
||||
batch_mlvl_scores = []
|
||||
batch_mlvl_score_factors = []
|
||||
|
||||
for level_idx, (cls_score, bbox_pred, score_factor, priors) in \
|
||||
enumerate(zip(cls_scores, bbox_preds,
|
||||
score_factor_list, batch_mlvl_priors)):
|
||||
assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
|
||||
# NCNN needs 3 dimensions to reshape when including -1 parameter in
|
||||
# width or height dimension.
|
||||
bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 4)
|
||||
if with_score_factors:
|
||||
score_factor = score_factor.permute(0, 2, 3, 1).\
|
||||
reshape(batch_size, -1, 1).sigmoid()
|
||||
cls_score = cls_score.permute(0, 2, 3, 1).\
|
||||
reshape(batch_size, -1, self.cls_out_channels)
|
||||
# NCNN DetectionOutput op needs num_class + 1 classes. So if sigmoid
|
||||
# score, we should padding background class according to mmdetection
|
||||
# num_class definition.
|
||||
if self.use_sigmoid_cls:
|
||||
scores = cls_score.sigmoid()
|
||||
dummy_background_score = torch.zeros(
|
||||
batch_size, cls_score.shape[1], 1, device=cls_score.device)
|
||||
scores = torch.cat([scores, dummy_background_score], dim=2)
|
||||
else:
|
||||
scores = cls_score.softmax(-1)
|
||||
batch_mlvl_bboxes.append(bbox_pred)
|
||||
batch_mlvl_scores.append(scores)
|
||||
batch_mlvl_score_factors.append(score_factor)
|
||||
|
||||
batch_mlvl_priors = torch.cat(batch_mlvl_priors, dim=1)
|
||||
batch_mlvl_scores = torch.cat(batch_mlvl_scores, dim=1)
|
||||
batch_mlvl_bboxes = torch.cat(batch_mlvl_bboxes, dim=1)
|
||||
batch_mlvl_scores = torch.cat([
|
||||
batch_mlvl_scores[:, :, self.num_classes:],
|
||||
batch_mlvl_scores[:, :, 0:self.num_classes]
|
||||
],
|
||||
dim=2)
|
||||
if isinstance(self.bbox_coder, TBLRBBoxCoder):
|
||||
batch_mlvl_bboxes = _tblr_pred_to_delta_xywh_pred(
|
||||
batch_mlvl_bboxes, vars[0:2])
|
||||
# flatten for ncnn DetectionOutput op inputs.
|
||||
batch_mlvl_vars = vars.expand_as(batch_mlvl_priors)
|
||||
batch_mlvl_bboxes = batch_mlvl_bboxes.reshape(batch_size, 1, -1)
|
||||
batch_mlvl_scores = batch_mlvl_scores.reshape(batch_size, 1, -1)
|
||||
batch_mlvl_priors = batch_mlvl_priors.reshape(batch_size, 1, -1)
|
||||
batch_mlvl_vars = batch_mlvl_vars.reshape(batch_size, 1, -1)
|
||||
batch_mlvl_priors = torch.cat([batch_mlvl_priors, batch_mlvl_vars], dim=1)\
|
||||
.data
|
||||
|
||||
post_params = get_post_processing_params(ctx.cfg)
|
||||
iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold)
|
||||
score_threshold = cfg.get('score_thr', post_params.score_threshold)
|
||||
pre_top_k = post_params.pre_top_k
|
||||
keep_top_k = cfg.get('max_per_img', post_params.keep_top_k)
|
||||
|
||||
output__ncnn = ncnn_detection_output_forward(
|
||||
batch_mlvl_bboxes, batch_mlvl_scores, batch_mlvl_priors,
|
||||
score_threshold, iou_threshold, pre_top_k, keep_top_k,
|
||||
self.num_classes + 1,
|
||||
vars.cpu().detach().numpy())
|
||||
|
||||
return output__ncnn
|
||||
|
||||
|
||||
def _tblr_pred_to_delta_xywh_pred(bbox_pred: torch.Tensor,
|
||||
normalizer: torch.Tensor) -> torch.Tensor:
|
||||
"""Transform tblr format bbox prediction to delta_xywh format for ncnn.
|
||||
|
||||
An internal function for transforming tblr format bbox prediction to
|
||||
delta_xywh format. NCNN DetectionOutput layer needs delta_xywh format
|
||||
bbox_pred as input.
|
||||
|
||||
Args:
|
||||
bbox_pred (Tensor): The bbox prediction of tblr format, has shape
|
||||
(N, num_det, 4).
|
||||
normalizer (Tensor): The normalizer scale of bbox horizon and
|
||||
vertical coordinates, has shape (2,).
|
||||
|
||||
Returns:
|
||||
Tensor: The delta_xywh format bbox predictions.
|
||||
"""
|
||||
top = bbox_pred[:, :, 0:1]
|
||||
bottom = bbox_pred[:, :, 1:2]
|
||||
left = bbox_pred[:, :, 2:3]
|
||||
right = bbox_pred[:, :, 3:4]
|
||||
h = (top + bottom) * normalizer[0]
|
||||
w = (left + right) * normalizer[1]
|
||||
|
||||
_dwh = torch.cat([w, h], dim=2)
|
||||
assert torch.all(_dwh >= 0), 'wh must be positive before log.'
|
||||
dwh = torch.log(_dwh)
|
||||
|
||||
return torch.cat([(right - left) / 2, (bottom - top) / 2, dwh], dim=2)
|
||||
|
124
mmdeploy/codebase/mmdet/models/dense_heads/ssd_head.py
Normal file
124
mmdeploy/codebase/mmdet/models/dense_heads/ssd_head.py
Normal file
@ -0,0 +1,124 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from mmdeploy.codebase.mmdet import get_post_processing_params
|
||||
from mmdeploy.codebase.mmdet.core.ops import (ncnn_detection_output_forward,
|
||||
ncnn_prior_box_forward)
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
from mmdeploy.utils import is_dynamic_shape
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdet.models.dense_heads.SSDHead.get_bboxes', backend='ncnn')
|
||||
def ssd_head__get_bboxes__ncnn(ctx,
|
||||
self,
|
||||
cls_scores,
|
||||
bbox_preds,
|
||||
img_metas,
|
||||
with_nms=True,
|
||||
cfg=None,
|
||||
**kwargs):
|
||||
"""Rewrite `get_bboxes` of SSDHead for NCNN backend.
|
||||
|
||||
This rewriter using ncnn PriorBox and DetectionOutput layer to
|
||||
support dynamic deployment, and has higher speed.
|
||||
|
||||
Args:
|
||||
ctx (ContextCaller): The context with additional information.
|
||||
cls_scores (list[Tensor]): Box scores for each level in the
|
||||
feature pyramid, has shape
|
||||
(N, num_anchors * num_classes, H, W).
|
||||
bbox_preds (list[Tensor]): Box energies / deltas for each
|
||||
level in the feature pyramid, has shape
|
||||
(N, num_anchors * 4, H, W).
|
||||
img_metas (list[dict]): Meta information of each image, e.g.,
|
||||
image size, scaling factor, etc.
|
||||
with_nms (bool): If True, do nms before return boxes.
|
||||
Default: True.
|
||||
cfg (mmcv.Config | None): Test / postprocessing configuration,
|
||||
if None, test_cfg would be used.
|
||||
Default: None.
|
||||
|
||||
Returns:
|
||||
Tensor: outputs, shape is [N, num_det, 6].
|
||||
"""
|
||||
assert len(cls_scores) == len(bbox_preds)
|
||||
deploy_cfg = ctx.cfg
|
||||
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
|
||||
num_levels = len(cls_scores)
|
||||
aspect_ratio = [
|
||||
ratio[ratio > 1].detach().cpu().numpy()
|
||||
for ratio in self.anchor_generator.ratios
|
||||
]
|
||||
strides = self.anchor_generator.strides
|
||||
min_sizes = self.anchor_generator.base_sizes
|
||||
if is_dynamic_flag:
|
||||
max_sizes = min_sizes[1:] + img_metas[0]['img_shape'][0:1].tolist()
|
||||
img_height = img_metas[0]['img_shape'][0].item()
|
||||
img_width = img_metas[0]['img_shape'][1].item()
|
||||
else:
|
||||
max_sizes = min_sizes[1:] + img_metas[0]['img_shape'][0:1]
|
||||
img_height = img_metas[0]['img_shape'][0]
|
||||
img_width = img_metas[0]['img_shape'][1]
|
||||
|
||||
# if no reshape, concat will be error in ncnn.
|
||||
mlvl_anchors = [
|
||||
ncnn_prior_box_forward(cls_scores[i], aspect_ratio[i], img_height,
|
||||
img_width, strides[i][0], strides[i][1],
|
||||
max_sizes[i:i + 1],
|
||||
min_sizes[i:i + 1]).reshape(1, 2, -1)
|
||||
for i in range(num_levels)
|
||||
]
|
||||
|
||||
mlvl_cls_scores = [cls_scores[i].detach() for i in range(num_levels)]
|
||||
mlvl_bbox_preds = [bbox_preds[i].detach() for i in range(num_levels)]
|
||||
|
||||
cfg = self.test_cfg if cfg is None else cfg
|
||||
assert len(mlvl_cls_scores) == len(mlvl_bbox_preds) == len(mlvl_anchors)
|
||||
batch_size = 1
|
||||
|
||||
mlvl_valid_bboxes = []
|
||||
mlvl_scores = []
|
||||
for level_id, cls_score, bbox_pred in zip(
|
||||
range(num_levels), mlvl_cls_scores, mlvl_bbox_preds):
|
||||
assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
|
||||
cls_score = cls_score.permute(0, 2, 3,
|
||||
1).reshape(batch_size, -1,
|
||||
self.cls_out_channels)
|
||||
bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 4)
|
||||
|
||||
mlvl_valid_bboxes.append(bbox_pred)
|
||||
mlvl_scores.append(cls_score)
|
||||
|
||||
# NCNN DetectionOutput layer uses background class at 0 position, but
|
||||
# in mmdetection, background class is at self.num_classes position.
|
||||
# We should adapt for ncnn.
|
||||
batch_mlvl_valid_bboxes = torch.cat(mlvl_valid_bboxes, dim=1)
|
||||
batch_mlvl_scores = torch.cat(mlvl_scores, dim=1)
|
||||
if self.use_sigmoid_cls:
|
||||
batch_mlvl_scores = batch_mlvl_scores.sigmoid()
|
||||
else:
|
||||
batch_mlvl_scores = batch_mlvl_scores.softmax(-1)
|
||||
batch_mlvl_anchors = torch.cat(mlvl_anchors, dim=2)
|
||||
batch_mlvl_scores = torch.cat([
|
||||
batch_mlvl_scores[:, :, self.num_classes:],
|
||||
batch_mlvl_scores[:, :, 0:self.num_classes]
|
||||
],
|
||||
dim=2)
|
||||
batch_mlvl_valid_bboxes = batch_mlvl_valid_bboxes.reshape(
|
||||
batch_size, 1, -1)
|
||||
batch_mlvl_scores = batch_mlvl_scores.reshape(batch_size, 1, -1)
|
||||
batch_mlvl_anchors = batch_mlvl_anchors.reshape(batch_size, 2, -1)
|
||||
|
||||
post_params = get_post_processing_params(deploy_cfg)
|
||||
iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold)
|
||||
score_threshold = cfg.get('score_thr', post_params.score_threshold)
|
||||
pre_top_k = post_params.pre_top_k
|
||||
keep_top_k = cfg.get('max_per_img', post_params.keep_top_k)
|
||||
|
||||
output__ncnn = ncnn_detection_output_forward(
|
||||
batch_mlvl_valid_bboxes, batch_mlvl_scores, batch_mlvl_anchors,
|
||||
score_threshold, iou_threshold, pre_top_k, keep_top_k,
|
||||
self.num_classes + 1)
|
||||
|
||||
return output__ncnn
|
10
mmdeploy/codebase/mmdet/models/necks.py
Normal file
10
mmdeploy/codebase/mmdet/models/necks.py
Normal file
@ -0,0 +1,10 @@
|
||||
import torch
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdet.models.necks.ssd_neck.L2Norm.forward')
|
||||
def l2norm__forward__default(ctx, self, x):
|
||||
return torch.nn.functional.normalize(
|
||||
x, dim=1) * self.weight[None, :, None, None]
|
@ -139,114 +139,6 @@ def get_single_roi_extractor():
|
||||
return model
|
||||
|
||||
|
||||
@pytest.mark.parametrize('backend_type', [Backend.NCNN])
|
||||
@pytest.mark.parametrize('is_ssd', [True, False])
|
||||
def test_anchor_head_get_bboxes(backend_type: Backend, is_ssd: bool):
|
||||
"""Test get_bboxes rewrite of anchor head."""
|
||||
check_backend(backend_type)
|
||||
if is_ssd:
|
||||
anchor_head = get_ssd_head_model()
|
||||
else:
|
||||
anchor_head = get_anchor_head_model()
|
||||
anchor_head.cpu().eval()
|
||||
s = 128
|
||||
img_metas = [{
|
||||
'scale_factor': np.ones(4),
|
||||
'pad_shape': (s, s, 3),
|
||||
'img_shape': (s, s, 3)
|
||||
}]
|
||||
if is_ssd:
|
||||
output_names = ['output']
|
||||
else:
|
||||
output_names = ['dets', 'labels']
|
||||
deploy_cfg = mmcv.Config(
|
||||
dict(
|
||||
backend_config=dict(type=backend_type.value),
|
||||
onnx_config=dict(output_names=output_names, input_shape=None),
|
||||
codebase_config=dict(
|
||||
type='mmdet',
|
||||
task='ObjectDetection',
|
||||
post_processing=dict(
|
||||
score_threshold=0.05,
|
||||
iou_threshold=0.5,
|
||||
max_output_boxes_per_class=200,
|
||||
pre_top_k=5000,
|
||||
keep_top_k=100,
|
||||
background_label_id=-1,
|
||||
))))
|
||||
|
||||
if not is_ssd:
|
||||
# For the general anchor_head:
|
||||
# the cls_score's size: (1, 36, 32, 32), (1, 36, 16, 16),
|
||||
# (1, 36, 8, 8), (1, 36, 4, 4), (1, 36, 2, 2).
|
||||
# the bboxes's size: (1, 36, 32, 32), (1, 36, 16, 16),
|
||||
# (1, 36, 8, 8), (1, 36, 4, 4), (1, 36, 2, 2)
|
||||
seed_everything(1234)
|
||||
cls_score = [
|
||||
torch.rand(1, 36, pow(2, i), pow(2, i)) for i in range(5, 0, -1)
|
||||
]
|
||||
seed_everything(5678)
|
||||
bboxes = [
|
||||
torch.rand(1, 36, pow(2, i), pow(2, i)) for i in range(5, 0, -1)
|
||||
]
|
||||
else:
|
||||
# For the ssd_head:
|
||||
# the cls_score's size: (1, 30, 20, 20), (1, 30, 10, 10),
|
||||
# (1, 30, 5, 5), (1, 30, 3, 3), (1, 30, 2, 2), (1, 30, 1, 1)
|
||||
# the bboxes's size: (1, 24, 20, 20), (1, 24, 10, 10),
|
||||
# (1, 24, 5, 5), (1, 24, 3, 3), (1, 24, 2, 2), (1, 24, 1, 1)
|
||||
feat_shape = [20, 10, 5, 3, 2, 1]
|
||||
num_prior = 6
|
||||
seed_everything(1234)
|
||||
cls_score = [
|
||||
torch.rand(1, 30, feat_shape[i], feat_shape[i])
|
||||
for i in range(num_prior)
|
||||
]
|
||||
seed_everything(5678)
|
||||
bboxes = [
|
||||
torch.rand(1, 24, feat_shape[i], feat_shape[i])
|
||||
for i in range(num_prior)
|
||||
]
|
||||
|
||||
# to get outputs of pytorch model
|
||||
model_inputs = {
|
||||
'cls_scores': cls_score,
|
||||
'bbox_preds': bboxes,
|
||||
'img_metas': img_metas
|
||||
}
|
||||
model_outputs = get_model_outputs(anchor_head, 'get_bboxes', model_inputs)
|
||||
|
||||
# to get outputs of onnx model after rewrite
|
||||
img_metas[0]['img_shape'] = torch.tensor([s, s], dtype=torch.int32)
|
||||
wrapped_model = WrapModel(
|
||||
anchor_head, 'get_bboxes', img_metas=img_metas, with_nms=True)
|
||||
rewrite_inputs = {
|
||||
'cls_scores': cls_score,
|
||||
'bbox_preds': bboxes,
|
||||
}
|
||||
rewrite_outputs, is_backend_output = get_rewrite_outputs(
|
||||
wrapped_model=wrapped_model,
|
||||
model_inputs=rewrite_inputs,
|
||||
deploy_cfg=deploy_cfg)
|
||||
|
||||
if is_backend_output:
|
||||
if isinstance(rewrite_outputs, dict):
|
||||
rewrite_outputs = convert_to_list(rewrite_outputs, output_names)
|
||||
for model_output, rewrite_output in zip(model_outputs[0],
|
||||
rewrite_outputs):
|
||||
model_output = model_output.squeeze().cpu().numpy()
|
||||
rewrite_output = rewrite_output.squeeze().cpu().numpy()
|
||||
# hard code to make two tensors with the same shape
|
||||
# rewrite and original codes applied different nms strategy
|
||||
assert np.allclose(
|
||||
model_output[:rewrite_output.shape[0]],
|
||||
rewrite_output,
|
||||
rtol=1e-03,
|
||||
atol=1e-05)
|
||||
else:
|
||||
assert rewrite_outputs is not None
|
||||
|
||||
|
||||
@pytest.mark.parametrize('backend_type', [Backend.NCNN])
|
||||
def test_get_bboxes_of_fcos_head(backend_type: Backend):
|
||||
check_backend(backend_type)
|
||||
@ -1126,8 +1018,8 @@ def test_get_bboxes_of_vfnet_head(backend_type: Backend):
|
||||
|
||||
|
||||
@pytest.mark.parametrize('backend_type',
|
||||
[Backend.ONNXRUNTIME, Backend.OPENVINO])
|
||||
def test_get_bboxes_of_base_dense_head(backend_type: Backend):
|
||||
[Backend.ONNXRUNTIME, Backend.NCNN, Backend.OPENVINO])
|
||||
def test_base_dense_head_get_bboxes(backend_type: Backend):
|
||||
"""Test get_bboxes rewrite of base dense head."""
|
||||
check_backend(backend_type)
|
||||
anchor_head = get_anchor_head_model()
|
||||
@ -1139,7 +1031,10 @@ def test_get_bboxes_of_base_dense_head(backend_type: Backend):
|
||||
'img_shape': (s, s, 3)
|
||||
}]
|
||||
|
||||
output_names = ['dets', 'labels']
|
||||
if backend_type != Backend.NCNN:
|
||||
output_names = ['dets', 'labels']
|
||||
else:
|
||||
output_names = ['output']
|
||||
deploy_cfg = mmcv.Config(
|
||||
dict(
|
||||
backend_config=dict(type=backend_type.value),
|
||||
@ -1204,3 +1099,89 @@ def test_get_bboxes_of_base_dense_head(backend_type: Backend):
|
||||
atol=1e-05)
|
||||
else:
|
||||
assert rewrite_outputs is not None
|
||||
|
||||
|
||||
@pytest.mark.parametrize('backend_type', [Backend.NCNN])
|
||||
def test_ssd_head_get_bboxes(backend_type: Backend):
|
||||
"""Test get_bboxes rewrite of anchor head."""
|
||||
check_backend(backend_type)
|
||||
ssd_head = get_ssd_head_model()
|
||||
ssd_head.cpu().eval()
|
||||
s = 128
|
||||
img_metas = [{
|
||||
'scale_factor': np.ones(4),
|
||||
'pad_shape': (s, s, 3),
|
||||
'img_shape': (s, s, 3)
|
||||
}]
|
||||
output_names = ['output']
|
||||
deploy_cfg = mmcv.Config(
|
||||
dict(
|
||||
backend_config=dict(type=backend_type.value),
|
||||
onnx_config=dict(output_names=output_names, input_shape=None),
|
||||
codebase_config=dict(
|
||||
type='mmdet',
|
||||
task='ObjectDetection',
|
||||
post_processing=dict(
|
||||
score_threshold=0.05,
|
||||
iou_threshold=0.5,
|
||||
max_output_boxes_per_class=200,
|
||||
pre_top_k=5000,
|
||||
keep_top_k=100,
|
||||
background_label_id=-1,
|
||||
))))
|
||||
|
||||
# For the ssd_head:
|
||||
# the cls_score's size: (1, 30, 20, 20), (1, 30, 10, 10),
|
||||
# (1, 30, 5, 5), (1, 30, 3, 3), (1, 30, 2, 2), (1, 30, 1, 1)
|
||||
# the bboxes's size: (1, 24, 20, 20), (1, 24, 10, 10),
|
||||
# (1, 24, 5, 5), (1, 24, 3, 3), (1, 24, 2, 2), (1, 24, 1, 1)
|
||||
feat_shape = [20, 10, 5, 3, 2, 1]
|
||||
num_prior = 6
|
||||
seed_everything(1234)
|
||||
cls_score = [
|
||||
torch.rand(1, 30, feat_shape[i], feat_shape[i])
|
||||
for i in range(num_prior)
|
||||
]
|
||||
seed_everything(5678)
|
||||
bboxes = [
|
||||
torch.rand(1, 24, feat_shape[i], feat_shape[i])
|
||||
for i in range(num_prior)
|
||||
]
|
||||
|
||||
# to get outputs of pytorch model
|
||||
model_inputs = {
|
||||
'cls_scores': cls_score,
|
||||
'bbox_preds': bboxes,
|
||||
'img_metas': img_metas
|
||||
}
|
||||
model_outputs = get_model_outputs(ssd_head, 'get_bboxes', model_inputs)
|
||||
|
||||
# to get outputs of onnx model after rewrite
|
||||
img_metas[0]['img_shape'] = torch.tensor([s, s], dtype=torch.int32)
|
||||
wrapped_model = WrapModel(
|
||||
ssd_head, 'get_bboxes', img_metas=img_metas, with_nms=True)
|
||||
rewrite_inputs = {
|
||||
'cls_scores': cls_score,
|
||||
'bbox_preds': bboxes,
|
||||
}
|
||||
rewrite_outputs, is_backend_output = get_rewrite_outputs(
|
||||
wrapped_model=wrapped_model,
|
||||
model_inputs=rewrite_inputs,
|
||||
deploy_cfg=deploy_cfg)
|
||||
|
||||
if is_backend_output:
|
||||
if isinstance(rewrite_outputs, dict):
|
||||
rewrite_outputs = convert_to_list(rewrite_outputs, output_names)
|
||||
for model_output, rewrite_output in zip(model_outputs[0],
|
||||
rewrite_outputs):
|
||||
model_output = model_output.squeeze().cpu().numpy()
|
||||
rewrite_output = rewrite_output.squeeze().cpu().numpy()
|
||||
# hard code to make two tensors with the same shape
|
||||
# rewrite and original codes applied different nms strategy
|
||||
assert np.allclose(
|
||||
model_output[:rewrite_output.shape[0]],
|
||||
rewrite_output,
|
||||
rtol=1e-03,
|
||||
atol=1e-05)
|
||||
else:
|
||||
assert rewrite_outputs is not None
|
||||
|
Loading…
x
Reference in New Issue
Block a user