[Fix] Fix mmdeploy rewriter ()

* fix mmdeploy rewriter

* fix unit test
pull/368/head
q.yao 2022-12-14 09:59:32 +08:00 committed by GitHub
parent ac34b80e38
commit d640e7b310
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 4 additions and 5 deletions
mmyolo/deploy/models/dense_heads

View File

@ -48,8 +48,7 @@ def yolov5_bbox_decoder(priors: Tensor, bbox_preds: Tensor,
@FUNCTION_REWRITER.register_rewriter(
func_name='mmyolo.models.dense_heads.yolov5_head.'
'YOLOv5Head.predict_by_feat')
def yolov5_head__predict_by_feat(ctx,
self,
def yolov5_head__predict_by_feat(self,
cls_scores: List[Tensor],
bbox_preds: List[Tensor],
objectnesses: Optional[List[Tensor]] = None,
@ -85,6 +84,7 @@ def yolov5_head__predict_by_feat(ctx,
tensor in the tuple is (N, num_box), and each element
represents the class label of the corresponding box.
"""
ctx = FUNCTION_REWRITER.get_context()
detector_type = type(self)
deploy_cfg = ctx.cfg
use_efficientnms = deploy_cfg.get('use_efficientnms', False)
@ -163,7 +163,7 @@ def yolov5_head__predict_by_feat(ctx,
func_name='mmyolo.models.dense_heads.yolov5_head.'
'YOLOv5Head.predict',
backend='rknn')
def yolov5_head__predict__rknn(ctx, self, x: Tuple[Tensor], *args,
def yolov5_head__predict__rknn(self, x: Tuple[Tensor], *args,
**kwargs) -> Tuple[Tensor, Tensor, Tensor]:
"""Perform forward propagation of the detection head and predict detection
results on the features of the upstream network.
@ -181,8 +181,7 @@ def yolov5_head__predict__rknn(ctx, self, x: Tuple[Tensor], *args,
'YOLOv5HeadModule.forward',
backend='rknn')
def yolov5_head_module__forward__rknn(
ctx, self, x: Tensor, *args,
**kwargs) -> Tuple[Tensor, Tensor, Tensor]:
self, x: Tensor, *args, **kwargs) -> Tuple[Tensor, Tensor, Tensor]:
"""Forward feature of a single scale level."""
out = []
for i, feat in enumerate(x):