mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
[Enhancement] optimize delta2bboxes (#152)
* optimize delta2bboxes * ncnn update
This commit is contained in:
parent
141d956636
commit
2c25eff32c
@ -92,48 +92,42 @@ def delta2bbox(ctx,
|
|||||||
bboxes (Tensor): Boxes with shape (N, num_classes * 4) or (N, 4),
|
bboxes (Tensor): Boxes with shape (N, num_classes * 4) or (N, 4),
|
||||||
where 4 represent tl_x, tl_y, br_x, br_y.
|
where 4 represent tl_x, tl_y, br_x, br_y.
|
||||||
"""
|
"""
|
||||||
means = deltas.new_tensor(means).view(1,
|
means = deltas.new_tensor(means).view(1, -1)
|
||||||
-1).repeat(1,
|
stds = deltas.new_tensor(stds).view(1, -1)
|
||||||
deltas.size(-1) // 4)
|
delta_shape = deltas.shape
|
||||||
stds = deltas.new_tensor(stds).view(1, -1).repeat(1, deltas.size(-1) // 4)
|
reshaped_deltas = deltas.view(delta_shape[:-1] + (-1, 4))
|
||||||
denorm_deltas = deltas * stds + means
|
denorm_deltas = reshaped_deltas * stds + means
|
||||||
dx = denorm_deltas[..., 0::4]
|
|
||||||
dy = denorm_deltas[..., 1::4]
|
|
||||||
dw = denorm_deltas[..., 2::4]
|
|
||||||
dh = denorm_deltas[..., 3::4]
|
|
||||||
|
|
||||||
x1, y1 = rois[..., 0], rois[..., 1]
|
dxy = denorm_deltas[..., :2]
|
||||||
x2, y2 = rois[..., 2], rois[..., 3]
|
dwh = denorm_deltas[..., 2:]
|
||||||
# Compute center of each roi
|
|
||||||
px = ((x1 + x2) * 0.5).unsqueeze(-1).expand_as(dx)
|
|
||||||
py = ((y1 + y2) * 0.5).unsqueeze(-1).expand_as(dy)
|
|
||||||
# Compute width/height of each roi
|
|
||||||
pw = (x2 - x1).unsqueeze(-1).expand_as(dw)
|
|
||||||
ph = (y2 - y1).unsqueeze(-1).expand_as(dh)
|
|
||||||
|
|
||||||
dx_width = pw * dx
|
xy1 = rois[..., None, :2]
|
||||||
dy_height = ph * dy
|
xy2 = rois[..., None, 2:]
|
||||||
|
|
||||||
|
pxy = (xy1 + xy2) * 0.5
|
||||||
|
pwh = xy2 - xy1
|
||||||
|
dxy_wh = pwh * dxy
|
||||||
|
|
||||||
max_ratio = np.abs(np.log(wh_ratio_clip))
|
max_ratio = np.abs(np.log(wh_ratio_clip))
|
||||||
if add_ctr_clamp:
|
if add_ctr_clamp:
|
||||||
dx_width = torch.clamp(dx_width, max=ctr_clamp, min=-ctr_clamp)
|
dxy_wh = torch.clamp(dxy_wh, max=ctr_clamp, min=-ctr_clamp)
|
||||||
dy_height = torch.clamp(dy_height, max=ctr_clamp, min=-ctr_clamp)
|
dwh = torch.clamp(dwh, max=max_ratio)
|
||||||
dw = torch.clamp(dw, max=max_ratio)
|
|
||||||
dh = torch.clamp(dh, max=max_ratio)
|
|
||||||
else:
|
else:
|
||||||
dw = dw.clamp(min=-max_ratio, max=max_ratio)
|
dwh = dwh.clamp(min=-max_ratio, max=max_ratio)
|
||||||
dh = dh.clamp(min=-max_ratio, max=max_ratio)
|
|
||||||
# Use exp(network energy) to enlarge/shrink each roi
|
# Use exp(network energy) to enlarge/shrink each roi
|
||||||
gw = pw * dw.exp()
|
half_gwh = pwh * dwh.exp() * 0.5
|
||||||
gh = ph * dh.exp()
|
|
||||||
# Use network energy to shift the center of each roi
|
# Use network energy to shift the center of each roi
|
||||||
gx = px + dx_width
|
gxy = pxy + dxy_wh
|
||||||
gy = py + dy_height
|
|
||||||
# Convert center-xy/width/height to top-left, bottom-right
|
# Convert center-xy/width/height to top-left, bottom-right
|
||||||
x1 = gx - gw * 0.5
|
xy1 = gxy - half_gwh
|
||||||
y1 = gy - gh * 0.5
|
xy2 = gxy + half_gwh
|
||||||
x2 = gx + gw * 0.5
|
|
||||||
y2 = gy + gh * 0.5
|
x1 = xy1[..., 0]
|
||||||
|
y1 = xy1[..., 1]
|
||||||
|
x2 = xy2[..., 0]
|
||||||
|
y2 = xy2[..., 1]
|
||||||
|
|
||||||
if clip_border and max_shape is not None:
|
if clip_border and max_shape is not None:
|
||||||
from mmdeploy.codebase.mmdet.deploy import clip_bboxes
|
from mmdeploy.codebase.mmdet.deploy import clip_bboxes
|
||||||
@ -190,68 +184,42 @@ def delta2bbox__ncnn(ctx,
|
|||||||
or (N, num_classes * 4) or (N, 4), where 4 represent tl_x, tl_y,
|
or (N, num_classes * 4) or (N, 4), where 4 represent tl_x, tl_y,
|
||||||
br_x, br_y.
|
br_x, br_y.
|
||||||
"""
|
"""
|
||||||
means = deltas.new_tensor(means).view(1, 1,
|
means = deltas.new_tensor(means).view(1, 1, 1, -1).data
|
||||||
-1).repeat(1, deltas.size(-2),
|
stds = deltas.new_tensor(stds).view(1, 1, 1, -1).data
|
||||||
deltas.size(-1) // 4).data
|
delta_shape = deltas.shape
|
||||||
stds = deltas.new_tensor(stds).view(1, 1,
|
reshaped_deltas = deltas.view(delta_shape[:-1] + (-1, 4))
|
||||||
-1).repeat(1, deltas.size(-2),
|
denorm_deltas = reshaped_deltas * stds + means
|
||||||
deltas.size(-1) // 4).data
|
|
||||||
denorm_deltas = deltas * stds + means
|
|
||||||
if denorm_deltas.shape[-1] == 4:
|
|
||||||
dx = denorm_deltas[..., 0:1]
|
|
||||||
dy = denorm_deltas[..., 1:2]
|
|
||||||
dw = denorm_deltas[..., 2:3]
|
|
||||||
dh = denorm_deltas[..., 3:4]
|
|
||||||
else:
|
|
||||||
dx = denorm_deltas[..., 0::4]
|
|
||||||
dy = denorm_deltas[..., 1::4]
|
|
||||||
dw = denorm_deltas[..., 2::4]
|
|
||||||
dh = denorm_deltas[..., 3::4]
|
|
||||||
|
|
||||||
x1, y1 = rois[..., 0:1], rois[..., 1:2]
|
dxy = denorm_deltas[..., :2]
|
||||||
x2, y2 = rois[..., 2:3], rois[..., 3:4]
|
dwh = denorm_deltas[..., 2:]
|
||||||
|
|
||||||
# Compute center of each roi
|
xy1 = rois[..., None, :2]
|
||||||
px = (x1 + x2) * 0.5
|
xy2 = rois[..., None, 2:]
|
||||||
py = (y1 + y2) * 0.5
|
|
||||||
# Compute width/height of each roi
|
|
||||||
pw = x2 - x1
|
|
||||||
ph = y2 - y1
|
|
||||||
|
|
||||||
# do not use expand unless necessary
|
pxy = (xy1 + xy2) * 0.5
|
||||||
# since expand is a custom ops
|
pwh = xy2 - xy1
|
||||||
if px.shape[-1] != 4:
|
dxy_wh = pwh * dxy
|
||||||
px = px.expand_as(dx)
|
|
||||||
if py.shape[-1] != 4:
|
|
||||||
py = py.expand_as(dy)
|
|
||||||
if pw.shape[-1] != 4:
|
|
||||||
pw = pw.expand_as(dw)
|
|
||||||
if px.shape[-1] != 4:
|
|
||||||
ph = ph.expand_as(dh)
|
|
||||||
|
|
||||||
dx_width = pw * dx
|
|
||||||
dy_height = ph * dy
|
|
||||||
|
|
||||||
max_ratio = np.abs(np.log(wh_ratio_clip))
|
max_ratio = np.abs(np.log(wh_ratio_clip))
|
||||||
if add_ctr_clamp:
|
if add_ctr_clamp:
|
||||||
dx_width = torch.clamp(dx_width, max=ctr_clamp, min=-ctr_clamp)
|
dxy_wh = torch.clamp(dxy_wh, max=ctr_clamp, min=-ctr_clamp)
|
||||||
dy_height = torch.clamp(dy_height, max=ctr_clamp, min=-ctr_clamp)
|
dwh = torch.clamp(dwh, max=max_ratio)
|
||||||
dw = torch.clamp(dw, max=max_ratio)
|
|
||||||
dh = torch.clamp(dh, max=max_ratio)
|
|
||||||
else:
|
else:
|
||||||
dw = dw.clamp(min=-max_ratio, max=max_ratio)
|
dwh = dwh.clamp(min=-max_ratio, max=max_ratio)
|
||||||
dh = dh.clamp(min=-max_ratio, max=max_ratio)
|
|
||||||
# Use exp(network energy) to enlarge/shrink each roi
|
# Use exp(network energy) to enlarge/shrink each roi
|
||||||
gw = pw * dw.exp()
|
half_gwh = pwh * dwh.exp() * 0.5
|
||||||
gh = ph * dh.exp()
|
|
||||||
# Use network energy to shift the center of each roi
|
# Use network energy to shift the center of each roi
|
||||||
gx = px + dx_width
|
gxy = pxy + dxy_wh
|
||||||
gy = py + dy_height
|
|
||||||
# Convert center-xy/width/height to top-left, bottom-right
|
# Convert center-xy/width/height to top-left, bottom-right
|
||||||
x1 = gx - gw * 0.5
|
xy1 = gxy - half_gwh
|
||||||
y1 = gy - gh * 0.5
|
xy2 = gxy + half_gwh
|
||||||
x2 = gx + gw * 0.5
|
|
||||||
y2 = gy + gh * 0.5
|
x1 = xy1[..., 0]
|
||||||
|
y1 = xy1[..., 1]
|
||||||
|
x2 = xy2[..., 0]
|
||||||
|
y2 = xy2[..., 1]
|
||||||
|
|
||||||
if clip_border and max_shape is not None:
|
if clip_border and max_shape is not None:
|
||||||
from mmdeploy.codebase.mmdet.deploy import clip_bboxes
|
from mmdeploy.codebase.mmdet.deploy import clip_bboxes
|
||||||
|
Loading…
x
Reference in New Issue
Block a user