fix mmdet tests (#302)

* fix mmdet tests

* fix
pull/1/head
RunningLeon 2021-12-18 14:36:54 +08:00 committed by GitHub
parent 3be1779e66
commit ab5c51f3ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 11 additions and 10 deletions

View File

@ -104,23 +104,23 @@ def base_dense_head__get_bbox(ctx,
score_factors = score_factors.permute(0, 2, 3,
1).reshape(batch_size,
-1).sigmoid()
score_factors = score_factors.unsqueeze(2)
bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 4)
if not is_dynamic_flag:
priors = priors.data
priors = priors.expand(batch_size, -1, priors.size(-1))
if pre_topk > 0:
if with_score_factors:
nms_pre_score = (nms_pre_score * score_factors[..., None])
nms_pre_score = nms_pre_score * score_factors
if backend == Backend.TENSORRT:
priors = pad_with_value(priors, 1, pre_topk)
bbox_pred = pad_with_value(bbox_pred, 1, pre_topk)
scores = pad_with_value(scores, 1, pre_topk, 0.)
nms_pre_score = pad_with_value(nms_pre_score, 1, pre_topk, 0.)
if with_score_factors:
score_factors = pad_with_value(
score_factors.unsqueeze(2), 1, pre_topk, 0.)
else:
score_factors = score_factors.unsqueeze(2)
score_factors = pad_with_value(score_factors, 1, pre_topk,
0.)
# Get maximum scores for foreground classes.
if self.use_sigmoid_cls:
max_scores, _ = nms_pre_score.max(-1)

View File

@ -11,8 +11,8 @@ import torch
from mmdeploy.codebase import import_codebase
from mmdeploy.utils import Backend, Codebase
from mmdeploy.utils.test import (WrapModel, backend_checker, check_backend,
get_model_outputs, get_rewrite_outputs)
from mmdeploy.utils.test import (WrapModel, check_backend, get_model_outputs,
get_rewrite_outputs)
import_codebase(Codebase.MMDET)
@ -292,15 +292,16 @@ def _replace_r50_with_r18(model):
return model
@pytest.mark.parametrize('backend', [Backend.ONNXRUNTIME])
@pytest.mark.parametrize('model_cfg_path', [
'tests/test_codebase/test_mmdet/data/single_stage_model.json',
'tests/test_codebase/test_mmdet/data/mask_model.json'
])
@backend_checker(Backend.ONNXRUNTIME)
def test_forward_of_base_detector(model_cfg_path):
def test_forward_of_base_detector(model_cfg_path, backend):
check_backend(backend)
deploy_cfg = mmcv.Config(
dict(
backend_config=dict(type='onnxruntime'),
backend_config=dict(type=backend.value),
onnx_config=dict(
output_names=['dets', 'labels'], input_shape=None),
codebase_config=dict(