mirror of https://github.com/open-mmlab/mmocr.git
[Refactor] ocr.py (#1344)
* [Feature] Add BaseInferencer] * [Feature] Add Det&Rec Inferencer * [Feature] Add KIEInferencer * [Feature] Add MMOCRInferencer * [Refactor] update ocr.py * update links * update two links * remove ocr.py * move ocr.py and add loadfromndarray Co-authored-by: xinyu <wangxinyu2017@gmail.com> Co-authored-by: liukuikun <liukuikun@sensetime.com>pull/1362/head
parent
dbb346afed
commit
db6ce0d95e
|
@ -0,0 +1,2 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .inferencers import * # NOQA
|
|
@ -0,0 +1,10 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .kie_inferencer import KIEInferencer
|
||||
from .mmocr_inferencer import MMOCRInferencer
|
||||
from .textdet_inferencer import TextDetInferencer
|
||||
from .textrec_inferencer import TextRecInferencer
|
||||
|
||||
__all__ = [
|
||||
'TextDetInferencer', 'TextRecInferencer', 'KIEInferencer',
|
||||
'MMOCRInferencer'
|
||||
]
|
|
@ -0,0 +1,195 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine.config import Config
|
||||
from mmengine.runner import load_checkpoint
|
||||
from mmengine.structures import InstanceData
|
||||
|
||||
from mmocr.registry import MODELS, VISUALIZERS
|
||||
from mmocr.utils import ConfigType
|
||||
|
||||
InstanceList = List[InstanceData]
|
||||
InputType = Union[str, np.ndarray]
|
||||
InputsType = Union[InputType, Sequence[InputType]]
|
||||
PredType = Union[InstanceData, InstanceList]
|
||||
ImgType = Union[np.ndarray, Sequence[np.ndarray]]
|
||||
ResType = Union[Dict, List[Dict]]
|
||||
|
||||
|
||||
class BaseInferencer:
|
||||
"""Base inferencer.
|
||||
|
||||
Args:
|
||||
model (str or ConfigType): Model config or the path to it.
|
||||
ckpt (str, optional): Path to the checkpoint.
|
||||
device (str, optional): Device to run inference. If None, the best
|
||||
device will be automatically used.
|
||||
show (bool): Whether to display the image in a popup window.
|
||||
Defaults to False.
|
||||
wait_time (float): The interval of show (s). Defaults to 0.
|
||||
draw_pred (bool): Whether to draw predicted bounding boxes.
|
||||
Defaults to True.
|
||||
pred_score_thr (float): Minimum score of bboxes to draw.
|
||||
Defaults to 0.3.
|
||||
img_out_dir (str): Output directory of images. Defaults to ''.
|
||||
pred_out_file: File to save the inference results. If left as empty, no
|
||||
file will be saved.
|
||||
print_result (bool): Whether to print the result.
|
||||
Defaults to False.
|
||||
"""
|
||||
|
||||
func_kwargs = dict(preprocess=[], forward=[], visualize=[], postprocess=[])
|
||||
func_order = dict(preprocess=0, forward=1, visualize=2, postprocess=3)
|
||||
|
||||
def __init__(self,
|
||||
config: Union[ConfigType, str],
|
||||
ckpt: Optional[str],
|
||||
device: Optional[str] = None,
|
||||
**kwargs) -> None:
|
||||
# Load config to cfg
|
||||
if isinstance(config, str):
|
||||
cfg = Config.fromfile(config)
|
||||
elif not isinstance(config, ConfigType):
|
||||
raise TypeError('config must be a filename or any ConfigType'
|
||||
f'object, but got {type(cfg)}')
|
||||
if cfg.model.get('pretrained'):
|
||||
cfg.model.pretrained = None
|
||||
|
||||
if device is None:
|
||||
device = torch.device(
|
||||
'cuda' if torch.cuda.is_available() else 'cpu')
|
||||
self._init_model(cfg, ckpt, device)
|
||||
self._init_pipeline(cfg)
|
||||
self._init_visualizer(cfg)
|
||||
self.base_params = self._dispatch_kwargs(**kwargs)
|
||||
|
||||
def _init_model(self, cfg: Union[ConfigType, str], ckpt: Optional[str],
|
||||
device: str) -> None:
|
||||
"""Initialize the model with the given config and checkpoint on the
|
||||
specific device."""
|
||||
model = MODELS.build(cfg.model)
|
||||
if ckpt is not None:
|
||||
ckpt = load_checkpoint(model, ckpt, map_location='cpu')
|
||||
model.cfg = cfg.model
|
||||
model.to(device)
|
||||
model.eval()
|
||||
self.model = model
|
||||
|
||||
def _init_pipeline(self, cfg: ConfigType) -> None:
|
||||
"""Initialize the test pipeline."""
|
||||
raise NotImplementedError
|
||||
|
||||
def _init_visualizer(self, cfg: ConfigType) -> None:
|
||||
"""Initialize visualizers."""
|
||||
# TODO: We don't export images via backends since the interface
|
||||
# of the visualizer will have to be refactored.
|
||||
self.visualizer = None
|
||||
if 'visualizer' in cfg:
|
||||
ts = str(datetime.timestamp(datetime.now()))
|
||||
cfg.visualizer['name'] = f'inferencer{ts}'
|
||||
self.visualizer = VISUALIZERS.build(cfg.visualizer)
|
||||
|
||||
def _dispatch_kwargs(self, **kwargs) -> Tuple[Dict, Dict, Dict, Dict]:
|
||||
"""Dispatch kwargs to preprocess(), forward(), visualize() and
|
||||
postprocess() according to the actual demands."""
|
||||
results = [{}, {}, {}, {}]
|
||||
dispatched_kwargs = set()
|
||||
|
||||
# Dispatch kwargs according to self.func_kwargs
|
||||
for func_name, func_kwargs in self.func_kwargs.items():
|
||||
for func_kwarg in func_kwargs:
|
||||
if func_kwarg in kwargs:
|
||||
dispatched_kwargs.add(func_kwarg)
|
||||
results[self.func_order[func_name]][func_kwarg] = kwargs[
|
||||
func_kwarg]
|
||||
|
||||
# Find if there is any kwargs that are not dispatched
|
||||
for kwarg in kwargs:
|
||||
if kwarg not in dispatched_kwargs:
|
||||
raise ValueError(f'Unknown kwarg: {kwarg}')
|
||||
|
||||
return results
|
||||
|
||||
def __call__(self, inputs: InputsType,
|
||||
**kwargs) -> Union[Dict, List[Dict]]:
|
||||
"""Call the inferencer.
|
||||
|
||||
Args:
|
||||
user_inputs: Inputs for the inferencer.
|
||||
kwargs: Keyword arguments for the inferencer.
|
||||
"""
|
||||
|
||||
params = self._dispatch_kwargs(**kwargs)
|
||||
preprocess_kwargs = self.base_params[0].copy()
|
||||
preprocess_kwargs.update(params[0])
|
||||
forward_kwargs = self.base_params[1].copy()
|
||||
forward_kwargs.update(params[1])
|
||||
visualize_kwargs = self.base_params[2].copy()
|
||||
visualize_kwargs.update(params[2])
|
||||
postprocess_kwargs = self.base_params[3].copy()
|
||||
postprocess_kwargs.update(params[3])
|
||||
|
||||
data = self.preprocess(inputs, **preprocess_kwargs)
|
||||
preds = self.forward(data, **forward_kwargs)
|
||||
imgs = self.visualize(inputs, preds, **visualize_kwargs)
|
||||
results = self.postprocess(preds, imgs, **postprocess_kwargs)
|
||||
return results
|
||||
|
||||
def preprocess(self, inputs: InputsType) -> List[Dict]:
|
||||
"""Process the inputs into a model-feedable format."""
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, inputs: InputsType) -> PredType:
|
||||
"""Forward the inputs to the model."""
|
||||
with torch.no_grad():
|
||||
return self.model.test_step(inputs)
|
||||
|
||||
def visualize(self,
|
||||
inputs: InputsType,
|
||||
preds: PredType,
|
||||
show: bool = False,
|
||||
wait_time: int = 0,
|
||||
draw_pred: bool = True,
|
||||
pred_score_thr: float = 0.3,
|
||||
img_out_dir: str = '') -> List[np.ndarray]:
|
||||
"""Visualize predictions.
|
||||
|
||||
Args:
|
||||
inputs (List[Union[str, np.ndarray]]): Inputs for the inferencer.
|
||||
preds (List[Dict]): Predictions of the model.
|
||||
show (bool): Whether to display the image in a popup window.
|
||||
Defaults to False.
|
||||
wait_time (float): The interval of show (s). Defaults to 0.
|
||||
draw_pred (bool): Whether to draw predicted bounding boxes.
|
||||
Defaults to True.
|
||||
pred_score_thr (float): Minimum score of bboxes to draw.
|
||||
Defaults to 0.3.
|
||||
img_out_dir (str): Output directory of images. Defaults to ''.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def postprocess(
|
||||
self,
|
||||
preds: PredType,
|
||||
imgs: Optional[List[np.ndarray]] = None,
|
||||
) -> Union[ResType, Tuple[ResType, np.ndarray]]:
|
||||
"""Postprocess predictions.
|
||||
|
||||
Args:
|
||||
preds (List[Dict]): Predictions of the model.
|
||||
imgs (Optional[np.ndarray]): Visualized predictions.
|
||||
is_batch (bool): Whether the inputs are in a batch.
|
||||
Defaults to False.
|
||||
print_result (bool): Whether to print the result.
|
||||
Defaults to False.
|
||||
pred_out_file (str): Output file name to store predictions
|
||||
without images. Supported file formats are “json”, “yaml/yml”
|
||||
and “pickle/pkl”. Defaults to ''.
|
||||
|
||||
Returns:
|
||||
TODO
|
||||
"""
|
||||
raise NotImplementedError
|
|
@ -0,0 +1,267 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
from mmengine.dataset import Compose
|
||||
from mmengine.structures import InstanceData
|
||||
|
||||
from mmocr.utils import ConfigType
|
||||
from .base_inferencer import BaseInferencer
|
||||
|
||||
InstanceList = List[InstanceData]
|
||||
InputType = Union[str, np.ndarray]
|
||||
InputsType = Union[InputType, Sequence[InputType]]
|
||||
PredType = Union[InstanceData, InstanceList]
|
||||
ImgType = Union[np.ndarray, Sequence[np.ndarray]]
|
||||
ResType = Union[Dict, List[Dict], InstanceData, List[InstanceData]]
|
||||
|
||||
|
||||
class BaseMMOCRInferencer(BaseInferencer):
|
||||
"""Base inferencer.
|
||||
|
||||
Args:
|
||||
model (str or ConfigType): Model config or the path to it.
|
||||
ckpt (str, optional): Path to the checkpoint.
|
||||
device (str, optional): Device to run inference. If None, the best
|
||||
device will be automatically used.
|
||||
show (bool): Whether to display the image in a popup window.
|
||||
Defaults to False.
|
||||
wait_time (float): The interval of show (s). Defaults to 0.
|
||||
draw_pred (bool): Whether to draw predicted bounding boxes.
|
||||
Defaults to True.
|
||||
pred_score_thr (float): Minimum score of bboxes to draw.
|
||||
Defaults to 0.3.
|
||||
img_out_dir (str): Output directory of images. Defaults to ''.
|
||||
pred_out_file: File to save the inference results. If left as empty, no
|
||||
file will be saved.
|
||||
print_result (bool): Whether to print the result.
|
||||
Defaults to False.
|
||||
"""
|
||||
|
||||
func_kwargs = dict(
|
||||
preprocess=[],
|
||||
forward=[],
|
||||
visualize=[
|
||||
'show', 'wait_time', 'draw_pred', 'pred_score_thr', 'img_out_dir'
|
||||
],
|
||||
postprocess=['print_result', 'pred_out_file', 'get_datasample'])
|
||||
|
||||
def __init__(self,
|
||||
config: Union[ConfigType, str],
|
||||
ckpt: Optional[str],
|
||||
device: Optional[str] = None,
|
||||
**kwargs) -> None:
|
||||
# A global counter tracking the number of images processed, for
|
||||
# naming of the output images
|
||||
self.num_visualized_imgs = 0
|
||||
super().__init__(config=config, ckpt=ckpt, device=device, **kwargs)
|
||||
|
||||
def _init_pipeline(self, cfg: ConfigType) -> None:
|
||||
"""Initialize the test pipeline."""
|
||||
pipeline_cfg = cfg.test_dataloader.dataset.pipeline
|
||||
|
||||
# For inference, the key of ``instances`` is not used.
|
||||
if 'meta_keys' in pipeline_cfg[-1]:
|
||||
pipeline_cfg[-1]['meta_keys'] = tuple(
|
||||
meta_key for meta_key in pipeline_cfg[-1]['meta_keys']
|
||||
if meta_key != 'instances')
|
||||
|
||||
# Loading annotations is also not applicable
|
||||
idx = self._get_transform_idx(pipeline_cfg, 'LoadOCRAnnotations')
|
||||
if idx != -1:
|
||||
del pipeline_cfg[idx]
|
||||
|
||||
self.file_pipeline = Compose(pipeline_cfg)
|
||||
|
||||
load_img_idx = self._get_transform_idx(pipeline_cfg,
|
||||
'LoadImageFromFile')
|
||||
if load_img_idx == -1:
|
||||
raise ValueError(
|
||||
'LoadImageFromFile is not found in the test pipeline')
|
||||
pipeline_cfg[load_img_idx]['type'] = 'LoadImageFromNDArray'
|
||||
self.ndarray_pipeline = Compose(pipeline_cfg)
|
||||
|
||||
def _get_transform_idx(self, pipeline_cfg: ConfigType, name: str) -> int:
|
||||
"""Returns the index of the transform in a pipeline.
|
||||
|
||||
If the transform is not found, returns -1.
|
||||
"""
|
||||
for i, transform in enumerate(pipeline_cfg):
|
||||
if transform['type'] == name:
|
||||
return i
|
||||
return -1
|
||||
|
||||
def preprocess(self, inputs: InputsType) -> Dict:
|
||||
"""Process the inputs into a model-feedable format."""
|
||||
results = []
|
||||
for single_input in inputs:
|
||||
if isinstance(single_input, str):
|
||||
if osp.isdir(single_input):
|
||||
raise ValueError('Feeding a directory is not supported')
|
||||
# for img_path in os.listdir(single_input):
|
||||
# data_ =dict(img_path=osp.join(single_input,img_path))
|
||||
# results.append(self.file_pipeline(data_))
|
||||
else:
|
||||
data_ = dict(img_path=single_input)
|
||||
results.append(self.file_pipeline(data_))
|
||||
elif isinstance(single_input, np.ndarray):
|
||||
data_ = dict(img=single_input)
|
||||
results.append(self.ndarray_pipeline(data_))
|
||||
else:
|
||||
raise ValueError(
|
||||
f'Unsupported input type: {type(single_input)}')
|
||||
|
||||
return self._collate(results)
|
||||
|
||||
def _collate(self, results: List[Dict]) -> Dict:
|
||||
"""Collate the results from different images."""
|
||||
results = {key: [d[key] for d in results] for key in results[0]}
|
||||
return results
|
||||
|
||||
def __call__(self, user_inputs: InputsType,
|
||||
**kwargs) -> Union[Dict, List[Dict]]:
|
||||
"""Call the inferencer.
|
||||
|
||||
Args:
|
||||
user_inputs: Inputs for the inferencer.
|
||||
kwargs: Keyword arguments for the inferencer.
|
||||
"""
|
||||
|
||||
# Detect if user_inputs are in a batch
|
||||
is_batch = isinstance(user_inputs, (list, tuple))
|
||||
inputs = user_inputs if is_batch else [user_inputs]
|
||||
|
||||
params = self._dispatch_kwargs(**kwargs)
|
||||
preprocess_kwargs = self.base_params[0].copy()
|
||||
preprocess_kwargs.update(params[0])
|
||||
forward_kwargs = self.base_params[1].copy()
|
||||
forward_kwargs.update(params[1])
|
||||
visualize_kwargs = self.base_params[2].copy()
|
||||
visualize_kwargs.update(params[2])
|
||||
postprocess_kwargs = self.base_params[3].copy()
|
||||
postprocess_kwargs.update(params[3])
|
||||
|
||||
data = self.preprocess(inputs, **preprocess_kwargs)
|
||||
preds = self.forward(data, **forward_kwargs)
|
||||
imgs = self.visualize(inputs, preds, **visualize_kwargs)
|
||||
results = self.postprocess(
|
||||
preds, imgs, is_batch=is_batch, **postprocess_kwargs)
|
||||
return results
|
||||
|
||||
def visualize(self,
|
||||
inputs: InputsType,
|
||||
preds: PredType,
|
||||
show: bool = False,
|
||||
wait_time: int = 0,
|
||||
draw_pred: bool = True,
|
||||
pred_score_thr: float = 0.3,
|
||||
img_out_dir: str = '') -> List[np.ndarray]:
|
||||
"""Visualize predictions.
|
||||
|
||||
Args:
|
||||
inputs (List[Union[str, np.ndarray]]): Inputs for the inferencer.
|
||||
preds (List[Dict]): Predictions of the model.
|
||||
show (bool): Whether to display the image in a popup window.
|
||||
Defaults to False.
|
||||
wait_time (float): The interval of show (s). Defaults to 0.
|
||||
draw_pred (bool): Whether to draw predicted bounding boxes.
|
||||
Defaults to True.
|
||||
pred_score_thr (float): Minimum score of bboxes to draw.
|
||||
Defaults to 0.3.
|
||||
img_out_dir (str): Output directory of images. Defaults to ''.
|
||||
"""
|
||||
if self.visualizer is None or not show and img_out_dir == '':
|
||||
return None
|
||||
|
||||
if getattr(self, 'visualizer') is None:
|
||||
raise ValueError('Visualization needs the "visualizer" term'
|
||||
'defined in the config, but got None.')
|
||||
|
||||
results = []
|
||||
|
||||
for single_input, pred in zip(inputs, preds):
|
||||
if isinstance(single_input, str):
|
||||
img = mmcv.imread(single_input)
|
||||
img = img[:, :, ::-1]
|
||||
img_name = osp.basename(single_input)
|
||||
elif isinstance(single_input, np.ndarray):
|
||||
img = single_input.copy()
|
||||
img_num = str(self.num_visualized_imgs).zfill(8)
|
||||
img_name = f'{img_num}.jpg'
|
||||
else:
|
||||
raise ValueError('Unsupported input type: '
|
||||
f'{type(single_input)}')
|
||||
|
||||
out_file = osp.join(img_out_dir, img_name) if img_out_dir != '' \
|
||||
else None
|
||||
|
||||
self.visualizer.add_datasample(
|
||||
img_name,
|
||||
img,
|
||||
pred,
|
||||
show=show,
|
||||
wait_time=wait_time,
|
||||
draw_gt=False,
|
||||
draw_pred=draw_pred,
|
||||
pred_score_thr=pred_score_thr,
|
||||
out_file=out_file,
|
||||
)
|
||||
results.append(img)
|
||||
self.num_visualized_imgs += 1
|
||||
|
||||
return results
|
||||
|
||||
def postprocess(
|
||||
self,
|
||||
preds: PredType,
|
||||
imgs: Optional[List[np.ndarray]] = None,
|
||||
is_batch: bool = False,
|
||||
print_result: bool = False,
|
||||
pred_out_file: str = '',
|
||||
get_datasample: bool = False,
|
||||
) -> Union[ResType, Tuple[ResType, np.ndarray]]:
|
||||
"""Postprocess predictions.
|
||||
|
||||
Args:
|
||||
preds (List[Dict]): Predictions of the model.
|
||||
imgs (Optional[np.ndarray]): Visualized predictions.
|
||||
is_batch (bool): Whether the inputs are in a batch.
|
||||
Defaults to False.
|
||||
print_result (bool): Whether to print the result.
|
||||
Defaults to False.
|
||||
pred_out_file (str): Output file name to store predictions
|
||||
without images. Supported file formats are “json”, “yaml/yml”
|
||||
and “pickle/pkl”. Defaults to ''.
|
||||
get_datasample (bool): Whether to use Datasample to store
|
||||
inference results. If False, dict will be used.
|
||||
|
||||
Returns:
|
||||
TODO
|
||||
"""
|
||||
results = preds
|
||||
if not get_datasample:
|
||||
results = []
|
||||
for pred in preds:
|
||||
result = self._pred2dict(pred)
|
||||
results.append(result)
|
||||
if not is_batch:
|
||||
results = results[0]
|
||||
if print_result:
|
||||
print(results)
|
||||
# Add img to the results after printing
|
||||
if pred_out_file != '':
|
||||
mmcv.dump(results, pred_out_file)
|
||||
if imgs is None:
|
||||
return results
|
||||
return results, imgs
|
||||
|
||||
def _pred2dict(self, data_sample: InstanceData) -> Dict:
|
||||
"""Extract elements necessary to represent a prediction into a
|
||||
dictionary.
|
||||
|
||||
It's better to contain only basic data elements such as strings and
|
||||
numbers in order to guarantee it's json-serializable.
|
||||
"""
|
||||
raise NotImplementedError
|
|
@ -0,0 +1,194 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
from typing import Dict, List, Sequence
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
from mmengine.dataset import Compose
|
||||
|
||||
from mmocr.registry import DATASETS, VISUALIZERS
|
||||
from mmocr.structures import KIEDataSample
|
||||
from mmocr.utils import ConfigType
|
||||
from .base_mmocr_inferencer import BaseMMOCRInferencer, PredType
|
||||
|
||||
InputType = Dict
|
||||
InputsType = Sequence[Dict]
|
||||
|
||||
|
||||
class KIEInferencer(BaseMMOCRInferencer):
|
||||
"""
|
||||
Inputs:
|
||||
dict or list[dict]: A dictionary containing the following keys:
|
||||
'bbox', 'texts', ` in this format:
|
||||
|
||||
- img (str or ndarray): Path to the image or the image itself.
|
||||
- img_shape (tuple(int, int)): Image shape in (H, W). In
|
||||
- instances (list[dict]): A list of instances.
|
||||
- bbox (ndarray(dtype=np.float32)): Shape (4, ). Bounding box.
|
||||
- text (str): Annotation text.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{
|
||||
# A nested list of 4 numbers representing the bounding box of
|
||||
# the instance, in (x1, y1, x2, y2) order.
|
||||
'bbox': np.array([[x1, y1, x2, y2], [x1, y1, x2, y2], ...],
|
||||
dtype=np.int32),
|
||||
|
||||
# List of texts.
|
||||
"texts": ['text1', 'text2', ...],
|
||||
}
|
||||
"""
|
||||
|
||||
def _init_pipeline(self, cfg: ConfigType) -> None:
|
||||
"""Initialize the test pipeline."""
|
||||
pipeline_cfg = cfg.test_dataloader.dataset.pipeline
|
||||
idx = self._get_transform_idx(pipeline_cfg, 'LoadKIEAnnotations')
|
||||
if idx == -1:
|
||||
raise ValueError(
|
||||
'LoadKIEAnnotations is not found in the test pipeline')
|
||||
pipeline_cfg[idx]['with_label'] = False
|
||||
self.novisual = self._get_transform_idx(pipeline_cfg,
|
||||
'LoadImageFromFile') == -1
|
||||
# If it's in non-visual mode, self.pipeline will be specified.
|
||||
# Otherwise, file_pipeline and ndarray_pipeline will be specified.
|
||||
if self.novisual:
|
||||
self.pipeline = Compose(pipeline_cfg)
|
||||
else:
|
||||
return super()._init_pipeline(cfg)
|
||||
|
||||
def _init_visualizer(self, cfg: ConfigType) -> None:
|
||||
"""Initialize visualizers."""
|
||||
# TODO: We don't export images via backends since the interface
|
||||
# of the visualizer will have to be refactored.
|
||||
self.visualizer = None
|
||||
if 'visualizer' in cfg:
|
||||
self.visualizer = VISUALIZERS.build(cfg.visualizer)
|
||||
dataset = DATASETS.build(cfg.test_dataloader.dataset)
|
||||
self.visualizer.dataset_meta = dataset.metainfo
|
||||
|
||||
def preprocess(self, inputs: InputsType) -> List[Dict]:
|
||||
results = []
|
||||
for single_input in inputs:
|
||||
if self.novisual:
|
||||
if 'img' not in single_input and \
|
||||
'img_shape' not in single_input:
|
||||
raise ValueError(
|
||||
'KIEInferencer in no-visual mode '
|
||||
'requires input has "img" or "img_shape", but both are'
|
||||
' not found.')
|
||||
if 'img' in single_input:
|
||||
new_input = {
|
||||
k: v
|
||||
for k, v in single_input.items() if k != 'img'
|
||||
}
|
||||
img = single_input['img']
|
||||
if isinstance(img, str):
|
||||
img = mmcv.imread(img)
|
||||
new_input['img_shape'] = img.shape[::2]
|
||||
results.append(self.pipeline(new_input))
|
||||
else:
|
||||
if 'img' not in single_input:
|
||||
raise ValueError(
|
||||
'This inferencer is constructed to '
|
||||
'accept image inputs, but the input does not contain '
|
||||
'"img" key.')
|
||||
if isinstance(single_input['img'], str):
|
||||
data_ = {
|
||||
k: v
|
||||
for k, v in single_input.items() if k != 'img'
|
||||
}
|
||||
data_['img_path'] = single_input['img']
|
||||
results.append(self.file_pipeline(data_))
|
||||
elif isinstance(single_input['img'], np.ndarray):
|
||||
results.append(self.ndarray_pipeline(single_input))
|
||||
else:
|
||||
atype = type(single_input['img'])
|
||||
raise ValueError(f'Unsupported input type: {atype}')
|
||||
return self._collate(results)
|
||||
|
||||
def visualize(self,
|
||||
inputs: InputsType,
|
||||
preds: PredType,
|
||||
show: bool = False,
|
||||
wait_time: int = 0,
|
||||
draw_pred: bool = True,
|
||||
pred_score_thr: float = 0.3,
|
||||
img_out_dir: str = '') -> List[np.ndarray]:
|
||||
"""Visualize predictions.
|
||||
|
||||
Args:
|
||||
inputs (List[Union[str, np.ndarray]]): Inputs for the inferencer.
|
||||
preds (List[Dict]): Predictions of the model.
|
||||
show (bool): Whether to display the image in a popup window.
|
||||
Defaults to False.
|
||||
wait_time (float): The interval of show (s). Defaults to 0.
|
||||
draw_pred (bool): Whether to draw predicted bounding boxes.
|
||||
Defaults to True.
|
||||
pred_score_thr (float): Minimum score of bboxes to draw.
|
||||
Defaults to 0.3.
|
||||
img_out_dir (str): Output directory of images. Defaults to ''.
|
||||
"""
|
||||
if self.visualizer is None or not show and img_out_dir == '':
|
||||
return None
|
||||
|
||||
if getattr(self, 'visualizer') is None:
|
||||
raise ValueError('Visualization needs the "visualizer" term'
|
||||
'defined in the config, but got None.')
|
||||
|
||||
results = []
|
||||
|
||||
for single_input, pred in zip(inputs, preds):
|
||||
assert 'img' in single_input or 'img_shape' in single_input
|
||||
if 'img' in single_input:
|
||||
if isinstance(single_input['img'], str):
|
||||
img = mmcv.imread(single_input['img'])
|
||||
img_name = osp.basename(single_input['img'])
|
||||
elif isinstance(single_input['img'], np.ndarray):
|
||||
img = single_input['img'].copy()
|
||||
img_num = str(self.num_visualized_imgs).zfill(8)
|
||||
img_name = f'{img_num}.jpg'
|
||||
elif 'img_shape' in single_input:
|
||||
img = np.zeros(single_input['img_shape'], dtype=np.uint8)
|
||||
img_name = f'{img_num}.jpg'
|
||||
else:
|
||||
raise ValueError('Input does not contain either "img" or '
|
||||
'"img_shape"')
|
||||
|
||||
out_file = osp.join(img_out_dir, img_name) if img_out_dir != '' \
|
||||
else None
|
||||
|
||||
self.visualizer.add_datasample(
|
||||
img_name,
|
||||
img,
|
||||
pred,
|
||||
show=show,
|
||||
wait_time=wait_time,
|
||||
draw_gt=False,
|
||||
draw_pred=draw_pred,
|
||||
pred_score_thr=pred_score_thr,
|
||||
out_file=out_file,
|
||||
)
|
||||
results.append(img)
|
||||
self.num_visualized_imgs += 1
|
||||
|
||||
return results
|
||||
|
||||
def _pred2dict(self, data_sample: KIEDataSample) -> Dict:
|
||||
"""Extract elements necessary to represent a prediction into a
|
||||
dictionary. It's better to contain only basic data elements such as
|
||||
strings and numbers in order to guarantee it's json-serializable.
|
||||
|
||||
Args:
|
||||
data_sample (TextRecogDataSample): The data sample to be converted.
|
||||
|
||||
Returns:
|
||||
dict: The output dictionary.
|
||||
"""
|
||||
result = {}
|
||||
pred = data_sample.pred_instances
|
||||
result['scores'] = pred.scores.cpu().numpy().tolist()
|
||||
result['edge_scores'] = pred.edge_scores.cpu().numpy().tolist()
|
||||
result['edge_labels'] = pred.edge_labels.cpu().numpy().tolist()
|
||||
result['labels'] = pred.labels.cpu().numpy().tolist()
|
||||
return result
|
|
@ -0,0 +1,233 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
|
||||
from mmocr.registry import VISUALIZERS
|
||||
from mmocr.structures.textdet_data_sample import TextDetDataSample
|
||||
from mmocr.utils import ConfigType, bbox2poly, crop_img, poly2bbox
|
||||
from .base_mmocr_inferencer import (BaseMMOCRInferencer, InputsType, PredType,
|
||||
ResType)
|
||||
from .kie_inferencer import KIEInferencer
|
||||
from .textdet_inferencer import TextDetInferencer
|
||||
from .textrec_inferencer import TextRecInferencer
|
||||
|
||||
|
||||
class MMOCRInferencer(BaseMMOCRInferencer):
|
||||
|
||||
def __init__(self,
|
||||
det_config: Optional[Union[ConfigType, str]] = None,
|
||||
det_ckpt: Optional[str] = None,
|
||||
rec_config: Optional[Union[ConfigType, str]] = None,
|
||||
rec_ckpt: Optional[str] = None,
|
||||
kie_config: Optional[Union[ConfigType, str]] = None,
|
||||
kie_ckpt: Optional[str] = None,
|
||||
device: Optional[str] = None,
|
||||
**kwargs) -> None:
|
||||
|
||||
self.visualizer = None
|
||||
self.base_params = self._dispatch_kwargs(*kwargs)
|
||||
self.num_visualized_imgs = 0
|
||||
|
||||
if det_config is not None:
|
||||
self.textdet_inferencer = TextDetInferencer(
|
||||
det_config, det_ckpt, device)
|
||||
self.mode = 'det'
|
||||
if rec_config is not None:
|
||||
self.textrec_inferencer = TextRecInferencer(
|
||||
rec_config, rec_ckpt, device)
|
||||
if getattr(self, 'mode', None) == 'det':
|
||||
self.mode = 'det_rec'
|
||||
ts = str(datetime.timestamp(datetime.now()))
|
||||
self.visualizer = VISUALIZERS.build(
|
||||
dict(
|
||||
type='TextSpottingLocalVisualizer',
|
||||
name=f'inferencer{ts}'))
|
||||
else:
|
||||
self.mode = 'rec'
|
||||
if kie_config is not None:
|
||||
if det_config is None or rec_config is None:
|
||||
raise ValueError(
|
||||
'kie_config is only applicable when det_config and '
|
||||
'rec_config are both provided')
|
||||
self.kie_inferencer = KIEInferencer(kie_config, kie_ckpt, device)
|
||||
self.mode = 'det_rec_kie'
|
||||
|
||||
def preprocess(self, inputs: InputsType):
|
||||
new_inputs = []
|
||||
for single_input in inputs:
|
||||
if isinstance(single_input, str):
|
||||
if osp.isdir(single_input):
|
||||
raise ValueError('Feeding a directory is not supported')
|
||||
# for img_path in os.listdir(single_input):
|
||||
# new_inputs.append(
|
||||
# mmcv.imread(osp.join(single_input, img_path)))
|
||||
else:
|
||||
single_input = mmcv.imread(single_input)
|
||||
new_inputs.append(single_input)
|
||||
else:
|
||||
new_inputs.append(single_input)
|
||||
return new_inputs
|
||||
|
||||
def forward(self, inputs: InputsType) -> PredType:
|
||||
"""Forward the inputs to the model.
|
||||
|
||||
Args:
|
||||
inputs (InputsType): The inputs to be forwarded.
|
||||
Returns:
|
||||
Dict: The prediction results. Possibly with keys "det", "rec", and
|
||||
"kie"..
|
||||
"""
|
||||
result = {}
|
||||
if self.mode == 'rec':
|
||||
# The extra list wrapper here is for the ease of postprocessing
|
||||
self.rec_inputs = inputs
|
||||
result['rec'] = [
|
||||
self.textrec_inferencer(self.rec_inputs, get_datasample=True)
|
||||
]
|
||||
elif self.mode.startswith('det'):
|
||||
result['det'] = self.textdet_inferencer(
|
||||
inputs, get_datasample=True)
|
||||
if self.mode.startswith('det_rec'):
|
||||
result['rec'] = []
|
||||
for img, det_data_sample in zip(inputs, result['det']):
|
||||
det_pred = det_data_sample.pred_instances
|
||||
self.rec_inputs = []
|
||||
for polygon in det_pred['polygons']:
|
||||
# Roughly convert the polygon to a quadangle with
|
||||
# 4 points
|
||||
quad = bbox2poly(poly2bbox(polygon)).tolist()
|
||||
self.rec_inputs.append(crop_img(img, quad))
|
||||
result['rec'].append(
|
||||
self.textrec_inferencer(
|
||||
self.rec_inputs, get_datasample=True))
|
||||
if self.mode == 'det_rec_kie':
|
||||
self.kie_inputs = []
|
||||
for img, det_data_sample, rec_data_samples in zip(
|
||||
inputs, result['det'], result['rec']):
|
||||
det_pred = det_data_sample.pred_instances
|
||||
kie_input = dict(img=img)
|
||||
kie_input['instances'] = []
|
||||
for polygon, rec_data_sample in zip(
|
||||
det_pred['polygons'], rec_data_samples):
|
||||
kie_input['instances'].append(
|
||||
dict(
|
||||
bbox=poly2bbox(polygon),
|
||||
text=rec_data_sample.pred_text.item))
|
||||
self.kie_inputs.append(kie_input)
|
||||
result['kie'] = self.kie_inferencer(
|
||||
self.kie_inputs, get_datasample=True)
|
||||
return result
|
||||
|
||||
def visualize(self, inputs: InputsType, preds: PredType,
|
||||
**kwargs) -> List[np.ndarray]:
|
||||
"""Visualize predictions.
|
||||
|
||||
Args:
|
||||
inputs (List[Union[str, np.ndarray]]): Inputs for the inferencer.
|
||||
preds (List[Dict]): Predictions of the model.
|
||||
show (bool): Whether to display the image in a popup window.
|
||||
Defaults to False.
|
||||
wait_time (float): The interval of show (s). Defaults to 0.
|
||||
draw_pred (bool): Whether to draw predicted bounding boxes.
|
||||
Defaults to True.
|
||||
pred_score_thr (float): Minimum score of bboxes to draw.
|
||||
Defaults to 0.3.
|
||||
img_out_dir (str): Output directory of images. Defaults to ''.
|
||||
"""
|
||||
if 'kie' in self.mode:
|
||||
return self.kie_inferencer.visualize(self.kie_inputs, preds['kie'],
|
||||
**kwargs)
|
||||
elif 'rec' in self.mode:
|
||||
if 'det' in self.mode:
|
||||
super().visualize(inputs, self._pack_e2e_datasamples(preds),
|
||||
**kwargs)
|
||||
else:
|
||||
return self.textrec_inferencer.visualize(
|
||||
self.rec_inputs, preds['rec'][0], **kwargs)
|
||||
else:
|
||||
return self.textdet_inferencer.visualize(inputs, preds['det'],
|
||||
**kwargs)
|
||||
|
||||
def postprocess(self,
|
||||
preds: PredType,
|
||||
imgs: Optional[List[np.ndarray]] = None,
|
||||
is_batch: bool = False,
|
||||
print_result: bool = False,
|
||||
pred_out_file: str = ''
|
||||
) -> Union[ResType, Tuple[ResType, np.ndarray]]:
|
||||
"""Postprocess predictions.
|
||||
|
||||
Args:
|
||||
preds (Dict): Predictions of the model.
|
||||
imgs (Optional[np.ndarray]): Visualized predictions.
|
||||
is_batch (bool): Whether the inputs are in a batch.
|
||||
Defaults to False.
|
||||
print_result (bool): Whether to print the result.
|
||||
Defaults to False.
|
||||
pred_out_file (str): Output file name to store predictions
|
||||
without images. Supported file formats are “json”, “yaml/yml”
|
||||
and “pickle/pkl”. Defaults to ''.
|
||||
|
||||
Returns:
|
||||
Dict or List[Dict]: Each dict contains the inference result of
|
||||
each image. Possible keys are "det_polygons", "det_scores",
|
||||
"rec_texts", "rec_scores", "kie_labels", "kie_scores",
|
||||
"kie_edge_labels" and "kie_edge_scores".
|
||||
"""
|
||||
|
||||
results = [{} for _ in range(len(next(iter(preds.values()))))]
|
||||
if 'rec' in self.mode:
|
||||
for i, rec_pred in enumerate(preds['rec']):
|
||||
result = dict(rec_texts=[], rec_scores=[])
|
||||
for rec_pred_instance in rec_pred:
|
||||
pred = rec_pred_instance.pred_text
|
||||
result['rec_texts'].append(pred.item)
|
||||
result['rec_scores'].append(pred.score)
|
||||
results[i].update(result)
|
||||
if 'det' in self.mode:
|
||||
for i, det_pred in enumerate(preds['det']):
|
||||
det_pred_instances = det_pred.pred_instances
|
||||
results[i].update(
|
||||
dict(
|
||||
det_polygons=det_pred_instances['polygons'],
|
||||
det_scores=det_pred_instances['scores']))
|
||||
if 'kie' in self.mode:
|
||||
for i, kie_pred in enumerate(preds['kie']):
|
||||
kie_pred_instances = kie_pred.pred_instances
|
||||
results[i].update(
|
||||
dict(
|
||||
kie_labels=kie_pred_instances['labels'],
|
||||
kie_scores=kie_pred_instances['scores']),
|
||||
kie_edge_scores=kie_pred_instances['edge_scores'],
|
||||
kie_edge_labels=kie_pred_instances['edge_labels'])
|
||||
|
||||
if not is_batch:
|
||||
results = results[0]
|
||||
if print_result:
|
||||
print(results)
|
||||
if pred_out_file != '':
|
||||
mmcv.dump(results, pred_out_file)
|
||||
if imgs is None:
|
||||
return results
|
||||
return results, imgs
|
||||
|
||||
def _pack_e2e_datasamples(self, preds: Dict) -> List[TextDetDataSample]:
|
||||
"""Pack text detection and recognition results into a list of
|
||||
TextDetDataSample.
|
||||
|
||||
Note that it is a temporary solution since the TextSpottingDataSample
|
||||
is not ready.
|
||||
"""
|
||||
results = []
|
||||
for det_data_sample, rec_data_samples in zip(preds['det'],
|
||||
preds['rec']):
|
||||
texts = []
|
||||
for rec_data_sample in rec_data_samples:
|
||||
texts.append(rec_data_sample.pred_text.item)
|
||||
det_data_sample.pred_instances.texts = texts
|
||||
results.append(det_data_sample)
|
||||
return results
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict
|
||||
|
||||
from mmocr.structures import TextDetDataSample
|
||||
from .base_mmocr_inferencer import BaseMMOCRInferencer
|
||||
|
||||
|
||||
class TextDetInferencer(BaseMMOCRInferencer):
|
||||
|
||||
def _pred2dict(self, data_sample: TextDetDataSample) -> Dict:
|
||||
"""Extract elements necessary to represent a prediction into a
|
||||
dictionary. It's better to contain only basic data elements such as
|
||||
strings and numbers in order to guarantee it's json-serializable.
|
||||
|
||||
Args:
|
||||
data_sample (TextDetDataSample): The data sample to be converted.
|
||||
|
||||
Returns:
|
||||
dict: The output dictionary.
|
||||
"""
|
||||
result = {}
|
||||
pred_instances = data_sample.pred_instances
|
||||
result['polygons'] = []
|
||||
for polygon in pred_instances.polygons:
|
||||
result['polygons'].append(polygon.tolist())
|
||||
result['scores'] = pred_instances.scores.cpu().numpy().tolist()
|
||||
return result
|
|
@ -0,0 +1,26 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mmocr.structures import TextRecogDataSample
|
||||
from .base_mmocr_inferencer import BaseMMOCRInferencer
|
||||
|
||||
|
||||
class TextRecInferencer(BaseMMOCRInferencer):
|
||||
|
||||
def _pred2dict(self, data_sample: TextRecogDataSample) -> Dict:
|
||||
"""Extract elements necessary to represent a prediction into a
|
||||
dictionary. It's better to contain only basic data elements such as
|
||||
strings and numbers in order to guarantee it's json-serializable.
|
||||
|
||||
Args:
|
||||
data_sample (TextRecogDataSample): The data sample to be converted.
|
||||
|
||||
Returns:
|
||||
dict: The output dictionary.
|
||||
"""
|
||||
result = {}
|
||||
result['text'] = data_sample.pred_text.item
|
||||
result['scores'] = float(np.mean(data_sample.pred_text.score))
|
||||
return result
|
|
@ -1,7 +1,8 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .adapters import MMDet2MMOCR, MMOCR2MMDet
|
||||
from .formatting import PackKIEInputs, PackTextDetInputs, PackTextRecogInputs
|
||||
from .loading import (LoadImageFromFile, LoadImageFromLMDB, LoadKIEAnnotations,
|
||||
from .loading import (LoadImageFromFile, LoadImageFromLMDB,
|
||||
LoadImageFromNDArray, LoadKIEAnnotations,
|
||||
LoadOCRAnnotations)
|
||||
from .ocr_transforms import RandomCrop, RandomRotate, Resize
|
||||
from .textdet_transforms import (BoundedScaleAspectJitter, FixInvalidPolygon,
|
||||
|
@ -18,5 +19,6 @@ __all__ = [
|
|||
'PackTextRecogInputs', 'RescaleToHeight', 'PadToWidth',
|
||||
'ShortScaleAspectJitter', 'RandomFlip', 'BoundedScaleAspectJitter',
|
||||
'PackKIEInputs', 'LoadKIEAnnotations', 'FixInvalidPolygon', 'MMDet2MMOCR',
|
||||
'MMOCR2MMDet', 'LoadImageFromLMDB', 'LoadImageFromFile'
|
||||
'MMOCR2MMDet', 'LoadImageFromLMDB', 'LoadImageFromFile',
|
||||
'LoadImageFromNDArray'
|
||||
]
|
||||
|
|
|
@ -251,9 +251,7 @@ class PackKIEInputs(BaseTransform):
|
|||
'gt_texts': 'texts',
|
||||
}
|
||||
|
||||
def __init__(self,
|
||||
meta_keys=('img_path', 'ori_shape', 'img_shape',
|
||||
'scale_factor')):
|
||||
def __init__(self, meta_keys=()):
|
||||
self.meta_keys = meta_keys
|
||||
|
||||
def transform(self, results: dict) -> dict:
|
||||
|
|
|
@ -114,6 +114,54 @@ class LoadImageFromFile(MMCV_LoadImageFromFile):
|
|||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadImageFromNDArray(LoadImageFromFile):
|
||||
"""Load an image from ``results['img']``.
|
||||
|
||||
Similar with :obj:`LoadImageFromFile`, but the image has been loaded as
|
||||
:obj:`np.ndarray` in ``results['img']``. Can be used when loading image
|
||||
from webcam.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img
|
||||
- img_path
|
||||
- img_shape
|
||||
- ori_shape
|
||||
|
||||
Args:
|
||||
to_float32 (bool): Whether to convert the loaded image to a float32
|
||||
numpy array. If set to False, the loaded image is an uint8 array.
|
||||
Defaults to False.
|
||||
"""
|
||||
|
||||
def transform(self, results: dict) -> dict:
|
||||
"""Transform function to add image meta information.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict with Webcam read image in
|
||||
``results['img']``.
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded image and meta information.
|
||||
"""
|
||||
|
||||
img = results['img']
|
||||
if self.to_float32:
|
||||
img = img.astype(np.float32)
|
||||
if self.color_type == 'grayscale':
|
||||
img = mmcv.image.rgb2gray(img)
|
||||
results['img_path'] = None
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape[:2]
|
||||
results['ori_shape'] = img.shape[:2]
|
||||
return results
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadOCRAnnotations(MMCV_LoadAnnotations):
|
||||
"""Load and process the ``instances`` annotation provided by dataset.
|
||||
|
|
|
@ -23,7 +23,7 @@ class SDMGRPostProcessor:
|
|||
Defaults to 'none'. Options are:
|
||||
|
||||
- 'none': The simplest link type involving no edge
|
||||
postprocessing. The edge prediction will be returned as it is.
|
||||
postprocessing. The edge prediction will be returned as-is.
|
||||
- 'one-to-one': One key node can be connected to one value node.
|
||||
- 'one-to-many': One key node can be connected to multiple value
|
||||
nodes.
|
||||
|
@ -98,13 +98,13 @@ class SDMGRPostProcessor:
|
|||
|
||||
for i in range(len(data_samples)):
|
||||
data_samples[i].pred_instances = InstanceData()
|
||||
data_samples[i].pred_instances.labels = node_preds[i]
|
||||
data_samples[i].pred_instances.scores = node_scores[i]
|
||||
data_samples[i].pred_instances.labels = node_preds[i].cpu()
|
||||
data_samples[i].pred_instances.scores = node_scores[i].cpu()
|
||||
if self.link_type != 'none':
|
||||
edge_scores[i], edge_preds[i] = self.decode_edges(
|
||||
node_preds[i], edge_scores[i], edge_preds[i])
|
||||
data_samples[i].pred_instances.edge_labels = edge_preds[i]
|
||||
data_samples[i].pred_instances.edge_scores = edge_scores[i]
|
||||
data_samples[i].pred_instances.edge_labels = edge_preds[i].cpu()
|
||||
data_samples[i].pred_instances.edge_scores = edge_scores[i].cpu()
|
||||
|
||||
return data_samples
|
||||
|
||||
|
@ -167,4 +167,4 @@ class SDMGRPostProcessor:
|
|||
elif self.link_type == 'many-to-one':
|
||||
tmp_edge_scores[i, :] = -1
|
||||
|
||||
return new_edge_scores, new_edge_labels
|
||||
return new_edge_scores.cpu(), new_edge_labels.cpu()
|
||||
|
|
|
@ -78,7 +78,7 @@ class BaseTextRecogPostprocessor:
|
|||
data_sample (TextRecogDataSample): Datasample of an image.
|
||||
|
||||
Returns:
|
||||
tuple(list[int], list[float]): index and score.
|
||||
tuple(list[int], list[float]): Index and scores per-character.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
|
@ -0,0 +1,471 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os
|
||||
import warnings
|
||||
from argparse import ArgumentParser
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from mmocr.apis.inferencers import MMOCRInferencer
|
||||
from mmocr.apis.inferencers.base_mmocr_inferencer import InputsType
|
||||
from mmocr.utils import register_all_modules
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument(
|
||||
'img', type=str, help='Input image file or folder path.')
|
||||
parser.add_argument(
|
||||
'--img-out-dir',
|
||||
type=str,
|
||||
default='',
|
||||
help='Output directory of images.')
|
||||
parser.add_argument(
|
||||
'--det',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Pretrained text detection algorithm')
|
||||
parser.add_argument(
|
||||
'--det-config',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Path to the custom config file of the selected det model. It '
|
||||
'overrides the settings in det')
|
||||
parser.add_argument(
|
||||
'--det-ckpt',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Path to the custom checkpoint file of the selected det model. '
|
||||
'It overrides the settings in det')
|
||||
parser.add_argument(
|
||||
'--recog',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Pretrained text recognition algorithm')
|
||||
parser.add_argument(
|
||||
'--recog-config',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Path to the custom config file of the selected recog model. It'
|
||||
'overrides the settings in recog')
|
||||
parser.add_argument(
|
||||
'--recog-ckpt',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Path to the custom checkpoint file of the selected recog model. '
|
||||
'It overrides the settings in recog')
|
||||
parser.add_argument(
|
||||
'--kie',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Pretrained key information extraction algorithm')
|
||||
parser.add_argument(
|
||||
'--kie-config',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Path to the custom config file of the selected kie model. It'
|
||||
'overrides the settings in kie')
|
||||
parser.add_argument(
|
||||
'--kie-ckpt',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Path to the custom checkpoint file of the selected kie model. '
|
||||
'It overrides the settings in kie')
|
||||
parser.add_argument(
|
||||
'--config-dir',
|
||||
type=str,
|
||||
default=os.path.join(str(Path.cwd()), 'configs/'),
|
||||
help='Path to the config directory where all the config files '
|
||||
'are located. Defaults to "configs/"')
|
||||
parser.add_argument(
|
||||
'--device',
|
||||
type=str,
|
||||
default='cuda',
|
||||
help='Device used for inference.')
|
||||
parser.add_argument(
|
||||
'--show',
|
||||
action='store_true',
|
||||
help='Display the image in a popup window.')
|
||||
parser.add_argument(
|
||||
'--print-result',
|
||||
action='store_true',
|
||||
help='Whether to print the results.')
|
||||
parser.add_argument(
|
||||
'--pred-out-file',
|
||||
type=str,
|
||||
default='',
|
||||
help='File to save the inference results.')
|
||||
|
||||
args = parser.parse_args()
|
||||
# Warnings
|
||||
if not os.path.samefile(args.config_dir, os.path.join(str(
|
||||
Path.cwd()))) and (args.det_config != ''
|
||||
or args.recog_config != ''):
|
||||
warnings.warn(
|
||||
'config_dir will be overridden by det-config or recog-config.',
|
||||
UserWarning)
|
||||
return args
|
||||
|
||||
|
||||
class MMOCR:
|
||||
"""MMOCR API for text detection, recognition, KIE inference.
|
||||
|
||||
Args:
|
||||
det (str): Name of the detection model. Default to 'FCE_IC15'.
|
||||
det_config (str): Path to the config file for the detection model.
|
||||
Default to None.
|
||||
det_ckpt (str): Path to the checkpoint file for the detection model.
|
||||
Default to None.
|
||||
recog (str): Name of the recognition model. Default to 'CRNN'.
|
||||
recog_config (str): Path to the config file for the recognition model.
|
||||
Default to None.
|
||||
recog_ckpt (str): Path to the checkpoint file for the recognition
|
||||
model. Default to None.
|
||||
kie (str): Name of the KIE model. Default to None.
|
||||
kie_config (str): Path to the config file for the KIE model. Default
|
||||
to None.
|
||||
kie_ckpt (str): Path to the checkpoint file for the KIE model.
|
||||
Default to None.
|
||||
config_dir (str): Path to the directory containing config files.
|
||||
Default to 'configs/'.
|
||||
device (torch.device): Device to use for inference. Default to 'cuda'.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
det: str = None,
|
||||
det_config: str = None,
|
||||
det_ckpt: str = None,
|
||||
recog: str = None,
|
||||
recog_config: str = None,
|
||||
recog_ckpt: str = None,
|
||||
kie: str = None,
|
||||
kie_config: str = None,
|
||||
kie_ckpt: str = None,
|
||||
config_dir: str = os.path.join(str(Path.cwd()), 'configs/'),
|
||||
device: torch.device = 'cuda',
|
||||
**kwargs) -> None:
|
||||
|
||||
register_all_modules(init_default_scope=True)
|
||||
|
||||
self.config_dir = config_dir
|
||||
inferencer_kwargs = {}
|
||||
inferencer_kwargs.update(
|
||||
self._get_inferencer_kwargs(det, det_config, det_ckpt, 'det_'))
|
||||
inferencer_kwargs.update(
|
||||
self._get_inferencer_kwargs(recog, recog_config, recog_ckpt,
|
||||
'rec_'))
|
||||
inferencer_kwargs.update(
|
||||
self._get_inferencer_kwargs(kie, kie_config, kie_ckpt, 'kie_'))
|
||||
self.inferencer = MMOCRInferencer(device=device, **inferencer_kwargs)
|
||||
|
||||
def _get_inferencer_kwargs(self, model: Optional[str],
|
||||
config: Optional[str], ckpt: Optional[str],
|
||||
prefix: str) -> Dict:
|
||||
"""Get the kwargs for the inferencer."""
|
||||
kwargs = {}
|
||||
|
||||
if model is not None:
|
||||
cfgs = self.get_model_config(model)
|
||||
kwargs[prefix + 'config'] = os.path.join(self.config_dir,
|
||||
cfgs['config'])
|
||||
kwargs[prefix + 'ckpt'] = 'https://download.openmmlab.com/' + \
|
||||
f'mmocr/{cfgs["ckpt"]}'
|
||||
|
||||
if config is not None:
|
||||
if kwargs.get(prefix + 'config', None) is not None:
|
||||
warnings.warn(
|
||||
f'{model}\'s default config is overridden by {config}',
|
||||
UserWarning)
|
||||
kwargs[prefix + 'config'] = config
|
||||
|
||||
if ckpt is not None:
|
||||
if kwargs.get(prefix + 'ckpt', None) is not None:
|
||||
warnings.warn(
|
||||
f'{model}\'s default checkpoint is overridden by {ckpt}',
|
||||
UserWarning)
|
||||
kwargs[prefix + 'ckpt'] = ckpt
|
||||
return kwargs
|
||||
|
||||
def readtext(self,
|
||||
img: InputsType,
|
||||
img_out_dir: str = '',
|
||||
show: bool = False,
|
||||
print_result: bool = False,
|
||||
pred_out_file: str = '',
|
||||
**kwargs) -> Union[Dict, List[Dict]]:
|
||||
"""Inferences text detection, recognition, and KIE on an image or a
|
||||
folder of images.
|
||||
|
||||
Args:
|
||||
imgs (str or np.array or Sequence[str or np.array]): Img,
|
||||
folder path, np array or list/tuple (with img
|
||||
paths or np arrays).
|
||||
img_out_dir (str): Output directory of images. Defaults to ''.
|
||||
show (bool): Whether to display the image in a popup window.
|
||||
Defaults to False.
|
||||
print_result (bool): Whether to print the results.
|
||||
pred_out_file (str): File to save the inference results. If left as
|
||||
empty, no file will be saved.
|
||||
|
||||
Returns:
|
||||
Dict or List[Dict]: Each dict contains the inference result of
|
||||
each image. Possible keys are "det_polygons", "det_scores",
|
||||
"rec_texts", "rec_scores", "kie_labels", "kie_scores",
|
||||
"kie_edge_labels" and "kie_edge_scores".
|
||||
"""
|
||||
return self.inferencer(
|
||||
img,
|
||||
img_out_dir=img_out_dir,
|
||||
show=show,
|
||||
print_result=print_result,
|
||||
pred_out_file=pred_out_file)
|
||||
|
||||
def get_model_config(self, model_name: str) -> Dict:
|
||||
"""Get the model configuration including model config and checkpoint
|
||||
url.
|
||||
|
||||
Args:
|
||||
model_name (str): Name of the model.
|
||||
Returns:
|
||||
dict: Model configuration.
|
||||
"""
|
||||
model_dict = {
|
||||
'Tesseract': {},
|
||||
# Detection models
|
||||
'DB_r18': {
|
||||
'config':
|
||||
'textdet/'
|
||||
'dbnet/dbnet_resnet18_fpnc_1200e_icdar2015.py',
|
||||
'ckpt':
|
||||
'textdet/'
|
||||
'dbnet/'
|
||||
'dbnet_resnet18_fpnc_1200e_icdar2015/'
|
||||
'dbnet_resnet18_fpnc_1200e_icdar2015_20220825_221614-7c0e94f2.pth' # noqa: E501
|
||||
},
|
||||
'DB_r50': {
|
||||
'config':
|
||||
'textdet/'
|
||||
'dbnet/dbnet_resnet50-dcnv2_fpnc_1200e_icdar2015.py',
|
||||
'ckpt':
|
||||
'textdet/'
|
||||
'dbnet/'
|
||||
'dbnet_resnet50-dcnv2_fpnc_1200e_icdar2015'
|
||||
'dbnet_resnet50-dcnv2_fpnc_1200e_icdar2015_20220828_124917-452c443c.pth' # noqa: E501
|
||||
},
|
||||
'DBPP_r50': {
|
||||
'config':
|
||||
'textdet/'
|
||||
'dbnetpp/dbnetpp_resnet50-dcnv2_fpnc_1200e_icdar2015.py',
|
||||
'ckpt':
|
||||
'textdet/'
|
||||
'dbnetpp/'
|
||||
'dbnetpp_resnet50-dcnv2_fpnc_1200e_icdar2015/'
|
||||
'dbnetpp_resnet50-dcnv2_fpnc_1200e_icdar2015_20220829_230108-f289bd20.pth' # noqa: E501
|
||||
},
|
||||
'DRRG': {
|
||||
'config':
|
||||
'textdet/'
|
||||
'drrg/drrg_resnet50_fpn-unet_1200e_ctw1500.py',
|
||||
'ckpt':
|
||||
'textdet/'
|
||||
'drrg/'
|
||||
'drrg_resnet50_fpn-unet_1200e_ctw1500/'
|
||||
'drrg_resnet50_fpn-unet_1200e_ctw1500_20220827_105233-d5c702dd.pth' # noqa: E501
|
||||
},
|
||||
'FCE_IC15': {
|
||||
'config':
|
||||
'textdet/'
|
||||
'fcenet/fcenet_resnet50_fpn_1500e_icdar2015.py',
|
||||
'ckpt':
|
||||
'textdet/'
|
||||
'fcenet/'
|
||||
'fcenet_resnet50_fpn_1500e_icdar2015/'
|
||||
'fcenet_resnet50_fpn_1500e_icdar2015_20220826_140941-167d9042.pth' # noqa: E501
|
||||
},
|
||||
'FCE_CTW_DCNv2': {
|
||||
'config':
|
||||
'textdet/'
|
||||
'fcenet/fcenet_resnet50-dcnv2_fpn_1500e_ctw1500.py',
|
||||
'ckpt':
|
||||
'textdet/'
|
||||
'fcenet/'
|
||||
'fcenet_resnet50-dcnv2_fpn_1500e_ctw1500/'
|
||||
'fcenet_resnet50-dcnv2_fpn_1500e_ctw1500_20220825_221510-4d705392.pth' # noqa: E501
|
||||
},
|
||||
'MaskRCNN_CTW': {
|
||||
'config':
|
||||
'textdet/'
|
||||
'maskrcnn/mask-rcnn_resnet50_fpn_160e_ctw1500.py',
|
||||
'ckpt':
|
||||
'textdet/'
|
||||
'maskrcnn/'
|
||||
'mask-rcnn_resnet50_fpn_160e_ctw1500/'
|
||||
'mask-rcnn_resnet50_fpn_160e_ctw1500_20220826_154755-ce68ee8e.pth' # noqa: E501
|
||||
},
|
||||
'MaskRCNN_IC15': {
|
||||
'config':
|
||||
'textdet/'
|
||||
'maskrcnn/mask-rcnn_resnet50_fpn_160e_icdar2015.py',
|
||||
'ckpt':
|
||||
'textdet/'
|
||||
'maskrcnn/'
|
||||
'mask-rcnn_resnet50_fpn_160e_icdar2015/'
|
||||
'mask-rcnn_resnet50_fpn_160e_icdar2015_20220826_154808-ff5c30bf.pth' # noqa: E501
|
||||
},
|
||||
# 'MaskRCNN_IC17': {
|
||||
# 'config':
|
||||
# 'textdet/'
|
||||
# 'maskrcnn/mask-rcnn_resnet50_fpn_160e_icdar2017.py',
|
||||
# 'ckpt':
|
||||
# 'textdet/'
|
||||
# 'maskrcnn/'
|
||||
# ''
|
||||
# },
|
||||
'PANet_CTW': {
|
||||
'config':
|
||||
'textdet/'
|
||||
'panet/panet_resnet18_fpem-ffm_600e_ctw1500.py',
|
||||
'ckpt':
|
||||
'textdet/'
|
||||
'panet/'
|
||||
'panet_resnet18_fpem-ffm_600e_ctw1500/'
|
||||
'panet_resnet18_fpem-ffm_600e_ctw1500_20220826_144818-980f32d0.pth' # noqa: E501
|
||||
},
|
||||
'PANet_IC15': {
|
||||
'config':
|
||||
'textdet/'
|
||||
'panet/panet_resnet18_fpem-ffm_600e_icdar2015.py',
|
||||
'ckpt':
|
||||
'textdet/'
|
||||
'panet/'
|
||||
'panet_resnet18_fpem-ffm_600e_icdar2015/'
|
||||
'panet_resnet18_fpem-ffm_600e_icdar2015_20220826_144817-be2acdb4.pth' # noqa: E501
|
||||
},
|
||||
'PS_CTW': {
|
||||
'config':
|
||||
'textdet/'
|
||||
'psenet/psenet_resnet50_fpnf_600e_ctw1500.py',
|
||||
'ckpt':
|
||||
'textdet/'
|
||||
'psenet/'
|
||||
'psenet_resnet50_fpnf_600e_ctw1500/'
|
||||
'psenet_resnet50_fpnf_600e_ctw1500_20220825_221459-7f974ac8.pth' # noqa: E501
|
||||
},
|
||||
'PS_IC15': {
|
||||
'config':
|
||||
'textdet/'
|
||||
'psenet/psenet_resnet50_fpnf_600e_icdar2015.py',
|
||||
'ckpt':
|
||||
'textdet/'
|
||||
'psenet/'
|
||||
'psenet_resnet50_fpnf_600e_icdar2015/'
|
||||
'psenet_resnet50_fpnf_600e_icdar2015_20220825_222709-b6741ec3.pth' # noqa: E501
|
||||
},
|
||||
'TextSnake': {
|
||||
'config':
|
||||
'textdet/'
|
||||
'textsnake/textsnake_resnet50_fpn-unet_1200e_ctw1500.py',
|
||||
'ckpt':
|
||||
'textdet/'
|
||||
'textsnake/'
|
||||
'textsnake_resnet50_fpn-unet_1200e_ctw1500/'
|
||||
'textsnake_resnet50_fpn-unet_1200e_ctw1500_20220825_221459-c0b6adc4.pth' # noqa: E501
|
||||
},
|
||||
# Recognition models
|
||||
'CRNN': {
|
||||
'config':
|
||||
'textrecog/crnn/crnn_mini-vgg_5e_mj.py',
|
||||
'ckpt':
|
||||
'textrecog/crnn/crnn_mini-vgg_5e_mj/crnn_mini-vgg_5e_mj_20220826_224120-8afbedbb.pth' # noqa: E501
|
||||
},
|
||||
# 'SAR': {
|
||||
# 'config':
|
||||
# 'textrecog/sar/'
|
||||
# 'sar_resnet31_parallel-decoder_5e_st-sub_mj-sub_sa_real.py',
|
||||
# 'ckpt':
|
||||
# ''
|
||||
# },
|
||||
# 'SAR_CN': {
|
||||
# 'config':
|
||||
# 'textrecog/'
|
||||
# 'sar/sar_r31_parallel_decoder_chinese.py',
|
||||
# 'ckpt':
|
||||
# 'textrecog/'
|
||||
# ''
|
||||
# },
|
||||
# 'NRTR_1/16-1/8': {
|
||||
# 'config':
|
||||
# 'textrecog/'
|
||||
# 'nrtr/nrtr_resnet31-1by16-1by8_6e_st_mj.py',
|
||||
# 'ckpt':
|
||||
# 'textrecog/'
|
||||
# ''
|
||||
# },
|
||||
# 'NRTR_1/8-1/4': {
|
||||
# 'config':
|
||||
# 'textrecog/'
|
||||
# 'nrtr/nrtr_resnet31-1by8-1by4_6e_st_mj.py',
|
||||
# 'ckpt':
|
||||
# 'textrecog/'
|
||||
# ''
|
||||
# },
|
||||
# 'RobustScanner': {
|
||||
# 'config':
|
||||
# 'textrecog/robust_scanner/'
|
||||
# 'robustscanner_resnet31_5e_st-sub_mj-sub_sa_real.py',
|
||||
# 'ckpt':
|
||||
# 'textrecog/'
|
||||
# ''
|
||||
# },
|
||||
# 'SATRN': {
|
||||
# 'config': 'textrecog/satrn/satrn_shallow_5e_st_mj.py',
|
||||
# 'ckpt': ''
|
||||
# },
|
||||
# 'SATRN_sm': {
|
||||
# 'config': 'textrecog/satrn/satrn_shallow-small_5e_st_mj.py',
|
||||
# 'ckpt': ''
|
||||
# },
|
||||
# 'ABINet': {
|
||||
# 'config': 'textrecog/abinet/abinet_20e_st-an_mj.py',
|
||||
# 'ckpt': ''
|
||||
# },
|
||||
# 'ABINet_Vision': {
|
||||
# 'config': 'textrecog/abinet/abinet-vision_20e_st-an_mj.py',
|
||||
# 'ckpt': ''
|
||||
# },
|
||||
# 'CRNN_TPS': {
|
||||
# 'config':
|
||||
# 'textrecog/tps/crnn_tps_academic_dataset.py',
|
||||
# 'ckpt':
|
||||
# ''
|
||||
# },
|
||||
# 'MASTER': {
|
||||
# 'config': 'textrecog/master/master_resnet31_12e_st_mj_sa.py',
|
||||
# 'ckpt': ''
|
||||
# },
|
||||
# KIE models
|
||||
'SDMGR': {
|
||||
'config':
|
||||
'kie/sdmgr/sdmgr_unet16_60e_wildreceipt.py',
|
||||
'ckpt':
|
||||
'kie/'
|
||||
'sdmgr/'
|
||||
'sdmgr_unet16_60e_wildreceipt/'
|
||||
'sdmgr_unet16_60e_wildreceipt_20220825_151648-22419f37.pth'
|
||||
}
|
||||
}
|
||||
if model_name not in model_dict:
|
||||
raise ValueError(f'Model {model_name} is not supported.')
|
||||
else:
|
||||
return model_dict[model_name]
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
ocr = MMOCR(**vars(args))
|
||||
ocr.readtext(**vars(args))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -1,877 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import os
|
||||
import warnings
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from pathlib import Path
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmcv.image.misc import tensor2imgs
|
||||
from mmcv.runner import load_checkpoint
|
||||
from mmcv.utils.config import Config
|
||||
from PIL import Image
|
||||
|
||||
try:
|
||||
import tesserocr
|
||||
except ImportError:
|
||||
tesserocr = None
|
||||
|
||||
from mmocr.apis import init_detector
|
||||
from mmocr.apis.inference import model_inference
|
||||
from mmocr.datasets import WildReceiptDataset
|
||||
from mmocr.models.textdet.detectors import TextDetectorMixin
|
||||
from mmocr.models.textrecog.recognizers import BaseRecognizer
|
||||
from mmocr.registry import MODELS
|
||||
from mmocr.utils import is_type_list, stitch_boxes_into_lines
|
||||
from mmocr.utils.fileio import list_from_file
|
||||
from mmocr.utils.img_utils import crop_img
|
||||
from mmocr.utils.model import revert_sync_batchnorm
|
||||
from mmocr.visualization.visualize import det_recog_show_result
|
||||
|
||||
|
||||
# Parse CLI arguments
|
||||
def parse_args():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument(
|
||||
'img', type=str, help='Input image file or folder path.')
|
||||
parser.add_argument(
|
||||
'--output',
|
||||
type=str,
|
||||
default='',
|
||||
help='Output file/folder name for visualization')
|
||||
parser.add_argument(
|
||||
'--det',
|
||||
type=str,
|
||||
default='PANet_IC15',
|
||||
help='Pretrained text detection algorithm')
|
||||
parser.add_argument(
|
||||
'--det-config',
|
||||
type=str,
|
||||
default='',
|
||||
help='Path to the custom config file of the selected det model. It '
|
||||
'overrides the settings in det')
|
||||
parser.add_argument(
|
||||
'--det-ckpt',
|
||||
type=str,
|
||||
default='',
|
||||
help='Path to the custom checkpoint file of the selected det model. '
|
||||
'It overrides the settings in det')
|
||||
parser.add_argument(
|
||||
'--recog',
|
||||
type=str,
|
||||
default='SEG',
|
||||
help='Pretrained text recognition algorithm')
|
||||
parser.add_argument(
|
||||
'--recog-config',
|
||||
type=str,
|
||||
default='',
|
||||
help='Path to the custom config file of the selected recog model. It'
|
||||
'overrides the settings in recog')
|
||||
parser.add_argument(
|
||||
'--recog-ckpt',
|
||||
type=str,
|
||||
default='',
|
||||
help='Path to the custom checkpoint file of the selected recog model. '
|
||||
'It overrides the settings in recog')
|
||||
parser.add_argument(
|
||||
'--kie',
|
||||
type=str,
|
||||
default='',
|
||||
help='Pretrained key information extraction algorithm')
|
||||
parser.add_argument(
|
||||
'--kie-config',
|
||||
type=str,
|
||||
default='',
|
||||
help='Path to the custom config file of the selected kie model. It'
|
||||
'overrides the settings in kie')
|
||||
parser.add_argument(
|
||||
'--kie-ckpt',
|
||||
type=str,
|
||||
default='',
|
||||
help='Path to the custom checkpoint file of the selected kie model. '
|
||||
'It overrides the settings in kie')
|
||||
parser.add_argument(
|
||||
'--config-dir',
|
||||
type=str,
|
||||
default=os.path.join(str(Path.cwd()), 'configs/'),
|
||||
help='Path to the config directory where all the config files '
|
||||
'are located. Defaults to "configs/"')
|
||||
parser.add_argument(
|
||||
'--batch-mode',
|
||||
action='store_true',
|
||||
help='Whether use batch mode for inference')
|
||||
parser.add_argument(
|
||||
'--recog-batch-size',
|
||||
type=int,
|
||||
default=0,
|
||||
help='Batch size for text recognition')
|
||||
parser.add_argument(
|
||||
'--det-batch-size',
|
||||
type=int,
|
||||
default=0,
|
||||
help='Batch size for text detection')
|
||||
parser.add_argument(
|
||||
'--single-batch-size',
|
||||
type=int,
|
||||
default=0,
|
||||
help='Batch size for separate det/recog inference')
|
||||
parser.add_argument(
|
||||
'--device', default=None, help='Device used for inference.')
|
||||
parser.add_argument(
|
||||
'--export',
|
||||
type=str,
|
||||
default='',
|
||||
help='Folder where the results of each image are exported')
|
||||
parser.add_argument(
|
||||
'--export-format',
|
||||
type=str,
|
||||
default='json',
|
||||
help='Format of the exported result file(s)')
|
||||
parser.add_argument(
|
||||
'--details',
|
||||
action='store_true',
|
||||
help='Whether include the text boxes coordinates and confidence values'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--imshow',
|
||||
action='store_true',
|
||||
help='Whether show image with OpenCV.')
|
||||
parser.add_argument(
|
||||
'--print-result',
|
||||
action='store_true',
|
||||
help='Prints the recognised text')
|
||||
parser.add_argument(
|
||||
'--merge', action='store_true', help='Merge neighboring boxes')
|
||||
parser.add_argument(
|
||||
'--merge-xdist',
|
||||
type=float,
|
||||
default=20,
|
||||
help='The maximum x-axis distance to merge boxes')
|
||||
args = parser.parse_args()
|
||||
if args.det == 'None':
|
||||
args.det = None
|
||||
if args.recog == 'None':
|
||||
args.recog = None
|
||||
# Warnings
|
||||
if args.merge and not (args.det and args.recog):
|
||||
warnings.warn(
|
||||
'Box merging will not work if the script is not'
|
||||
' running in detection + recognition mode.', UserWarning)
|
||||
if not os.path.samefile(args.config_dir, os.path.join(str(
|
||||
Path.cwd()))) and (args.det_config != ''
|
||||
or args.recog_config != ''):
|
||||
warnings.warn(
|
||||
'config_dir will be overridden by det-config or recog-config.',
|
||||
UserWarning)
|
||||
return args
|
||||
|
||||
|
||||
class MMOCR:
|
||||
|
||||
def __init__(self,
|
||||
det='PANet_IC15',
|
||||
det_config='',
|
||||
det_ckpt='',
|
||||
recog='SEG',
|
||||
recog_config='',
|
||||
recog_ckpt='',
|
||||
kie='',
|
||||
kie_config='',
|
||||
kie_ckpt='',
|
||||
config_dir=os.path.join(str(Path.cwd()), 'configs/'),
|
||||
device=None,
|
||||
**kwargs):
|
||||
|
||||
textdet_models = {
|
||||
'DB_r18': {
|
||||
'config':
|
||||
'dbnet/dbnet_r18_fpnc_1200e_icdar2015.py',
|
||||
'ckpt':
|
||||
'dbnet/'
|
||||
'dbnet_r18_fpnc_sbn_1200e_icdar2015_20210329-ba3ab597.pth'
|
||||
},
|
||||
'DB_r50': {
|
||||
'config':
|
||||
'dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py',
|
||||
'ckpt':
|
||||
'dbnet/'
|
||||
'dbnet_r50dcnv2_fpnc_sbn_1200e_icdar2015_20211025-9fe3b590.pth'
|
||||
},
|
||||
'DBPP_r50': {
|
||||
'config':
|
||||
'dbnetpp/dbnetpp_r50dcnv2_fpnc_1200e_icdar2015.py',
|
||||
'ckpt':
|
||||
'dbnet/'
|
||||
'dbnetpp_r50dcnv2_fpnc_1200e_icdar2015-20220502-d7a76fff.pth'
|
||||
},
|
||||
'DRRG': {
|
||||
'config':
|
||||
'drrg/drrg_r50_fpn_unet_1200e_ctw1500.py',
|
||||
'ckpt':
|
||||
'drrg/drrg_r50_fpn_unet_1200e_ctw1500_20211022-fb30b001.pth'
|
||||
},
|
||||
'FCE_IC15': {
|
||||
'config':
|
||||
'fcenet/fcenet_r50_fpn_1500e_icdar2015.py',
|
||||
'ckpt':
|
||||
'fcenet/fcenet_r50_fpn_1500e_icdar2015_20211022-daefb6ed.pth'
|
||||
},
|
||||
'FCE_CTW_DCNv2': {
|
||||
'config':
|
||||
'fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500.py',
|
||||
'ckpt':
|
||||
'fcenet/' +
|
||||
'fcenet_r50dcnv2_fpn_1500e_ctw1500_20211022-e326d7ec.pth'
|
||||
},
|
||||
'MaskRCNN_CTW': {
|
||||
'config':
|
||||
'maskrcnn/mask_rcnn_r50_fpn_160e_ctw1500.py',
|
||||
'ckpt':
|
||||
'maskrcnn/'
|
||||
'mask_rcnn_r50_fpn_160e_ctw1500_20210219-96497a76.pth'
|
||||
},
|
||||
'MaskRCNN_IC15': {
|
||||
'config':
|
||||
'maskrcnn/mask_rcnn_r50_fpn_160e_icdar2015.py',
|
||||
'ckpt':
|
||||
'maskrcnn/'
|
||||
'mask_rcnn_r50_fpn_160e_icdar2015_20210219-8eb340a3.pth'
|
||||
},
|
||||
'MaskRCNN_IC17': {
|
||||
'config':
|
||||
'maskrcnn/mask_rcnn_r50_fpn_160e_icdar2017.py',
|
||||
'ckpt':
|
||||
'maskrcnn/'
|
||||
'mask_rcnn_r50_fpn_160e_icdar2017_20210218-c6ec3ebb.pth'
|
||||
},
|
||||
'PANet_CTW': {
|
||||
'config':
|
||||
'panet/panet_r18_fpem_ffm_600e_ctw1500.py',
|
||||
'ckpt':
|
||||
'panet/'
|
||||
'panet_r18_fpem_ffm_sbn_600e_ctw1500_20210219-3b3a9aa3.pth'
|
||||
},
|
||||
'PANet_IC15': {
|
||||
'config':
|
||||
'panet/panet_r18_fpem_ffm_600e_icdar2015.py',
|
||||
'ckpt':
|
||||
'panet/'
|
||||
'panet_r18_fpem_ffm_sbn_600e_icdar2015_20210219-42dbe46a.pth'
|
||||
},
|
||||
'PS_CTW': {
|
||||
'config': 'psenet/psenet_r50_fpnf_600e_ctw1500.py',
|
||||
'ckpt':
|
||||
'psenet/psenet_r50_fpnf_600e_ctw1500_20210401-216fed50.pth'
|
||||
},
|
||||
'PS_IC15': {
|
||||
'config':
|
||||
'psenet/psenet_r50_fpnf_600e_icdar2015.py',
|
||||
'ckpt':
|
||||
'psenet/psenet_r50_fpnf_600e_icdar2015_pretrain-eefd8fe6.pth'
|
||||
},
|
||||
'TextSnake': {
|
||||
'config':
|
||||
'textsnake/textsnake_r50_fpn_unet_1200e_ctw1500.py',
|
||||
'ckpt':
|
||||
'textsnake/textsnake_r50_fpn_unet_1200e_ctw1500-27f65b64.pth'
|
||||
},
|
||||
'Tesseract': {}
|
||||
}
|
||||
|
||||
textrecog_models = {
|
||||
'CRNN': {
|
||||
'config': 'crnn/crnn_academic_dataset.py',
|
||||
'ckpt': 'crnn/crnn_academic-a723a1c5.pth'
|
||||
},
|
||||
'SAR': {
|
||||
'config': 'sar/sar_r31_parallel_decoder_academic.py',
|
||||
'ckpt': 'sar/sar_r31_parallel_decoder_academic-dba3a4a3.pth'
|
||||
},
|
||||
'SAR_CN': {
|
||||
'config':
|
||||
'sar/sar_r31_parallel_decoder_chinese.py',
|
||||
'ckpt':
|
||||
'sar/sar_r31_parallel_decoder_chineseocr_20210507-b4be8214.pth'
|
||||
},
|
||||
'NRTR_1/16-1/8': {
|
||||
'config': 'nrtr/nrtr_r31_1by16_1by8_academic.py',
|
||||
'ckpt':
|
||||
'nrtr/nrtr_r31_1by16_1by8_academic_20211124-f60cebf4.pth'
|
||||
},
|
||||
'NRTR_1/8-1/4': {
|
||||
'config': 'nrtr/nrtr_r31_1by8_1by4_academic.py',
|
||||
'ckpt':
|
||||
'nrtr/nrtr_r31_1by8_1by4_academic_20211123-e1fdb322.pth'
|
||||
},
|
||||
'RobustScanner': {
|
||||
'config': 'robust_scanner/robustscanner_r31_academic.py',
|
||||
'ckpt': 'robustscanner/robustscanner_r31_academic-5f05874f.pth'
|
||||
},
|
||||
'SATRN': {
|
||||
'config': 'satrn/satrn_academic.py',
|
||||
'ckpt': 'satrn/satrn_academic_20211009-cb8b1580.pth'
|
||||
},
|
||||
'SATRN_sm': {
|
||||
'config': 'satrn/satrn_small.py',
|
||||
'ckpt': 'satrn/satrn_small_20211009-2cf13355.pth'
|
||||
},
|
||||
'ABINet': {
|
||||
'config': 'abinet/abinet_academic.py',
|
||||
'ckpt': 'abinet/abinet_academic-f718abf6.pth'
|
||||
},
|
||||
'ABINet_Vision': {
|
||||
'config': 'abinet/abinet_vision_only_academic.py',
|
||||
'ckpt': 'abinet/abinet_vision_only_academic-e6b9ea89.pth'
|
||||
},
|
||||
'SEG': {
|
||||
'config': 'seg/seg_r31_1by16_fpnocr_academic.py',
|
||||
'ckpt': 'seg/seg_r31_1by16_fpnocr_academic-72235b11.pth'
|
||||
},
|
||||
'CRNN_TPS': {
|
||||
'config': 'tps/crnn_tps_academic_dataset.py',
|
||||
'ckpt': 'tps/crnn_tps_academic_dataset_20210510-d221a905.pth'
|
||||
},
|
||||
'Tesseract': {},
|
||||
'MASTER': {
|
||||
'config': 'master/master_r31_12e_ST_MJ_SA.py',
|
||||
'ckpt': 'master/master_r31_12e_ST_MJ_SA-787edd36.pth'
|
||||
}
|
||||
}
|
||||
|
||||
kie_models = {
|
||||
'SDMGR': {
|
||||
'config': 'sdmgr/sdmgr_unet16_60e_wildreceipt.py',
|
||||
'ckpt':
|
||||
'sdmgr/sdmgr_unet16_60e_wildreceipt_20210520-7489e6de.pth'
|
||||
}
|
||||
}
|
||||
|
||||
self.td = det
|
||||
self.tr = recog
|
||||
self.kie = kie
|
||||
self.device = device
|
||||
if self.device is None:
|
||||
self.device = torch.device(
|
||||
'cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
# Check if the det/recog model choice is valid
|
||||
if self.td and self.td not in textdet_models:
|
||||
raise ValueError(self.td,
|
||||
'is not a supported text detection algorthm')
|
||||
elif self.tr and self.tr not in textrecog_models:
|
||||
raise ValueError(self.tr,
|
||||
'is not a supported text recognition algorithm')
|
||||
elif self.kie:
|
||||
if self.kie not in kie_models:
|
||||
raise ValueError(
|
||||
self.kie, 'is not a supported key information extraction'
|
||||
' algorithm')
|
||||
elif not (self.td and self.tr):
|
||||
raise NotImplementedError(
|
||||
self.kie, 'has to run together'
|
||||
' with text detection and recognition algorithms.')
|
||||
|
||||
self.detect_model = None
|
||||
if self.td and self.td == 'Tesseract':
|
||||
if tesserocr is None:
|
||||
raise ImportError('Please install tesserocr first. '
|
||||
'Check out the installation guide at '
|
||||
'https://github.com/sirfz/tesserocr')
|
||||
self.detect_model = 'Tesseract_det'
|
||||
elif self.td:
|
||||
# Build detection model
|
||||
if not det_config:
|
||||
det_config = os.path.join(config_dir, 'textdet/',
|
||||
textdet_models[self.td]['config'])
|
||||
if not det_ckpt:
|
||||
det_ckpt = 'https://download.openmmlab.com/mmocr/textdet/' + \
|
||||
textdet_models[self.td]['ckpt']
|
||||
|
||||
self.detect_model = init_detector(
|
||||
det_config, det_ckpt, device=self.device)
|
||||
self.detect_model = revert_sync_batchnorm(self.detect_model)
|
||||
|
||||
self.recog_model = None
|
||||
if self.tr and self.tr == 'Tesseract':
|
||||
if tesserocr is None:
|
||||
raise ImportError('Please install tesserocr first. '
|
||||
'Check out the installation guide at '
|
||||
'https://github.com/sirfz/tesserocr')
|
||||
self.recog_model = 'Tesseract_recog'
|
||||
elif self.tr:
|
||||
# Build recognition model
|
||||
if not recog_config:
|
||||
recog_config = os.path.join(
|
||||
config_dir, 'textrecog/',
|
||||
textrecog_models[self.tr]['config'])
|
||||
if not recog_ckpt:
|
||||
recog_ckpt = 'https://download.openmmlab.com/mmocr/' + \
|
||||
'textrecog/' + textrecog_models[self.tr]['ckpt']
|
||||
|
||||
self.recog_model = init_detector(
|
||||
recog_config, recog_ckpt, device=self.device)
|
||||
self.recog_model = revert_sync_batchnorm(self.recog_model)
|
||||
|
||||
self.kie_model = None
|
||||
if self.kie:
|
||||
# Build key information extraction model
|
||||
if not kie_config:
|
||||
kie_config = os.path.join(config_dir, 'kie/',
|
||||
kie_models[self.kie]['config'])
|
||||
if not kie_ckpt:
|
||||
kie_ckpt = 'https://download.openmmlab.com/mmocr/' + \
|
||||
'kie/' + kie_models[self.kie]['ckpt']
|
||||
|
||||
kie_cfg = Config.fromfile(kie_config)
|
||||
self.kie_model = MODELS.build(
|
||||
kie_cfg.model, test_cfg=kie_cfg.get('test_cfg'))
|
||||
self.kie_model = revert_sync_batchnorm(self.kie_model)
|
||||
self.kie_model.cfg = kie_cfg
|
||||
load_checkpoint(self.kie_model, kie_ckpt, map_location=self.device)
|
||||
|
||||
# Attribute check
|
||||
for model in list(filter(None, [self.recog_model, self.detect_model])):
|
||||
if hasattr(model, 'module'):
|
||||
model = model.module
|
||||
|
||||
@staticmethod
|
||||
def get_tesserocr_api():
|
||||
"""Get tesserocr api depending on different platform."""
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
if sys.platform == 'linux':
|
||||
api = tesserocr.PyTessBaseAPI()
|
||||
elif sys.platform == 'win32':
|
||||
try:
|
||||
p = subprocess.Popen(
|
||||
'where tesseract', stdout=subprocess.PIPE, shell=True)
|
||||
s = p.communicate()[0].decode('utf-8').split('\\')
|
||||
path = s[:-1] + ['tessdata']
|
||||
tessdata_path = '/'.join(path)
|
||||
api = tesserocr.PyTessBaseAPI(path=tessdata_path)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
'Please install tesseract first.\n Check out the'
|
||||
' installation guide at'
|
||||
' https://github.com/UB-Mannheim/tesseract/wiki')
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return api
|
||||
|
||||
def tesseract_det_inference(self, imgs, **kwargs):
|
||||
"""Inference image(s) with the tesseract detector.
|
||||
|
||||
Args:
|
||||
imgs (ndarray or list[ndarray]): image(s) to inference.
|
||||
|
||||
Returns:
|
||||
result (dict): Predicted results.
|
||||
"""
|
||||
is_batch = True
|
||||
if isinstance(imgs, np.ndarray):
|
||||
is_batch = False
|
||||
imgs = [imgs]
|
||||
assert is_type_list(imgs, np.ndarray)
|
||||
api = self.get_tesserocr_api()
|
||||
|
||||
# Get detection result using tesseract
|
||||
results = []
|
||||
for img in imgs:
|
||||
image = Image.fromarray(img)
|
||||
api.SetImage(image)
|
||||
boxes = api.GetComponentImages(tesserocr.RIL.TEXTLINE, True)
|
||||
boundaries = []
|
||||
for _, box, _, _ in boxes:
|
||||
min_x = box['x']
|
||||
min_y = box['y']
|
||||
max_x = box['x'] + box['w']
|
||||
max_y = box['y'] + box['h']
|
||||
boundary = [
|
||||
min_x, min_y, max_x, min_y, max_x, max_y, min_x, max_y, 1.0
|
||||
]
|
||||
boundaries.append(boundary)
|
||||
results.append({'boundary_result': boundaries})
|
||||
|
||||
# close tesserocr api
|
||||
api.End()
|
||||
|
||||
if not is_batch:
|
||||
return results[0]
|
||||
else:
|
||||
return results
|
||||
|
||||
def tesseract_recog_inference(self, imgs, **kwargs):
|
||||
"""Inference image(s) with the tesseract recognizer.
|
||||
|
||||
Args:
|
||||
imgs (ndarray or list[ndarray]): image(s) to inference.
|
||||
|
||||
Returns:
|
||||
result (dict): Predicted results.
|
||||
"""
|
||||
is_batch = True
|
||||
if isinstance(imgs, np.ndarray):
|
||||
is_batch = False
|
||||
imgs = [imgs]
|
||||
assert is_type_list(imgs, np.ndarray)
|
||||
api = self.get_tesserocr_api()
|
||||
|
||||
results = []
|
||||
for img in imgs:
|
||||
image = Image.fromarray(img)
|
||||
api.SetImage(image)
|
||||
api.SetRectangle(0, 0, img.shape[1], img.shape[0])
|
||||
# Remove beginning and trailing spaces from Tesseract
|
||||
text = api.GetUTF8Text().strip()
|
||||
conf = api.MeanTextConf() / 100
|
||||
results.append({'text': text, 'score': conf})
|
||||
|
||||
# close tesserocr api
|
||||
api.End()
|
||||
|
||||
if not is_batch:
|
||||
return results[0]
|
||||
else:
|
||||
return results
|
||||
|
||||
def readtext(self,
|
||||
img,
|
||||
output=None,
|
||||
details=False,
|
||||
export=None,
|
||||
export_format='json',
|
||||
batch_mode=False,
|
||||
recog_batch_size=0,
|
||||
det_batch_size=0,
|
||||
single_batch_size=0,
|
||||
imshow=False,
|
||||
print_result=False,
|
||||
merge=False,
|
||||
merge_xdist=20,
|
||||
**kwargs):
|
||||
args = locals().copy()
|
||||
[args.pop(x, None) for x in ['kwargs', 'self']]
|
||||
args = Namespace(**args)
|
||||
|
||||
# Input and output arguments processing
|
||||
self._args_processing(args)
|
||||
self.args = args
|
||||
|
||||
pp_result = None
|
||||
|
||||
# Send args and models to the MMOCR model inference API
|
||||
# and call post-processing functions for the output
|
||||
if self.detect_model and self.recog_model:
|
||||
det_recog_result = self.det_recog_kie_inference(
|
||||
self.detect_model, self.recog_model, kie_model=self.kie_model)
|
||||
pp_result = self.det_recog_pp(det_recog_result)
|
||||
else:
|
||||
for model in list(
|
||||
filter(None, [self.recog_model, self.detect_model])):
|
||||
result = self.single_inference(model, args.arrays,
|
||||
args.batch_mode,
|
||||
args.single_batch_size)
|
||||
pp_result = self.single_pp(result, model)
|
||||
|
||||
return pp_result
|
||||
|
||||
# Post processing function for end2end ocr
|
||||
def det_recog_pp(self, result):
|
||||
final_results = []
|
||||
args = self.args
|
||||
for arr, output, export, det_recog_result in zip(
|
||||
args.arrays, args.output, args.export, result):
|
||||
if output or args.imshow:
|
||||
if self.kie_model:
|
||||
res_img = det_recog_show_result(arr, det_recog_result)
|
||||
else:
|
||||
res_img = det_recog_show_result(
|
||||
arr, det_recog_result, out_file=output)
|
||||
if args.imshow and not self.kie_model:
|
||||
mmcv.imshow(res_img, 'inference results')
|
||||
if not args.details:
|
||||
simple_res = {}
|
||||
simple_res['filename'] = det_recog_result['filename']
|
||||
simple_res['text'] = [
|
||||
x['text'] for x in det_recog_result['result']
|
||||
]
|
||||
final_result = simple_res
|
||||
else:
|
||||
final_result = det_recog_result
|
||||
if export:
|
||||
mmcv.dump(final_result, export, indent=4)
|
||||
if args.print_result:
|
||||
print(final_result, end='\n\n')
|
||||
final_results.append(final_result)
|
||||
return final_results
|
||||
|
||||
# Post processing function for separate det/recog inference
|
||||
def single_pp(self, result, model):
|
||||
for arr, output, export, res in zip(self.args.arrays, self.args.output,
|
||||
self.args.export, result):
|
||||
if export:
|
||||
mmcv.dump(res, export, indent=4)
|
||||
if output or self.args.imshow:
|
||||
if model == 'Tesseract_det':
|
||||
res_img = TextDetectorMixin(show_score=False).show_result(
|
||||
arr, res, out_file=output)
|
||||
elif model == 'Tesseract_recog':
|
||||
res_img = BaseRecognizer.show_result(
|
||||
arr, res, out_file=output)
|
||||
else:
|
||||
res_img = model.show_result(arr, res, out_file=output)
|
||||
if self.args.imshow:
|
||||
mmcv.imshow(res_img, 'inference results')
|
||||
if self.args.print_result:
|
||||
print(res, end='\n\n')
|
||||
return result
|
||||
|
||||
def generate_kie_labels(self, result, boxes, class_list):
|
||||
idx_to_cls = {}
|
||||
if class_list is not None:
|
||||
for line in list_from_file(class_list):
|
||||
class_idx, class_label = line.strip().split()
|
||||
idx_to_cls[class_idx] = class_label
|
||||
|
||||
max_value, max_idx = torch.max(result['nodes'].detach().cpu(), -1)
|
||||
node_pred_label = max_idx.numpy().tolist()
|
||||
node_pred_score = max_value.numpy().tolist()
|
||||
labels = []
|
||||
for i in range(len(boxes)):
|
||||
pred_label = str(node_pred_label[i])
|
||||
if pred_label in idx_to_cls:
|
||||
pred_label = idx_to_cls[pred_label]
|
||||
pred_score = node_pred_score[i]
|
||||
labels.append((pred_label, pred_score))
|
||||
return labels
|
||||
|
||||
def visualize_kie_output(self,
|
||||
model,
|
||||
data,
|
||||
result,
|
||||
out_file=None,
|
||||
show=False):
|
||||
"""Visualizes KIE output."""
|
||||
img_tensor = data['img'].data
|
||||
img_meta = data['img_metas'].data
|
||||
gt_bboxes = data['gt_bboxes'].data.numpy().tolist()
|
||||
if img_tensor.dtype == torch.uint8:
|
||||
# The img tensor is the raw input not being normalized
|
||||
# (For SDMGR non-visual)
|
||||
img = img_tensor.cpu().numpy().transpose(1, 2, 0)
|
||||
else:
|
||||
img = tensor2imgs(
|
||||
img_tensor.unsqueeze(0), **img_meta.get('img_norm_cfg', {}))[0]
|
||||
h, w, _ = img_meta.get('img_shape', img.shape)
|
||||
img_show = img[:h, :w, :]
|
||||
model.show_result(
|
||||
img_show, result, gt_bboxes, show=show, out_file=out_file)
|
||||
|
||||
# End2end ocr inference pipeline
|
||||
def det_recog_kie_inference(self, det_model, recog_model, kie_model=None):
|
||||
end2end_res = []
|
||||
# Find bounding boxes in the images (text detection)
|
||||
det_result = self.single_inference(det_model, self.args.arrays,
|
||||
self.args.batch_mode,
|
||||
self.args.det_batch_size)
|
||||
bboxes_list = [res['boundary_result'] for res in det_result]
|
||||
|
||||
if kie_model:
|
||||
kie_dataset = WildReceiptDataset(
|
||||
dict_file=kie_model.cfg.data.test.dict_file)
|
||||
|
||||
# For each bounding box, the image is cropped and
|
||||
# sent to the recognition model either one by one
|
||||
# or all together depending on the batch_mode
|
||||
for filename, arr, bboxes, out_file in zip(self.args.filenames,
|
||||
self.args.arrays,
|
||||
bboxes_list,
|
||||
self.args.output):
|
||||
img_e2e_res = {}
|
||||
img_e2e_res['filename'] = filename
|
||||
img_e2e_res['result'] = []
|
||||
box_imgs = []
|
||||
for bbox in bboxes:
|
||||
box_res = {}
|
||||
box_res['box'] = [round(x) for x in bbox[:-1]]
|
||||
box_res['box_score'] = float(bbox[-1])
|
||||
box = bbox[:8]
|
||||
if len(bbox) > 9:
|
||||
min_x = min(bbox[0:-1:2])
|
||||
min_y = min(bbox[1:-1:2])
|
||||
max_x = max(bbox[0:-1:2])
|
||||
max_y = max(bbox[1:-1:2])
|
||||
box = [
|
||||
min_x, min_y, max_x, min_y, max_x, max_y, min_x, max_y
|
||||
]
|
||||
box_img = crop_img(arr, box)
|
||||
if self.args.batch_mode:
|
||||
box_imgs.append(box_img)
|
||||
else:
|
||||
if recog_model == 'Tesseract_recog':
|
||||
recog_result = self.single_inference(
|
||||
recog_model, box_img, batch_mode=True)
|
||||
else:
|
||||
recog_result = model_inference(recog_model, box_img)
|
||||
text = recog_result['text']
|
||||
text_score = recog_result['score']
|
||||
if isinstance(text_score, list):
|
||||
text_score = sum(text_score) / max(1, len(text))
|
||||
box_res['text'] = text
|
||||
box_res['text_score'] = text_score
|
||||
img_e2e_res['result'].append(box_res)
|
||||
|
||||
if self.args.batch_mode:
|
||||
recog_results = self.single_inference(
|
||||
recog_model, box_imgs, True, self.args.recog_batch_size)
|
||||
for i, recog_result in enumerate(recog_results):
|
||||
text = recog_result['text']
|
||||
text_score = recog_result['score']
|
||||
if isinstance(text_score, (list, tuple)):
|
||||
text_score = sum(text_score) / max(1, len(text))
|
||||
img_e2e_res['result'][i]['text'] = text
|
||||
img_e2e_res['result'][i]['text_score'] = text_score
|
||||
|
||||
if self.args.merge:
|
||||
img_e2e_res['result'] = stitch_boxes_into_lines(
|
||||
img_e2e_res['result'], self.args.merge_xdist, 0.5)
|
||||
|
||||
if kie_model:
|
||||
annotations = copy.deepcopy(img_e2e_res['result'])
|
||||
# Customized for kie_dataset, which
|
||||
# assumes that boxes are represented by only 4 points
|
||||
for i, ann in enumerate(annotations):
|
||||
min_x = min(ann['box'][::2])
|
||||
min_y = min(ann['box'][1::2])
|
||||
max_x = max(ann['box'][::2])
|
||||
max_y = max(ann['box'][1::2])
|
||||
annotations[i]['box'] = [
|
||||
min_x, min_y, max_x, min_y, max_x, max_y, min_x, max_y
|
||||
]
|
||||
ann_info = kie_dataset._parse_anno_info(annotations)
|
||||
ann_info['ori_bboxes'] = ann_info.get('ori_bboxes',
|
||||
ann_info['bboxes'])
|
||||
ann_info['gt_bboxes'] = ann_info.get('gt_bboxes',
|
||||
ann_info['bboxes'])
|
||||
kie_result, data = model_inference(
|
||||
kie_model,
|
||||
arr,
|
||||
ann=ann_info,
|
||||
return_data=True,
|
||||
batch_mode=self.args.batch_mode)
|
||||
# visualize KIE results
|
||||
self.visualize_kie_output(
|
||||
kie_model,
|
||||
data,
|
||||
kie_result,
|
||||
out_file=out_file,
|
||||
show=self.args.imshow)
|
||||
gt_bboxes = data['gt_bboxes'].data.numpy().tolist()
|
||||
labels = self.generate_kie_labels(kie_result, gt_bboxes,
|
||||
kie_model.class_list)
|
||||
for i in range(len(gt_bboxes)):
|
||||
img_e2e_res['result'][i]['label'] = labels[i][0]
|
||||
img_e2e_res['result'][i]['label_score'] = labels[i][1]
|
||||
|
||||
end2end_res.append(img_e2e_res)
|
||||
return end2end_res
|
||||
|
||||
# Separate det/recog inference pipeline
|
||||
def single_inference(self, model, arrays, batch_mode, batch_size=0):
|
||||
|
||||
def inference(m, a, **kwargs):
|
||||
if model == 'Tesseract_det':
|
||||
return self.tesseract_det_inference(a)
|
||||
elif model == 'Tesseract_recog':
|
||||
return self.tesseract_recog_inference(a)
|
||||
else:
|
||||
return model_inference(m, a, **kwargs)
|
||||
|
||||
result = []
|
||||
if batch_mode:
|
||||
if batch_size == 0:
|
||||
result = inference(model, arrays, batch_mode=True)
|
||||
else:
|
||||
n = batch_size
|
||||
arr_chunks = [
|
||||
arrays[i:i + n] for i in range(0, len(arrays), n)
|
||||
]
|
||||
for chunk in arr_chunks:
|
||||
result.extend(inference(model, chunk, batch_mode=True))
|
||||
else:
|
||||
for arr in arrays:
|
||||
result.append(inference(model, arr, batch_mode=False))
|
||||
return result
|
||||
|
||||
# Arguments pre-processing function
|
||||
def _args_processing(self, args):
|
||||
# Check if the input is a list/tuple that
|
||||
# contains only np arrays or strings
|
||||
if isinstance(args.img, (list, tuple)):
|
||||
img_list = args.img
|
||||
if not all([isinstance(x, (np.ndarray, str)) for x in args.img]):
|
||||
raise AssertionError('Images must be strings or numpy arrays')
|
||||
|
||||
# Create a list of the images
|
||||
if isinstance(args.img, str):
|
||||
img_path = Path(args.img)
|
||||
if img_path.is_dir():
|
||||
img_list = [str(x) for x in img_path.glob('*')]
|
||||
else:
|
||||
img_list = [str(img_path)]
|
||||
elif isinstance(args.img, np.ndarray):
|
||||
img_list = [args.img]
|
||||
|
||||
# Read all image(s) in advance to reduce wasted time
|
||||
# re-reading the images for visualization output
|
||||
args.arrays = [mmcv.imread(x) for x in img_list]
|
||||
|
||||
# Create a list of filenames (used for output images and result files)
|
||||
if isinstance(img_list[0], str):
|
||||
args.filenames = [str(Path(x).stem) for x in img_list]
|
||||
else:
|
||||
args.filenames = [str(x) for x in range(len(img_list))]
|
||||
|
||||
# If given an output argument, create a list of output image filenames
|
||||
num_res = len(img_list)
|
||||
if args.output:
|
||||
output_path = Path(args.output)
|
||||
if output_path.is_dir():
|
||||
args.output = [
|
||||
str(output_path / f'out_{x}.png') for x in args.filenames
|
||||
]
|
||||
else:
|
||||
args.output = [str(args.output)]
|
||||
if args.batch_mode:
|
||||
raise AssertionError('Output of multiple images inference'
|
||||
' must be a directory')
|
||||
else:
|
||||
args.output = [None] * num_res
|
||||
|
||||
# If given an export argument, create a list of
|
||||
# result filenames for each image
|
||||
if args.export:
|
||||
export_path = Path(args.export)
|
||||
args.export = [
|
||||
str(export_path / f'out_{x}.{args.export_format}')
|
||||
for x in args.filenames
|
||||
]
|
||||
else:
|
||||
args.export = [None] * num_res
|
||||
|
||||
return args
|
||||
|
||||
|
||||
# Create an inference pipeline with parsed arguments
|
||||
def main():
|
||||
args = parse_args()
|
||||
ocr = MMOCR(**vars(args))
|
||||
ocr.readtext(**vars(args))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -16,6 +16,7 @@ def register_all_modules(init_default_scope: bool = True) -> None:
|
|||
to https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/registry.md
|
||||
Defaults to True.
|
||||
""" # noqa
|
||||
import mmocr.apis # noqa: F401,F403
|
||||
import mmocr.datasets # noqa: F401,F403
|
||||
import mmocr.engine # noqa: F401,F403
|
||||
import mmocr.evaluation # noqa: F401,F403
|
||||
|
|
|
@ -174,8 +174,6 @@ class TestPackKIEInputs(TestCase):
|
|||
torch.int64)
|
||||
self.assertIsInstance(data_sample.gt_instances.texts, list)
|
||||
|
||||
self.assertIn('img_path', data_sample)
|
||||
|
||||
transform = PackKIEInputs(meta_keys=('img_path', ))
|
||||
results = transform(copy.deepcopy(datainfo))
|
||||
self.assertIn('inputs', results)
|
||||
|
@ -191,7 +189,4 @@ class TestPackKIEInputs(TestCase):
|
|||
self.assertEqual(results['inputs'].shape, torch.Size((0, 0, 0)))
|
||||
|
||||
def test_repr(self):
|
||||
self.assertEqual(
|
||||
repr(self.transform),
|
||||
("PackKIEInputs(meta_keys=('img_path', 'ori_shape', "
|
||||
"'img_shape', 'scale_factor'))"))
|
||||
self.assertEqual(repr(self.transform), ('PackKIEInputs(meta_keys=())'))
|
||||
|
|
Loading…
Reference in New Issue