q.yao 3f261e6d50
[Refactor] Refactor rewriter context for MMRazor (#1483)
* wip

* update rewriter

* Support all codebase

* update docs

* fix ssd

* rename qualname

* support torch.fx.wrap

* import by torch version

Co-authored-by: pppppM <gjf_mail@126.com>
2022-12-13 19:03:56 +08:00

34 lines
1.0 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmdet.structures.bbox import get_box_tensor
from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
'mmrotate.models.task_modules.coders.gliding_vertex_coder'
'.GVFixCoder.decode')
def gvfixcoder__decode(self, hboxes, fix_deltas):
"""Rewriter for GVFixCoder decode, support more dimension input."""
assert hboxes.size(
-1) == 4, f'expect hboxes.size(-1)==4 get {hboxes.size(-1)}.'
hboxes = get_box_tensor(hboxes)
x1 = hboxes[..., 0::4]
y1 = hboxes[..., 1::4]
x2 = hboxes[..., 2::4]
y2 = hboxes[..., 3::4]
w = hboxes[..., 2::4] - hboxes[..., 0::4]
h = hboxes[..., 3::4] - hboxes[..., 1::4]
pred_t_x = x1 + w * fix_deltas[..., 0::4]
pred_r_y = y1 + h * fix_deltas[..., 1::4]
pred_d_x = x2 - w * fix_deltas[..., 2::4]
pred_l_y = y2 - h * fix_deltas[..., 3::4]
polys = torch.stack(
[pred_t_x, y1, x2, pred_r_y, pred_d_x, y2, x1, pred_l_y], dim=-1)
polys = polys.flatten(2)
return polys