[Fix] support ncnn faster-rcnn (#304)
* wtf * Support fcos ncnn end2end * support ncnn two stage detector * fix testpull/12/head
parent
fabdb473bb
commit
abdf64a576
|
@ -141,3 +141,121 @@ def delta2bbox(ctx,
|
|||
|
||||
bboxes = torch.stack([x1, y1, x2, y2], dim=-1).view(deltas.size())
|
||||
return bboxes
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdet.core.bbox.coder.delta_xywh_bbox_coder.delta2bbox', # noqa
|
||||
backend='ncnn')
|
||||
def delta2bbox__ncnn(ctx,
|
||||
rois,
|
||||
deltas,
|
||||
means=(0., 0., 0., 0.),
|
||||
stds=(1., 1., 1., 1.),
|
||||
max_shape=None,
|
||||
wh_ratio_clip=16 / 1000,
|
||||
clip_border=True,
|
||||
add_ctr_clamp=False,
|
||||
ctr_clamp=32):
|
||||
"""Rewrite `delta2bbox` for ncnn backend.
|
||||
Batch dimension is not supported by ncnn, but supported by pytorch.
|
||||
NCNN regards the lowest two dimensions as continuous address with byte
|
||||
alignment, so the lowest two dimensions are not absolutely independent.
|
||||
Reshape operator with -1 arguments should operates ncnn::Mat with
|
||||
dimension >= 3.
|
||||
Args:
|
||||
ctx (ContextCaller): The context with additional information.
|
||||
rois (Tensor): Boxes to be transformed. Has shape (N, 4) or (B, N, 4)
|
||||
deltas (Tensor): Encoded offsets with respect to each roi.
|
||||
Has shape (B, N, num_classes * 4) or (B, N, 4) or
|
||||
(N, num_classes * 4) or (N, 4). Note N = num_anchors * W * H
|
||||
when rois is a grid of anchors.Offset encoding follows [1]_.
|
||||
means (Sequence[float]): Denormalizing means for delta coordinates
|
||||
stds (Sequence[float]): Denormalizing standard deviation for delta
|
||||
coordinates
|
||||
max_shape (Sequence[int] or torch.Tensor or Sequence[
|
||||
Sequence[int]],optional): Maximum bounds for boxes, specifies
|
||||
(H, W, C) or (H, W). If rois shape is (B, N, 4), then
|
||||
the max_shape should be a Sequence[Sequence[int]]
|
||||
and the length of max_shape should also be B.
|
||||
wh_ratio_clip (float): Maximum aspect ratio for boxes.
|
||||
clip_border (bool, optional): Whether clip the objects outside the
|
||||
border of the image. Defaults to True.
|
||||
add_ctr_clamp (bool): Whether to add center clamp, when added, the
|
||||
predicted box is clamped is its center is too far away from
|
||||
the original anchor's center. Only used by YOLOF. Default False.
|
||||
ctr_clamp (int): the maximum pixel shift to clamp. Only used by YOLOF.
|
||||
Default 32.
|
||||
Return:
|
||||
bboxes (Tensor): Boxes with shape (B, N, num_classes * 4) or (B, N, 4)
|
||||
or (N, num_classes * 4) or (N, 4), where 4 represent tl_x, tl_y,
|
||||
br_x, br_y.
|
||||
"""
|
||||
means = deltas.new_tensor(means).view(1, 1,
|
||||
-1).repeat(1, deltas.size(-2),
|
||||
deltas.size(-1) // 4).data
|
||||
stds = deltas.new_tensor(stds).view(1, 1,
|
||||
-1).repeat(1, deltas.size(-2),
|
||||
deltas.size(-1) // 4).data
|
||||
denorm_deltas = deltas * stds + means
|
||||
if denorm_deltas.shape[-1] == 4:
|
||||
dx = denorm_deltas[..., 0:1]
|
||||
dy = denorm_deltas[..., 1:2]
|
||||
dw = denorm_deltas[..., 2:3]
|
||||
dh = denorm_deltas[..., 3:4]
|
||||
else:
|
||||
dx = denorm_deltas[..., 0::4]
|
||||
dy = denorm_deltas[..., 1::4]
|
||||
dw = denorm_deltas[..., 2::4]
|
||||
dh = denorm_deltas[..., 3::4]
|
||||
|
||||
x1, y1 = rois[..., 0:1], rois[..., 1:2]
|
||||
x2, y2 = rois[..., 2:3], rois[..., 3:4]
|
||||
|
||||
# Compute center of each roi
|
||||
px = (x1 + x2) * 0.5
|
||||
py = (y1 + y2) * 0.5
|
||||
# Compute width/height of each roi
|
||||
pw = x2 - x1
|
||||
ph = y2 - y1
|
||||
|
||||
# do not use expand unless necessary
|
||||
# since expand is a custom ops
|
||||
if px.shape[-1] != 4:
|
||||
px = px.expand_as(dx)
|
||||
if py.shape[-1] != 4:
|
||||
py = py.expand_as(dy)
|
||||
if pw.shape[-1] != 4:
|
||||
pw = pw.expand_as(dw)
|
||||
if px.shape[-1] != 4:
|
||||
ph = ph.expand_as(dh)
|
||||
|
||||
dx_width = pw * dx
|
||||
dy_height = ph * dy
|
||||
|
||||
max_ratio = np.abs(np.log(wh_ratio_clip))
|
||||
if add_ctr_clamp:
|
||||
dx_width = torch.clamp(dx_width, max=ctr_clamp, min=-ctr_clamp)
|
||||
dy_height = torch.clamp(dy_height, max=ctr_clamp, min=-ctr_clamp)
|
||||
dw = torch.clamp(dw, max=max_ratio)
|
||||
dh = torch.clamp(dh, max=max_ratio)
|
||||
else:
|
||||
dw = dw.clamp(min=-max_ratio, max=max_ratio)
|
||||
dh = dh.clamp(min=-max_ratio, max=max_ratio)
|
||||
# Use exp(network energy) to enlarge/shrink each roi
|
||||
gw = pw * dw.exp()
|
||||
gh = ph * dh.exp()
|
||||
# Use network energy to shift the center of each roi
|
||||
gx = px + dx_width
|
||||
gy = py + dy_height
|
||||
# Convert center-xy/width/height to top-left, bottom-right
|
||||
x1 = gx - gw * 0.5
|
||||
y1 = gy - gh * 0.5
|
||||
x2 = gx + gw * 0.5
|
||||
y2 = gy + gh * 0.5
|
||||
|
||||
if clip_border and max_shape is not None:
|
||||
from mmdeploy.codebase.mmdet.deploy import clip_bboxes
|
||||
x1, y1, x2, y2 = clip_bboxes(x1, y1, x2, y2, max_shape)
|
||||
|
||||
bboxes = torch.stack([x1, y1, x2, y2], dim=-1).view(deltas.size())
|
||||
return bboxes
|
||||
|
|
|
@ -483,8 +483,8 @@ class PartitionTwoStageModel(End2EndModel):
|
|||
rois,
|
||||
cls_score,
|
||||
bbox_pred,
|
||||
img_metas[0]['img_shape'],
|
||||
img_metas[0]['scale_factor'],
|
||||
img_metas[0][0]['img_shape'],
|
||||
img_metas[0][0]['scale_factor'],
|
||||
cfg=rcnn_test_cfg)
|
||||
|
||||
def forward_test(self, imgs: torch.Tensor, img_metas: Sequence[dict],
|
||||
|
|
|
@ -338,11 +338,11 @@ class TestPartitionTwoStageModel:
|
|||
rois = torch.rand(1, 10, 5)
|
||||
cls_score = torch.rand(10, 81)
|
||||
bbox_pred = torch.rand(10, 320)
|
||||
img_metas = [{
|
||||
img_metas = [[{
|
||||
'ori_shape': [32, 32, 3],
|
||||
'img_shape': [32, 32, 3],
|
||||
'scale_factor': [1, 1, 1, 1],
|
||||
}]
|
||||
}]]
|
||||
results = self.model.partition1_postprocess(
|
||||
rois=rois,
|
||||
cls_score=cls_score,
|
||||
|
|
Loading…
Reference in New Issue