parent
3be1779e66
commit
ab5c51f3ab
|
@ -104,23 +104,23 @@ def base_dense_head__get_bbox(ctx,
|
|||
score_factors = score_factors.permute(0, 2, 3,
|
||||
1).reshape(batch_size,
|
||||
-1).sigmoid()
|
||||
score_factors = score_factors.unsqueeze(2)
|
||||
bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 4)
|
||||
if not is_dynamic_flag:
|
||||
priors = priors.data
|
||||
priors = priors.expand(batch_size, -1, priors.size(-1))
|
||||
if pre_topk > 0:
|
||||
if with_score_factors:
|
||||
nms_pre_score = (nms_pre_score * score_factors[..., None])
|
||||
nms_pre_score = nms_pre_score * score_factors
|
||||
if backend == Backend.TENSORRT:
|
||||
priors = pad_with_value(priors, 1, pre_topk)
|
||||
bbox_pred = pad_with_value(bbox_pred, 1, pre_topk)
|
||||
scores = pad_with_value(scores, 1, pre_topk, 0.)
|
||||
nms_pre_score = pad_with_value(nms_pre_score, 1, pre_topk, 0.)
|
||||
if with_score_factors:
|
||||
score_factors = pad_with_value(
|
||||
score_factors.unsqueeze(2), 1, pre_topk, 0.)
|
||||
else:
|
||||
score_factors = score_factors.unsqueeze(2)
|
||||
score_factors = pad_with_value(score_factors, 1, pre_topk,
|
||||
0.)
|
||||
|
||||
# Get maximum scores for foreground classes.
|
||||
if self.use_sigmoid_cls:
|
||||
max_scores, _ = nms_pre_score.max(-1)
|
||||
|
|
|
@ -11,8 +11,8 @@ import torch
|
|||
|
||||
from mmdeploy.codebase import import_codebase
|
||||
from mmdeploy.utils import Backend, Codebase
|
||||
from mmdeploy.utils.test import (WrapModel, backend_checker, check_backend,
|
||||
get_model_outputs, get_rewrite_outputs)
|
||||
from mmdeploy.utils.test import (WrapModel, check_backend, get_model_outputs,
|
||||
get_rewrite_outputs)
|
||||
|
||||
import_codebase(Codebase.MMDET)
|
||||
|
||||
|
@ -292,15 +292,16 @@ def _replace_r50_with_r18(model):
|
|||
return model
|
||||
|
||||
|
||||
@pytest.mark.parametrize('backend', [Backend.ONNXRUNTIME])
|
||||
@pytest.mark.parametrize('model_cfg_path', [
|
||||
'tests/test_codebase/test_mmdet/data/single_stage_model.json',
|
||||
'tests/test_codebase/test_mmdet/data/mask_model.json'
|
||||
])
|
||||
@backend_checker(Backend.ONNXRUNTIME)
|
||||
def test_forward_of_base_detector(model_cfg_path):
|
||||
def test_forward_of_base_detector(model_cfg_path, backend):
|
||||
check_backend(backend)
|
||||
deploy_cfg = mmcv.Config(
|
||||
dict(
|
||||
backend_config=dict(type='onnxruntime'),
|
||||
backend_config=dict(type=backend.value),
|
||||
onnx_config=dict(
|
||||
output_names=['dets', 'labels'], input_shape=None),
|
||||
codebase_config=dict(
|
||||
|
|
Loading…
Reference in New Issue