[Enhancement] Support mmocr v0.4+ (#115)

* support mmocr v0.4+

* 0.4.0 -> 0.4.1
This commit is contained in:
AllentDan 2022-02-07 13:47:38 +08:00 committed by GitHub
parent 230596bad9
commit 51fa2ff566
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 10 additions and 4 deletions

View File

@ -98,7 +98,7 @@ class End2EndModel(BaseBackendModel):
def show_result(self, def show_result(self,
img: np.ndarray, img: np.ndarray,
result: list, result: list,
win_name: str, win_name: str = '',
show: bool = True, show: bool = True,
out_file: str = None): out_file: str = None):
"""Show predictions of classification. """Show predictions of classification.

View File

@ -4,6 +4,7 @@ from typing import Optional, Union
import mmcv import mmcv
import torch import torch
from mmcv.utils import Registry from mmcv.utils import Registry
from packaging import version
from torch.utils.data import DataLoader, Dataset from torch.utils.data import DataLoader, Dataset
from mmdeploy.codebase.base import CODEBASE, BaseTask, MMCodebase from mmdeploy.codebase.base import CODEBASE, BaseTask, MMCodebase
@ -137,6 +138,11 @@ class MMOCR(MMCodebase):
Returns: Returns:
list: The prediction results. list: The prediction results.
""" """
import mmocr
# fixed the bug when using `--show-dir` after mocr v0.4.1
if version.parse(mmocr.__version__) < version.parse('0.4.1'):
from mmdet.apis import single_gpu_test from mmdet.apis import single_gpu_test
else:
from mmocr.apis import single_gpu_test
outputs = single_gpu_test(model, data_loader, show, out_dir, **kwargs) outputs = single_gpu_test(model, data_loader, show, out_dir, **kwargs)
return outputs return outputs

View File

@ -118,7 +118,7 @@ class End2EndModel(BaseBackendModel):
def show_result(self, def show_result(self,
img: np.ndarray, img: np.ndarray,
result: dict, result: dict,
win_name: str, win_name: str = '',
show: bool = True, show: bool = True,
score_thr: float = 0.3, score_thr: float = 0.3,
out_file: str = None): out_file: str = None):

View File

@ -125,7 +125,7 @@ class End2EndModel(BaseBackendModel):
def show_result(self, def show_result(self,
img: np.ndarray, img: np.ndarray,
result: list, result: list,
win_name: str, win_name: str = '',
show: bool = True, show: bool = True,
score_thr: float = 0.3, score_thr: float = 0.3,
out_file: str = None): out_file: str = None):