mirror of https://github.com/open-mmlab/mmocr.git
fix sartrn onnxruntime test (#679)
* fix satrn test * disable aug test for deployment testpull/691/head
parent
60dfb2a85b
commit
abbae7bfd1
|
@ -43,7 +43,8 @@ class ONNXRuntimeDetector(TextDetectorMixin, SingleStageTextDetector):
|
|||
cfg: Any,
|
||||
device_id: int,
|
||||
show_score: bool = False):
|
||||
cfg.model.pop('type')
|
||||
if 'type' in cfg.model:
|
||||
cfg.model.pop('type')
|
||||
SingleStageTextDetector.__init__(self, **(cfg.model))
|
||||
TextDetectorMixin.__init__(self, show_score)
|
||||
import onnxruntime as ort
|
||||
|
@ -117,7 +118,8 @@ class ONNXRuntimeRecognizer(EncodeDecodeRecognizer):
|
|||
cfg: Any,
|
||||
device_id: int,
|
||||
show_score: bool = False):
|
||||
cfg.model.pop('type')
|
||||
if 'type' in cfg.model:
|
||||
cfg.model.pop('type')
|
||||
EncodeDecodeRecognizer.__init__(self, **(cfg.model))
|
||||
import onnxruntime as ort
|
||||
# get the custom op path
|
||||
|
@ -154,7 +156,16 @@ class ONNXRuntimeRecognizer(EncodeDecodeRecognizer):
|
|||
raise NotImplementedError('This method is not implemented.')
|
||||
|
||||
def aug_test(self, imgs, img_metas, **kwargs):
|
||||
raise NotImplementedError('This method is not implemented.')
|
||||
if isinstance(imgs, list):
|
||||
for idx, each_img in enumerate(imgs):
|
||||
if each_img.dim() == 3:
|
||||
imgs[idx] = each_img.unsqueeze(0)
|
||||
imgs = imgs[0] # avoid aug_test
|
||||
img_metas = img_metas[0]
|
||||
else:
|
||||
if len(img_metas) == 1 and isinstance(img_metas[0], list):
|
||||
img_metas = img_metas[0]
|
||||
return self.simple_test(imgs, img_metas=img_metas)
|
||||
|
||||
def extract_feat(self, imgs):
|
||||
raise NotImplementedError('This method is not implemented.')
|
||||
|
@ -197,7 +208,8 @@ class TensorRTDetector(TextDetectorMixin, SingleStageTextDetector):
|
|||
cfg: Any,
|
||||
device_id: int,
|
||||
show_score: bool = False):
|
||||
cfg.model.pop('type')
|
||||
if 'type' in cfg.model:
|
||||
cfg.model.pop('type')
|
||||
SingleStageTextDetector.__init__(self, **(cfg.model))
|
||||
TextDetectorMixin.__init__(self, show_score)
|
||||
from mmcv.tensorrt import TRTWrapper, load_tensorrt_plugin
|
||||
|
@ -252,7 +264,8 @@ class TensorRTRecognizer(EncodeDecodeRecognizer):
|
|||
cfg: Any,
|
||||
device_id: int,
|
||||
show_score: bool = False):
|
||||
cfg.model.pop('type')
|
||||
if 'type' in cfg.model:
|
||||
cfg.model.pop('type')
|
||||
EncodeDecodeRecognizer.__init__(self, **(cfg.model))
|
||||
from mmcv.tensorrt import TRTWrapper, load_tensorrt_plugin
|
||||
try:
|
||||
|
@ -271,7 +284,16 @@ class TensorRTRecognizer(EncodeDecodeRecognizer):
|
|||
raise NotImplementedError('This method is not implemented.')
|
||||
|
||||
def aug_test(self, imgs, img_metas, **kwargs):
|
||||
raise NotImplementedError('This method is not implemented.')
|
||||
if isinstance(imgs, list):
|
||||
for idx, each_img in enumerate(imgs):
|
||||
if each_img.dim() == 3:
|
||||
imgs[idx] = each_img.unsqueeze(0)
|
||||
imgs = imgs[0] # avoid aug_test
|
||||
img_metas = img_metas[0]
|
||||
else:
|
||||
if len(img_metas) == 1 and isinstance(img_metas[0], list):
|
||||
img_metas = img_metas[0]
|
||||
return self.simple_test(imgs, img_metas=img_metas)
|
||||
|
||||
def extract_feat(self, imgs):
|
||||
raise NotImplementedError('This method is not implemented.')
|
||||
|
|
|
@ -6,6 +6,7 @@ from mmcv.parallel import MMDataParallel
|
|||
from mmcv.runner import get_dist_info
|
||||
from mmdet.apis import single_gpu_test
|
||||
|
||||
from mmocr.apis.inference import disable_text_recog_aug_test
|
||||
from mmocr.core.deployment import (ONNXRuntimeDetector, ONNXRuntimeRecognizer,
|
||||
TensorRTDetector, TensorRTRecognizer)
|
||||
from mmocr.datasets import build_dataloader, build_dataset
|
||||
|
@ -63,6 +64,7 @@ def main():
|
|||
|
||||
# build the dataloader
|
||||
samples_per_gpu = 1
|
||||
cfg = disable_text_recog_aug_test(cfg)
|
||||
dataset = build_dataset(cfg.data.test)
|
||||
data_loader = build_dataloader(
|
||||
dataset,
|
||||
|
|
Loading…
Reference in New Issue