[Fix] fix centernet (#1768)

* fix centernet

* update sdk transform
pull/1811/head
q.yao 2023-02-24 10:50:03 +08:00 committed by GitHub
parent b5fef4873a
commit bf36950f0e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 56 additions and 1 deletions

View File

@ -36,7 +36,9 @@ class DefaultFormatBundle : public Transform {
} }
} }
if (!data.contains("scale_factor")) { if (!data.contains("scale_factor")) {
data["scale_factor"].push_back(1.0); for (int i = 0; i < 4; ++i) {
data["scale_factor"].push_back(1.0);
}
} }
if (!data.contains("img_norm_cfg")) { if (!data.contains("img_norm_cfg")) {
int channel = tensor.shape()[3]; int channel = tensor.shape()[3];

View File

@ -41,6 +41,10 @@ class Pad : public Transform {
} else { } else {
pad_val_ = 0.0f; pad_val_ = 0.0f;
} }
logical_or_val_ = args.value("logical_or_val", 0);
add_pix_val_ = args.value("add_pix_val", 0);
pad_to_square_ = args.value("pad_to_square", false); pad_to_square_ = args.value("pad_to_square", false);
padding_mode_ = args.value("padding_mode", std::string("constant")); padding_mode_ = args.value("padding_mode", std::string("constant"));
orientation_agnostic_ = args.value("orientation_agnostic", false); orientation_agnostic_ = args.value("orientation_agnostic", false);
@ -89,6 +93,16 @@ class Pad : public Transform {
data["pad_size_divisor"] = size_divisor_; data["pad_size_divisor"] = size_divisor_;
data["pad_fixed_size"].push_back(pad_h); data["pad_fixed_size"].push_back(pad_h);
data["pad_fixed_size"].push_back(pad_w); data["pad_fixed_size"].push_back(pad_w);
} else if (logical_or_val_ > 0) {
int pad_h = (height | logical_or_val_) + add_pix_val_;
int pad_w = (width | logical_or_val_) + add_pix_val_;
int offset_h = pad_h / 2 - height / 2;
int offset_w = pad_w / 2 - width / 2;
padding = {offset_w, offset_h, pad_w - width - offset_w, pad_h - height - offset_h};
data["border"].push_back(offset_h);
data["border"].push_back(offset_w);
data["border"].push_back(offset_h + height);
data["border"].push_back(offset_w + width);
} else { } else {
output_tensor = tensor; output_tensor = tensor;
data["pad_fixed_size"].push_back(height); data["pad_fixed_size"].push_back(height);
@ -124,6 +138,8 @@ class Pad : public Transform {
operation::Managed<operation::Pad> pad_; operation::Managed<operation::Pad> pad_;
std::array<int, 2> size_; std::array<int, 2> size_;
int size_divisor_; int size_divisor_;
int logical_or_val_;
int add_pix_val_;
float pad_val_; float pad_val_;
bool pad_to_square_; bool pad_to_square_;
bool orientation_agnostic_; bool orientation_agnostic_;

View File

@ -7,3 +7,4 @@ from . import necks # noqa: F401,F403
from . import roi_heads # noqa: F401,F403 from . import roi_heads # noqa: F401,F403
from . import task_modules # noqa: F401,F403 from . import task_modules # noqa: F401,F403
from . import transformer # noqa: F401,F403 from . import transformer # noqa: F401,F403
from . import utils # noqa: F401,F403

View File

@ -0,0 +1,2 @@
# Copyright (c) OpenMMLab. All rights reserved.
from . import gaussian_target # noqa: F401,F403

View File

@ -0,0 +1,20 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.utils.gaussian_target.get_topk_from_heatmap')
def get_topk_from_heatmap__default(scores, k=20):
"""Get top k positions from heatmap.
Replace view(batch, -1) with flatten
"""
height, width = scores.size()[2:]
topk_scores, topk_inds = torch.topk(scores.flatten(1), k)
topk_clses = topk_inds // (height * width)
topk_inds = topk_inds % (height * width)
topk_ys = topk_inds // width
topk_xs = (topk_inds % width).int().float()
return topk_scores, topk_inds, topk_clses, topk_ys, topk_xs

View File

@ -68,3 +68,17 @@ def test_get_mmdet_params():
pre_top_k=-1, pre_top_k=-1,
keep_top_k=100, keep_top_k=100,
background_label_id=-1) background_label_id=-1)
def test_get_topk_from_heatmap():
from mmdet.models.utils.gaussian_target import get_topk_from_heatmap
from mmdeploy.codebase.mmdet.models.utils.gaussian_target import \
get_topk_from_heatmap__default
scores = torch.rand(1, 2, 4, 4)
gts = get_topk_from_heatmap(scores, k=20)
outs = get_topk_from_heatmap__default(scores, k=20)
for gt, out in zip(gts, outs):
torch.testing.assert_allclose(gt, out)