[Fix] Fix VFNet test (#281)
* [Fix] fix bugs for mmcls performance test (#269) * fix bugs for mmcls performance test * fix yapf * add comments of CLASSES attribute * Fix test_get_bboxes_of_vfnet_head * Fix Co-authored-by: hanrui1sensetime <83800577+hanrui1sensetime@users.noreply.github.com>pull/1/head
parent
0f90a0af0a
commit
a96e5f9d76
|
@ -1 +0,0 @@
|
|||
Subproject commit d2ed955a32d6e61e4a1f8acd4dbf59b6accd888a
|
|
@ -76,13 +76,9 @@ def cascade_roi_head__simple_test(ctx, self, x, proposals, img_metas,
|
|||
cls_score = sum(ms_scores) / float(len(ms_scores))
|
||||
bbox_pred = bbox_pred.reshape(batch_size, num_proposals_per_img, 4)
|
||||
rois = rois.reshape(batch_size, num_proposals_per_img, -1)
|
||||
scale_factor = img_metas[0].get('scale_factor', None)
|
||||
det_bboxes, det_labels = self.bbox_head[-1].get_bboxes(
|
||||
rois,
|
||||
cls_score,
|
||||
bbox_pred,
|
||||
max_shape,
|
||||
img_metas[0]['scale_factor'],
|
||||
cfg=rcnn_test_cfg)
|
||||
rois, cls_score, bbox_pred, max_shape, scale_factor, cfg=rcnn_test_cfg)
|
||||
|
||||
if not self.with_mask:
|
||||
return det_bboxes, det_labels
|
||||
|
|
|
@ -1053,44 +1053,13 @@ def get_vfnet_head_model():
|
|||
return model
|
||||
|
||||
|
||||
@pytest.mark.parametrize('backend_type', [Backend.OPENVINO])
|
||||
@pytest.mark.parametrize('backend_type',
|
||||
[Backend.OPENVINO, Backend.ONNXRUNTIME])
|
||||
def test_get_bboxes_of_vfnet_head(backend_type: Backend):
|
||||
"""Test get_bboxes rewrite of VFNet head."""
|
||||
check_backend(backend_type)
|
||||
|
||||
class TestModel(torch.nn.Module):
|
||||
"""Stub for VFNetHead with fake bbox_preds operations.
|
||||
|
||||
Then bbox_preds will be one of the inputs to the ONNX graph.
|
||||
"""
|
||||
|
||||
def __init__(self, vfnet_head):
|
||||
super().__init__()
|
||||
self.vfnet_head = vfnet_head
|
||||
|
||||
def get_bboxes(self,
|
||||
cls_scores,
|
||||
bbox_preds,
|
||||
bbox_preds_refine,
|
||||
img_metas,
|
||||
cfg=None,
|
||||
rescale=None,
|
||||
with_nms=True):
|
||||
tmp_bbox_pred_refine = []
|
||||
for bbox_pred, bbox_pred_refine in zip(bbox_preds,
|
||||
bbox_preds_refine):
|
||||
tmp = bbox_pred_refine + bbox_pred
|
||||
tmp = tmp - bbox_pred
|
||||
tmp_bbox_pred_refine.append(tmp)
|
||||
bbox_preds_refine = tmp_bbox_pred_refine
|
||||
return self.vfnet_head.get_bboxes(cls_scores, bbox_preds,
|
||||
bbox_preds_refine, img_metas,
|
||||
cfg, rescale, with_nms)
|
||||
|
||||
test_model = TestModel(get_vfnet_head_model())
|
||||
test_model.requires_grad_(False)
|
||||
test_model.cpu().eval()
|
||||
|
||||
vfnet_head = get_vfnet_head_model()
|
||||
vfnet_head.cpu().eval()
|
||||
s = 16
|
||||
img_metas = [{
|
||||
'scale_factor': np.ones(4),
|
||||
|
@ -1116,32 +1085,24 @@ def test_get_bboxes_of_vfnet_head(backend_type: Backend):
|
|||
|
||||
seed_everything(1234)
|
||||
cls_score = [
|
||||
torch.rand(1, test_model.vfnet_head.num_classes, pow(2, i), pow(2, i))
|
||||
torch.rand(1, vfnet_head.num_classes, pow(2, i), pow(2, i))
|
||||
for i in range(5, 0, -1)
|
||||
]
|
||||
seed_everything(5678)
|
||||
bboxes = [torch.rand(1, 4, pow(2, i), pow(2, i)) for i in range(5, 0, -1)]
|
||||
seed_everything(9101)
|
||||
bbox_preds_refine = [
|
||||
torch.rand(1, 4, pow(2, i), pow(2, i)) for i in range(5, 0, -1)
|
||||
]
|
||||
|
||||
model_inputs = {
|
||||
'cls_scores': cls_score,
|
||||
'bbox_preds': bboxes,
|
||||
'bbox_preds_refine': bbox_preds_refine,
|
||||
'img_metas': img_metas
|
||||
}
|
||||
model_outputs = get_model_outputs(test_model, 'get_bboxes', model_inputs)
|
||||
model_outputs = get_model_outputs(vfnet_head, 'get_bboxes', model_inputs)
|
||||
|
||||
img_metas[0]['img_shape'] = torch.Tensor([s, s])
|
||||
wrapped_model = WrapModel(
|
||||
test_model, 'get_bboxes', img_metas=img_metas, with_nms=True)
|
||||
rewrite_inputs = {
|
||||
'cls_scores': cls_score,
|
||||
'bbox_preds': bboxes,
|
||||
'bbox_preds_refine': bbox_preds_refine
|
||||
}
|
||||
vfnet_head, 'get_bboxes', img_metas=img_metas, with_nms=True)
|
||||
rewrite_inputs = {'cls_scores': cls_score, 'bbox_preds': bboxes}
|
||||
rewrite_outputs, is_backend_output = get_rewrite_outputs(
|
||||
wrapped_model=wrapped_model,
|
||||
model_inputs=rewrite_inputs,
|
||||
|
|
Loading…
Reference in New Issue