fix sartrn onnxruntime test (#679)

* fix satrn test

* disable aug test for deployment test
pull/691/head
AllentDan 2021-12-21 13:52:11 +08:00 committed by GitHub
parent 60dfb2a85b
commit abbae7bfd1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 30 additions and 6 deletions

View File

@ -43,7 +43,8 @@ class ONNXRuntimeDetector(TextDetectorMixin, SingleStageTextDetector):
cfg: Any,
device_id: int,
show_score: bool = False):
cfg.model.pop('type')
if 'type' in cfg.model:
cfg.model.pop('type')
SingleStageTextDetector.__init__(self, **(cfg.model))
TextDetectorMixin.__init__(self, show_score)
import onnxruntime as ort
@ -117,7 +118,8 @@ class ONNXRuntimeRecognizer(EncodeDecodeRecognizer):
cfg: Any,
device_id: int,
show_score: bool = False):
cfg.model.pop('type')
if 'type' in cfg.model:
cfg.model.pop('type')
EncodeDecodeRecognizer.__init__(self, **(cfg.model))
import onnxruntime as ort
# get the custom op path
@ -154,7 +156,16 @@ class ONNXRuntimeRecognizer(EncodeDecodeRecognizer):
raise NotImplementedError('This method is not implemented.')
def aug_test(self, imgs, img_metas, **kwargs):
raise NotImplementedError('This method is not implemented.')
if isinstance(imgs, list):
for idx, each_img in enumerate(imgs):
if each_img.dim() == 3:
imgs[idx] = each_img.unsqueeze(0)
imgs = imgs[0] # avoid aug_test
img_metas = img_metas[0]
else:
if len(img_metas) == 1 and isinstance(img_metas[0], list):
img_metas = img_metas[0]
return self.simple_test(imgs, img_metas=img_metas)
def extract_feat(self, imgs):
raise NotImplementedError('This method is not implemented.')
@ -197,7 +208,8 @@ class TensorRTDetector(TextDetectorMixin, SingleStageTextDetector):
cfg: Any,
device_id: int,
show_score: bool = False):
cfg.model.pop('type')
if 'type' in cfg.model:
cfg.model.pop('type')
SingleStageTextDetector.__init__(self, **(cfg.model))
TextDetectorMixin.__init__(self, show_score)
from mmcv.tensorrt import TRTWrapper, load_tensorrt_plugin
@ -252,7 +264,8 @@ class TensorRTRecognizer(EncodeDecodeRecognizer):
cfg: Any,
device_id: int,
show_score: bool = False):
cfg.model.pop('type')
if 'type' in cfg.model:
cfg.model.pop('type')
EncodeDecodeRecognizer.__init__(self, **(cfg.model))
from mmcv.tensorrt import TRTWrapper, load_tensorrt_plugin
try:
@ -271,7 +284,16 @@ class TensorRTRecognizer(EncodeDecodeRecognizer):
raise NotImplementedError('This method is not implemented.')
def aug_test(self, imgs, img_metas, **kwargs):
raise NotImplementedError('This method is not implemented.')
if isinstance(imgs, list):
for idx, each_img in enumerate(imgs):
if each_img.dim() == 3:
imgs[idx] = each_img.unsqueeze(0)
imgs = imgs[0] # avoid aug_test
img_metas = img_metas[0]
else:
if len(img_metas) == 1 and isinstance(img_metas[0], list):
img_metas = img_metas[0]
return self.simple_test(imgs, img_metas=img_metas)
def extract_feat(self, imgs):
raise NotImplementedError('This method is not implemented.')

View File

@ -6,6 +6,7 @@ from mmcv.parallel import MMDataParallel
from mmcv.runner import get_dist_info
from mmdet.apis import single_gpu_test
from mmocr.apis.inference import disable_text_recog_aug_test
from mmocr.core.deployment import (ONNXRuntimeDetector, ONNXRuntimeRecognizer,
TensorRTDetector, TensorRTRecognizer)
from mmocr.datasets import build_dataloader, build_dataset
@ -63,6 +64,7 @@ def main():
# build the dataloader
samples_per_gpu = 1
cfg = disable_text_recog_aug_test(cfg)
dataset = build_dataset(cfg.data.test)
data_loader = build_dataloader(
dataset,