From abbae7bfd1aae4109970ad82ef9b087cd6c11000 Mon Sep 17 00:00:00 2001 From: AllentDan <41138331+AllentDan@users.noreply.github.com> Date: Tue, 21 Dec 2021 13:52:11 +0800 Subject: [PATCH] fix sartrn onnxruntime test (#679) * fix satrn test * disable aug test for deployment test --- mmocr/core/deployment/deploy_utils.py | 34 ++++++++++++++++++++++----- tools/deployment/deploy_test.py | 2 ++ 2 files changed, 30 insertions(+), 6 deletions(-) diff --git a/mmocr/core/deployment/deploy_utils.py b/mmocr/core/deployment/deploy_utils.py index 27a1c772..a82a324c 100644 --- a/mmocr/core/deployment/deploy_utils.py +++ b/mmocr/core/deployment/deploy_utils.py @@ -43,7 +43,8 @@ class ONNXRuntimeDetector(TextDetectorMixin, SingleStageTextDetector): cfg: Any, device_id: int, show_score: bool = False): - cfg.model.pop('type') + if 'type' in cfg.model: + cfg.model.pop('type') SingleStageTextDetector.__init__(self, **(cfg.model)) TextDetectorMixin.__init__(self, show_score) import onnxruntime as ort @@ -117,7 +118,8 @@ class ONNXRuntimeRecognizer(EncodeDecodeRecognizer): cfg: Any, device_id: int, show_score: bool = False): - cfg.model.pop('type') + if 'type' in cfg.model: + cfg.model.pop('type') EncodeDecodeRecognizer.__init__(self, **(cfg.model)) import onnxruntime as ort # get the custom op path @@ -154,7 +156,16 @@ class ONNXRuntimeRecognizer(EncodeDecodeRecognizer): raise NotImplementedError('This method is not implemented.') def aug_test(self, imgs, img_metas, **kwargs): - raise NotImplementedError('This method is not implemented.') + if isinstance(imgs, list): + for idx, each_img in enumerate(imgs): + if each_img.dim() == 3: + imgs[idx] = each_img.unsqueeze(0) + imgs = imgs[0] # avoid aug_test + img_metas = img_metas[0] + else: + if len(img_metas) == 1 and isinstance(img_metas[0], list): + img_metas = img_metas[0] + return self.simple_test(imgs, img_metas=img_metas) def extract_feat(self, imgs): raise NotImplementedError('This method is not implemented.') @@ -197,7 +208,8 @@ class TensorRTDetector(TextDetectorMixin, SingleStageTextDetector): cfg: Any, device_id: int, show_score: bool = False): - cfg.model.pop('type') + if 'type' in cfg.model: + cfg.model.pop('type') SingleStageTextDetector.__init__(self, **(cfg.model)) TextDetectorMixin.__init__(self, show_score) from mmcv.tensorrt import TRTWrapper, load_tensorrt_plugin @@ -252,7 +264,8 @@ class TensorRTRecognizer(EncodeDecodeRecognizer): cfg: Any, device_id: int, show_score: bool = False): - cfg.model.pop('type') + if 'type' in cfg.model: + cfg.model.pop('type') EncodeDecodeRecognizer.__init__(self, **(cfg.model)) from mmcv.tensorrt import TRTWrapper, load_tensorrt_plugin try: @@ -271,7 +284,16 @@ class TensorRTRecognizer(EncodeDecodeRecognizer): raise NotImplementedError('This method is not implemented.') def aug_test(self, imgs, img_metas, **kwargs): - raise NotImplementedError('This method is not implemented.') + if isinstance(imgs, list): + for idx, each_img in enumerate(imgs): + if each_img.dim() == 3: + imgs[idx] = each_img.unsqueeze(0) + imgs = imgs[0] # avoid aug_test + img_metas = img_metas[0] + else: + if len(img_metas) == 1 and isinstance(img_metas[0], list): + img_metas = img_metas[0] + return self.simple_test(imgs, img_metas=img_metas) def extract_feat(self, imgs): raise NotImplementedError('This method is not implemented.') diff --git a/tools/deployment/deploy_test.py b/tools/deployment/deploy_test.py index 33580663..7934087b 100644 --- a/tools/deployment/deploy_test.py +++ b/tools/deployment/deploy_test.py @@ -6,6 +6,7 @@ from mmcv.parallel import MMDataParallel from mmcv.runner import get_dist_info from mmdet.apis import single_gpu_test +from mmocr.apis.inference import disable_text_recog_aug_test from mmocr.core.deployment import (ONNXRuntimeDetector, ONNXRuntimeRecognizer, TensorRTDetector, TensorRTRecognizer) from mmocr.datasets import build_dataloader, build_dataset @@ -63,6 +64,7 @@ def main(): # build the dataloader samples_per_gpu = 1 + cfg = disable_text_recog_aug_test(cfg) dataset = build_dataset(cfg.data.test) data_loader = build_dataloader( dataset,