[Refactor] ocr.py (#1344)

* [Feature] Add BaseInferencer]

* [Feature] Add Det&Rec Inferencer

* [Feature] Add KIEInferencer

* [Feature] Add MMOCRInferencer

* [Refactor] update ocr.py

* update links

* update two links

* remove ocr.py

* move ocr.py and add loadfromndarray

Co-authored-by: xinyu <wangxinyu2017@gmail.com>
Co-authored-by: liukuikun <liukuikun@sensetime.com>
pull/1362/head
Tong Gao 2022-08-31 22:56:24 +08:00 committed by GitHub
parent dbb346afed
commit db6ce0d95e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 1487 additions and 895 deletions

View File

@ -0,0 +1,2 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .inferencers import * # NOQA

View File

@ -0,0 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .kie_inferencer import KIEInferencer
from .mmocr_inferencer import MMOCRInferencer
from .textdet_inferencer import TextDetInferencer
from .textrec_inferencer import TextRecInferencer
__all__ = [
'TextDetInferencer', 'TextRecInferencer', 'KIEInferencer',
'MMOCRInferencer'
]

View File

@ -0,0 +1,195 @@
# Copyright (c) OpenMMLab. All rights reserved.
from datetime import datetime
from typing import Dict, List, Optional, Sequence, Tuple, Union
import numpy as np
import torch
from mmengine.config import Config
from mmengine.runner import load_checkpoint
from mmengine.structures import InstanceData
from mmocr.registry import MODELS, VISUALIZERS
from mmocr.utils import ConfigType
InstanceList = List[InstanceData]
InputType = Union[str, np.ndarray]
InputsType = Union[InputType, Sequence[InputType]]
PredType = Union[InstanceData, InstanceList]
ImgType = Union[np.ndarray, Sequence[np.ndarray]]
ResType = Union[Dict, List[Dict]]
class BaseInferencer:
"""Base inferencer.
Args:
model (str or ConfigType): Model config or the path to it.
ckpt (str, optional): Path to the checkpoint.
device (str, optional): Device to run inference. If None, the best
device will be automatically used.
show (bool): Whether to display the image in a popup window.
Defaults to False.
wait_time (float): The interval of show (s). Defaults to 0.
draw_pred (bool): Whether to draw predicted bounding boxes.
Defaults to True.
pred_score_thr (float): Minimum score of bboxes to draw.
Defaults to 0.3.
img_out_dir (str): Output directory of images. Defaults to ''.
pred_out_file: File to save the inference results. If left as empty, no
file will be saved.
print_result (bool): Whether to print the result.
Defaults to False.
"""
func_kwargs = dict(preprocess=[], forward=[], visualize=[], postprocess=[])
func_order = dict(preprocess=0, forward=1, visualize=2, postprocess=3)
def __init__(self,
config: Union[ConfigType, str],
ckpt: Optional[str],
device: Optional[str] = None,
**kwargs) -> None:
# Load config to cfg
if isinstance(config, str):
cfg = Config.fromfile(config)
elif not isinstance(config, ConfigType):
raise TypeError('config must be a filename or any ConfigType'
f'object, but got {type(cfg)}')
if cfg.model.get('pretrained'):
cfg.model.pretrained = None
if device is None:
device = torch.device(
'cuda' if torch.cuda.is_available() else 'cpu')
self._init_model(cfg, ckpt, device)
self._init_pipeline(cfg)
self._init_visualizer(cfg)
self.base_params = self._dispatch_kwargs(**kwargs)
def _init_model(self, cfg: Union[ConfigType, str], ckpt: Optional[str],
device: str) -> None:
"""Initialize the model with the given config and checkpoint on the
specific device."""
model = MODELS.build(cfg.model)
if ckpt is not None:
ckpt = load_checkpoint(model, ckpt, map_location='cpu')
model.cfg = cfg.model
model.to(device)
model.eval()
self.model = model
def _init_pipeline(self, cfg: ConfigType) -> None:
"""Initialize the test pipeline."""
raise NotImplementedError
def _init_visualizer(self, cfg: ConfigType) -> None:
"""Initialize visualizers."""
# TODO: We don't export images via backends since the interface
# of the visualizer will have to be refactored.
self.visualizer = None
if 'visualizer' in cfg:
ts = str(datetime.timestamp(datetime.now()))
cfg.visualizer['name'] = f'inferencer{ts}'
self.visualizer = VISUALIZERS.build(cfg.visualizer)
def _dispatch_kwargs(self, **kwargs) -> Tuple[Dict, Dict, Dict, Dict]:
"""Dispatch kwargs to preprocess(), forward(), visualize() and
postprocess() according to the actual demands."""
results = [{}, {}, {}, {}]
dispatched_kwargs = set()
# Dispatch kwargs according to self.func_kwargs
for func_name, func_kwargs in self.func_kwargs.items():
for func_kwarg in func_kwargs:
if func_kwarg in kwargs:
dispatched_kwargs.add(func_kwarg)
results[self.func_order[func_name]][func_kwarg] = kwargs[
func_kwarg]
# Find if there is any kwargs that are not dispatched
for kwarg in kwargs:
if kwarg not in dispatched_kwargs:
raise ValueError(f'Unknown kwarg: {kwarg}')
return results
def __call__(self, inputs: InputsType,
**kwargs) -> Union[Dict, List[Dict]]:
"""Call the inferencer.
Args:
user_inputs: Inputs for the inferencer.
kwargs: Keyword arguments for the inferencer.
"""
params = self._dispatch_kwargs(**kwargs)
preprocess_kwargs = self.base_params[0].copy()
preprocess_kwargs.update(params[0])
forward_kwargs = self.base_params[1].copy()
forward_kwargs.update(params[1])
visualize_kwargs = self.base_params[2].copy()
visualize_kwargs.update(params[2])
postprocess_kwargs = self.base_params[3].copy()
postprocess_kwargs.update(params[3])
data = self.preprocess(inputs, **preprocess_kwargs)
preds = self.forward(data, **forward_kwargs)
imgs = self.visualize(inputs, preds, **visualize_kwargs)
results = self.postprocess(preds, imgs, **postprocess_kwargs)
return results
def preprocess(self, inputs: InputsType) -> List[Dict]:
"""Process the inputs into a model-feedable format."""
raise NotImplementedError
def forward(self, inputs: InputsType) -> PredType:
"""Forward the inputs to the model."""
with torch.no_grad():
return self.model.test_step(inputs)
def visualize(self,
inputs: InputsType,
preds: PredType,
show: bool = False,
wait_time: int = 0,
draw_pred: bool = True,
pred_score_thr: float = 0.3,
img_out_dir: str = '') -> List[np.ndarray]:
"""Visualize predictions.
Args:
inputs (List[Union[str, np.ndarray]]): Inputs for the inferencer.
preds (List[Dict]): Predictions of the model.
show (bool): Whether to display the image in a popup window.
Defaults to False.
wait_time (float): The interval of show (s). Defaults to 0.
draw_pred (bool): Whether to draw predicted bounding boxes.
Defaults to True.
pred_score_thr (float): Minimum score of bboxes to draw.
Defaults to 0.3.
img_out_dir (str): Output directory of images. Defaults to ''.
"""
raise NotImplementedError
def postprocess(
self,
preds: PredType,
imgs: Optional[List[np.ndarray]] = None,
) -> Union[ResType, Tuple[ResType, np.ndarray]]:
"""Postprocess predictions.
Args:
preds (List[Dict]): Predictions of the model.
imgs (Optional[np.ndarray]): Visualized predictions.
is_batch (bool): Whether the inputs are in a batch.
Defaults to False.
print_result (bool): Whether to print the result.
Defaults to False.
pred_out_file (str): Output file name to store predictions
without images. Supported file formats are json, yaml/yml
and pickle/pkl. Defaults to ''.
Returns:
TODO
"""
raise NotImplementedError

View File

@ -0,0 +1,267 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from typing import Dict, List, Optional, Sequence, Tuple, Union
import mmcv
import numpy as np
from mmengine.dataset import Compose
from mmengine.structures import InstanceData
from mmocr.utils import ConfigType
from .base_inferencer import BaseInferencer
InstanceList = List[InstanceData]
InputType = Union[str, np.ndarray]
InputsType = Union[InputType, Sequence[InputType]]
PredType = Union[InstanceData, InstanceList]
ImgType = Union[np.ndarray, Sequence[np.ndarray]]
ResType = Union[Dict, List[Dict], InstanceData, List[InstanceData]]
class BaseMMOCRInferencer(BaseInferencer):
"""Base inferencer.
Args:
model (str or ConfigType): Model config or the path to it.
ckpt (str, optional): Path to the checkpoint.
device (str, optional): Device to run inference. If None, the best
device will be automatically used.
show (bool): Whether to display the image in a popup window.
Defaults to False.
wait_time (float): The interval of show (s). Defaults to 0.
draw_pred (bool): Whether to draw predicted bounding boxes.
Defaults to True.
pred_score_thr (float): Minimum score of bboxes to draw.
Defaults to 0.3.
img_out_dir (str): Output directory of images. Defaults to ''.
pred_out_file: File to save the inference results. If left as empty, no
file will be saved.
print_result (bool): Whether to print the result.
Defaults to False.
"""
func_kwargs = dict(
preprocess=[],
forward=[],
visualize=[
'show', 'wait_time', 'draw_pred', 'pred_score_thr', 'img_out_dir'
],
postprocess=['print_result', 'pred_out_file', 'get_datasample'])
def __init__(self,
config: Union[ConfigType, str],
ckpt: Optional[str],
device: Optional[str] = None,
**kwargs) -> None:
# A global counter tracking the number of images processed, for
# naming of the output images
self.num_visualized_imgs = 0
super().__init__(config=config, ckpt=ckpt, device=device, **kwargs)
def _init_pipeline(self, cfg: ConfigType) -> None:
"""Initialize the test pipeline."""
pipeline_cfg = cfg.test_dataloader.dataset.pipeline
# For inference, the key of ``instances`` is not used.
if 'meta_keys' in pipeline_cfg[-1]:
pipeline_cfg[-1]['meta_keys'] = tuple(
meta_key for meta_key in pipeline_cfg[-1]['meta_keys']
if meta_key != 'instances')
# Loading annotations is also not applicable
idx = self._get_transform_idx(pipeline_cfg, 'LoadOCRAnnotations')
if idx != -1:
del pipeline_cfg[idx]
self.file_pipeline = Compose(pipeline_cfg)
load_img_idx = self._get_transform_idx(pipeline_cfg,
'LoadImageFromFile')
if load_img_idx == -1:
raise ValueError(
'LoadImageFromFile is not found in the test pipeline')
pipeline_cfg[load_img_idx]['type'] = 'LoadImageFromNDArray'
self.ndarray_pipeline = Compose(pipeline_cfg)
def _get_transform_idx(self, pipeline_cfg: ConfigType, name: str) -> int:
"""Returns the index of the transform in a pipeline.
If the transform is not found, returns -1.
"""
for i, transform in enumerate(pipeline_cfg):
if transform['type'] == name:
return i
return -1
def preprocess(self, inputs: InputsType) -> Dict:
"""Process the inputs into a model-feedable format."""
results = []
for single_input in inputs:
if isinstance(single_input, str):
if osp.isdir(single_input):
raise ValueError('Feeding a directory is not supported')
# for img_path in os.listdir(single_input):
# data_ =dict(img_path=osp.join(single_input,img_path))
# results.append(self.file_pipeline(data_))
else:
data_ = dict(img_path=single_input)
results.append(self.file_pipeline(data_))
elif isinstance(single_input, np.ndarray):
data_ = dict(img=single_input)
results.append(self.ndarray_pipeline(data_))
else:
raise ValueError(
f'Unsupported input type: {type(single_input)}')
return self._collate(results)
def _collate(self, results: List[Dict]) -> Dict:
"""Collate the results from different images."""
results = {key: [d[key] for d in results] for key in results[0]}
return results
def __call__(self, user_inputs: InputsType,
**kwargs) -> Union[Dict, List[Dict]]:
"""Call the inferencer.
Args:
user_inputs: Inputs for the inferencer.
kwargs: Keyword arguments for the inferencer.
"""
# Detect if user_inputs are in a batch
is_batch = isinstance(user_inputs, (list, tuple))
inputs = user_inputs if is_batch else [user_inputs]
params = self._dispatch_kwargs(**kwargs)
preprocess_kwargs = self.base_params[0].copy()
preprocess_kwargs.update(params[0])
forward_kwargs = self.base_params[1].copy()
forward_kwargs.update(params[1])
visualize_kwargs = self.base_params[2].copy()
visualize_kwargs.update(params[2])
postprocess_kwargs = self.base_params[3].copy()
postprocess_kwargs.update(params[3])
data = self.preprocess(inputs, **preprocess_kwargs)
preds = self.forward(data, **forward_kwargs)
imgs = self.visualize(inputs, preds, **visualize_kwargs)
results = self.postprocess(
preds, imgs, is_batch=is_batch, **postprocess_kwargs)
return results
def visualize(self,
inputs: InputsType,
preds: PredType,
show: bool = False,
wait_time: int = 0,
draw_pred: bool = True,
pred_score_thr: float = 0.3,
img_out_dir: str = '') -> List[np.ndarray]:
"""Visualize predictions.
Args:
inputs (List[Union[str, np.ndarray]]): Inputs for the inferencer.
preds (List[Dict]): Predictions of the model.
show (bool): Whether to display the image in a popup window.
Defaults to False.
wait_time (float): The interval of show (s). Defaults to 0.
draw_pred (bool): Whether to draw predicted bounding boxes.
Defaults to True.
pred_score_thr (float): Minimum score of bboxes to draw.
Defaults to 0.3.
img_out_dir (str): Output directory of images. Defaults to ''.
"""
if self.visualizer is None or not show and img_out_dir == '':
return None
if getattr(self, 'visualizer') is None:
raise ValueError('Visualization needs the "visualizer" term'
'defined in the config, but got None.')
results = []
for single_input, pred in zip(inputs, preds):
if isinstance(single_input, str):
img = mmcv.imread(single_input)
img = img[:, :, ::-1]
img_name = osp.basename(single_input)
elif isinstance(single_input, np.ndarray):
img = single_input.copy()
img_num = str(self.num_visualized_imgs).zfill(8)
img_name = f'{img_num}.jpg'
else:
raise ValueError('Unsupported input type: '
f'{type(single_input)}')
out_file = osp.join(img_out_dir, img_name) if img_out_dir != '' \
else None
self.visualizer.add_datasample(
img_name,
img,
pred,
show=show,
wait_time=wait_time,
draw_gt=False,
draw_pred=draw_pred,
pred_score_thr=pred_score_thr,
out_file=out_file,
)
results.append(img)
self.num_visualized_imgs += 1
return results
def postprocess(
self,
preds: PredType,
imgs: Optional[List[np.ndarray]] = None,
is_batch: bool = False,
print_result: bool = False,
pred_out_file: str = '',
get_datasample: bool = False,
) -> Union[ResType, Tuple[ResType, np.ndarray]]:
"""Postprocess predictions.
Args:
preds (List[Dict]): Predictions of the model.
imgs (Optional[np.ndarray]): Visualized predictions.
is_batch (bool): Whether the inputs are in a batch.
Defaults to False.
print_result (bool): Whether to print the result.
Defaults to False.
pred_out_file (str): Output file name to store predictions
without images. Supported file formats are json, yaml/yml
and pickle/pkl. Defaults to ''.
get_datasample (bool): Whether to use Datasample to store
inference results. If False, dict will be used.
Returns:
TODO
"""
results = preds
if not get_datasample:
results = []
for pred in preds:
result = self._pred2dict(pred)
results.append(result)
if not is_batch:
results = results[0]
if print_result:
print(results)
# Add img to the results after printing
if pred_out_file != '':
mmcv.dump(results, pred_out_file)
if imgs is None:
return results
return results, imgs
def _pred2dict(self, data_sample: InstanceData) -> Dict:
"""Extract elements necessary to represent a prediction into a
dictionary.
It's better to contain only basic data elements such as strings and
numbers in order to guarantee it's json-serializable.
"""
raise NotImplementedError

View File

@ -0,0 +1,194 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from typing import Dict, List, Sequence
import mmcv
import numpy as np
from mmengine.dataset import Compose
from mmocr.registry import DATASETS, VISUALIZERS
from mmocr.structures import KIEDataSample
from mmocr.utils import ConfigType
from .base_mmocr_inferencer import BaseMMOCRInferencer, PredType
InputType = Dict
InputsType = Sequence[Dict]
class KIEInferencer(BaseMMOCRInferencer):
"""
Inputs:
dict or list[dict]: A dictionary containing the following keys:
'bbox', 'texts', ` in this format:
- img (str or ndarray): Path to the image or the image itself.
- img_shape (tuple(int, int)): Image shape in (H, W). In
- instances (list[dict]): A list of instances.
- bbox (ndarray(dtype=np.float32)): Shape (4, ). Bounding box.
- text (str): Annotation text.
.. code-block:: python
{
# A nested list of 4 numbers representing the bounding box of
# the instance, in (x1, y1, x2, y2) order.
'bbox': np.array([[x1, y1, x2, y2], [x1, y1, x2, y2], ...],
dtype=np.int32),
# List of texts.
"texts": ['text1', 'text2', ...],
}
"""
def _init_pipeline(self, cfg: ConfigType) -> None:
"""Initialize the test pipeline."""
pipeline_cfg = cfg.test_dataloader.dataset.pipeline
idx = self._get_transform_idx(pipeline_cfg, 'LoadKIEAnnotations')
if idx == -1:
raise ValueError(
'LoadKIEAnnotations is not found in the test pipeline')
pipeline_cfg[idx]['with_label'] = False
self.novisual = self._get_transform_idx(pipeline_cfg,
'LoadImageFromFile') == -1
# If it's in non-visual mode, self.pipeline will be specified.
# Otherwise, file_pipeline and ndarray_pipeline will be specified.
if self.novisual:
self.pipeline = Compose(pipeline_cfg)
else:
return super()._init_pipeline(cfg)
def _init_visualizer(self, cfg: ConfigType) -> None:
"""Initialize visualizers."""
# TODO: We don't export images via backends since the interface
# of the visualizer will have to be refactored.
self.visualizer = None
if 'visualizer' in cfg:
self.visualizer = VISUALIZERS.build(cfg.visualizer)
dataset = DATASETS.build(cfg.test_dataloader.dataset)
self.visualizer.dataset_meta = dataset.metainfo
def preprocess(self, inputs: InputsType) -> List[Dict]:
results = []
for single_input in inputs:
if self.novisual:
if 'img' not in single_input and \
'img_shape' not in single_input:
raise ValueError(
'KIEInferencer in no-visual mode '
'requires input has "img" or "img_shape", but both are'
' not found.')
if 'img' in single_input:
new_input = {
k: v
for k, v in single_input.items() if k != 'img'
}
img = single_input['img']
if isinstance(img, str):
img = mmcv.imread(img)
new_input['img_shape'] = img.shape[::2]
results.append(self.pipeline(new_input))
else:
if 'img' not in single_input:
raise ValueError(
'This inferencer is constructed to '
'accept image inputs, but the input does not contain '
'"img" key.')
if isinstance(single_input['img'], str):
data_ = {
k: v
for k, v in single_input.items() if k != 'img'
}
data_['img_path'] = single_input['img']
results.append(self.file_pipeline(data_))
elif isinstance(single_input['img'], np.ndarray):
results.append(self.ndarray_pipeline(single_input))
else:
atype = type(single_input['img'])
raise ValueError(f'Unsupported input type: {atype}')
return self._collate(results)
def visualize(self,
inputs: InputsType,
preds: PredType,
show: bool = False,
wait_time: int = 0,
draw_pred: bool = True,
pred_score_thr: float = 0.3,
img_out_dir: str = '') -> List[np.ndarray]:
"""Visualize predictions.
Args:
inputs (List[Union[str, np.ndarray]]): Inputs for the inferencer.
preds (List[Dict]): Predictions of the model.
show (bool): Whether to display the image in a popup window.
Defaults to False.
wait_time (float): The interval of show (s). Defaults to 0.
draw_pred (bool): Whether to draw predicted bounding boxes.
Defaults to True.
pred_score_thr (float): Minimum score of bboxes to draw.
Defaults to 0.3.
img_out_dir (str): Output directory of images. Defaults to ''.
"""
if self.visualizer is None or not show and img_out_dir == '':
return None
if getattr(self, 'visualizer') is None:
raise ValueError('Visualization needs the "visualizer" term'
'defined in the config, but got None.')
results = []
for single_input, pred in zip(inputs, preds):
assert 'img' in single_input or 'img_shape' in single_input
if 'img' in single_input:
if isinstance(single_input['img'], str):
img = mmcv.imread(single_input['img'])
img_name = osp.basename(single_input['img'])
elif isinstance(single_input['img'], np.ndarray):
img = single_input['img'].copy()
img_num = str(self.num_visualized_imgs).zfill(8)
img_name = f'{img_num}.jpg'
elif 'img_shape' in single_input:
img = np.zeros(single_input['img_shape'], dtype=np.uint8)
img_name = f'{img_num}.jpg'
else:
raise ValueError('Input does not contain either "img" or '
'"img_shape"')
out_file = osp.join(img_out_dir, img_name) if img_out_dir != '' \
else None
self.visualizer.add_datasample(
img_name,
img,
pred,
show=show,
wait_time=wait_time,
draw_gt=False,
draw_pred=draw_pred,
pred_score_thr=pred_score_thr,
out_file=out_file,
)
results.append(img)
self.num_visualized_imgs += 1
return results
def _pred2dict(self, data_sample: KIEDataSample) -> Dict:
"""Extract elements necessary to represent a prediction into a
dictionary. It's better to contain only basic data elements such as
strings and numbers in order to guarantee it's json-serializable.
Args:
data_sample (TextRecogDataSample): The data sample to be converted.
Returns:
dict: The output dictionary.
"""
result = {}
pred = data_sample.pred_instances
result['scores'] = pred.scores.cpu().numpy().tolist()
result['edge_scores'] = pred.edge_scores.cpu().numpy().tolist()
result['edge_labels'] = pred.edge_labels.cpu().numpy().tolist()
result['labels'] = pred.labels.cpu().numpy().tolist()
return result

View File

@ -0,0 +1,233 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from datetime import datetime
from typing import Dict, List, Optional, Tuple, Union
import mmcv
import numpy as np
from mmocr.registry import VISUALIZERS
from mmocr.structures.textdet_data_sample import TextDetDataSample
from mmocr.utils import ConfigType, bbox2poly, crop_img, poly2bbox
from .base_mmocr_inferencer import (BaseMMOCRInferencer, InputsType, PredType,
ResType)
from .kie_inferencer import KIEInferencer
from .textdet_inferencer import TextDetInferencer
from .textrec_inferencer import TextRecInferencer
class MMOCRInferencer(BaseMMOCRInferencer):
def __init__(self,
det_config: Optional[Union[ConfigType, str]] = None,
det_ckpt: Optional[str] = None,
rec_config: Optional[Union[ConfigType, str]] = None,
rec_ckpt: Optional[str] = None,
kie_config: Optional[Union[ConfigType, str]] = None,
kie_ckpt: Optional[str] = None,
device: Optional[str] = None,
**kwargs) -> None:
self.visualizer = None
self.base_params = self._dispatch_kwargs(*kwargs)
self.num_visualized_imgs = 0
if det_config is not None:
self.textdet_inferencer = TextDetInferencer(
det_config, det_ckpt, device)
self.mode = 'det'
if rec_config is not None:
self.textrec_inferencer = TextRecInferencer(
rec_config, rec_ckpt, device)
if getattr(self, 'mode', None) == 'det':
self.mode = 'det_rec'
ts = str(datetime.timestamp(datetime.now()))
self.visualizer = VISUALIZERS.build(
dict(
type='TextSpottingLocalVisualizer',
name=f'inferencer{ts}'))
else:
self.mode = 'rec'
if kie_config is not None:
if det_config is None or rec_config is None:
raise ValueError(
'kie_config is only applicable when det_config and '
'rec_config are both provided')
self.kie_inferencer = KIEInferencer(kie_config, kie_ckpt, device)
self.mode = 'det_rec_kie'
def preprocess(self, inputs: InputsType):
new_inputs = []
for single_input in inputs:
if isinstance(single_input, str):
if osp.isdir(single_input):
raise ValueError('Feeding a directory is not supported')
# for img_path in os.listdir(single_input):
# new_inputs.append(
# mmcv.imread(osp.join(single_input, img_path)))
else:
single_input = mmcv.imread(single_input)
new_inputs.append(single_input)
else:
new_inputs.append(single_input)
return new_inputs
def forward(self, inputs: InputsType) -> PredType:
"""Forward the inputs to the model.
Args:
inputs (InputsType): The inputs to be forwarded.
Returns:
Dict: The prediction results. Possibly with keys "det", "rec", and
"kie"..
"""
result = {}
if self.mode == 'rec':
# The extra list wrapper here is for the ease of postprocessing
self.rec_inputs = inputs
result['rec'] = [
self.textrec_inferencer(self.rec_inputs, get_datasample=True)
]
elif self.mode.startswith('det'):
result['det'] = self.textdet_inferencer(
inputs, get_datasample=True)
if self.mode.startswith('det_rec'):
result['rec'] = []
for img, det_data_sample in zip(inputs, result['det']):
det_pred = det_data_sample.pred_instances
self.rec_inputs = []
for polygon in det_pred['polygons']:
# Roughly convert the polygon to a quadangle with
# 4 points
quad = bbox2poly(poly2bbox(polygon)).tolist()
self.rec_inputs.append(crop_img(img, quad))
result['rec'].append(
self.textrec_inferencer(
self.rec_inputs, get_datasample=True))
if self.mode == 'det_rec_kie':
self.kie_inputs = []
for img, det_data_sample, rec_data_samples in zip(
inputs, result['det'], result['rec']):
det_pred = det_data_sample.pred_instances
kie_input = dict(img=img)
kie_input['instances'] = []
for polygon, rec_data_sample in zip(
det_pred['polygons'], rec_data_samples):
kie_input['instances'].append(
dict(
bbox=poly2bbox(polygon),
text=rec_data_sample.pred_text.item))
self.kie_inputs.append(kie_input)
result['kie'] = self.kie_inferencer(
self.kie_inputs, get_datasample=True)
return result
def visualize(self, inputs: InputsType, preds: PredType,
**kwargs) -> List[np.ndarray]:
"""Visualize predictions.
Args:
inputs (List[Union[str, np.ndarray]]): Inputs for the inferencer.
preds (List[Dict]): Predictions of the model.
show (bool): Whether to display the image in a popup window.
Defaults to False.
wait_time (float): The interval of show (s). Defaults to 0.
draw_pred (bool): Whether to draw predicted bounding boxes.
Defaults to True.
pred_score_thr (float): Minimum score of bboxes to draw.
Defaults to 0.3.
img_out_dir (str): Output directory of images. Defaults to ''.
"""
if 'kie' in self.mode:
return self.kie_inferencer.visualize(self.kie_inputs, preds['kie'],
**kwargs)
elif 'rec' in self.mode:
if 'det' in self.mode:
super().visualize(inputs, self._pack_e2e_datasamples(preds),
**kwargs)
else:
return self.textrec_inferencer.visualize(
self.rec_inputs, preds['rec'][0], **kwargs)
else:
return self.textdet_inferencer.visualize(inputs, preds['det'],
**kwargs)
def postprocess(self,
preds: PredType,
imgs: Optional[List[np.ndarray]] = None,
is_batch: bool = False,
print_result: bool = False,
pred_out_file: str = ''
) -> Union[ResType, Tuple[ResType, np.ndarray]]:
"""Postprocess predictions.
Args:
preds (Dict): Predictions of the model.
imgs (Optional[np.ndarray]): Visualized predictions.
is_batch (bool): Whether the inputs are in a batch.
Defaults to False.
print_result (bool): Whether to print the result.
Defaults to False.
pred_out_file (str): Output file name to store predictions
without images. Supported file formats are json, yaml/yml
and pickle/pkl. Defaults to ''.
Returns:
Dict or List[Dict]: Each dict contains the inference result of
each image. Possible keys are "det_polygons", "det_scores",
"rec_texts", "rec_scores", "kie_labels", "kie_scores",
"kie_edge_labels" and "kie_edge_scores".
"""
results = [{} for _ in range(len(next(iter(preds.values()))))]
if 'rec' in self.mode:
for i, rec_pred in enumerate(preds['rec']):
result = dict(rec_texts=[], rec_scores=[])
for rec_pred_instance in rec_pred:
pred = rec_pred_instance.pred_text
result['rec_texts'].append(pred.item)
result['rec_scores'].append(pred.score)
results[i].update(result)
if 'det' in self.mode:
for i, det_pred in enumerate(preds['det']):
det_pred_instances = det_pred.pred_instances
results[i].update(
dict(
det_polygons=det_pred_instances['polygons'],
det_scores=det_pred_instances['scores']))
if 'kie' in self.mode:
for i, kie_pred in enumerate(preds['kie']):
kie_pred_instances = kie_pred.pred_instances
results[i].update(
dict(
kie_labels=kie_pred_instances['labels'],
kie_scores=kie_pred_instances['scores']),
kie_edge_scores=kie_pred_instances['edge_scores'],
kie_edge_labels=kie_pred_instances['edge_labels'])
if not is_batch:
results = results[0]
if print_result:
print(results)
if pred_out_file != '':
mmcv.dump(results, pred_out_file)
if imgs is None:
return results
return results, imgs
def _pack_e2e_datasamples(self, preds: Dict) -> List[TextDetDataSample]:
"""Pack text detection and recognition results into a list of
TextDetDataSample.
Note that it is a temporary solution since the TextSpottingDataSample
is not ready.
"""
results = []
for det_data_sample, rec_data_samples in zip(preds['det'],
preds['rec']):
texts = []
for rec_data_sample in rec_data_samples:
texts.append(rec_data_sample.pred_text.item)
det_data_sample.pred_instances.texts = texts
results.append(det_data_sample)
return results

View File

@ -0,0 +1,27 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict
from mmocr.structures import TextDetDataSample
from .base_mmocr_inferencer import BaseMMOCRInferencer
class TextDetInferencer(BaseMMOCRInferencer):
def _pred2dict(self, data_sample: TextDetDataSample) -> Dict:
"""Extract elements necessary to represent a prediction into a
dictionary. It's better to contain only basic data elements such as
strings and numbers in order to guarantee it's json-serializable.
Args:
data_sample (TextDetDataSample): The data sample to be converted.
Returns:
dict: The output dictionary.
"""
result = {}
pred_instances = data_sample.pred_instances
result['polygons'] = []
for polygon in pred_instances.polygons:
result['polygons'].append(polygon.tolist())
result['scores'] = pred_instances.scores.cpu().numpy().tolist()
return result

View File

@ -0,0 +1,26 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict
import numpy as np
from mmocr.structures import TextRecogDataSample
from .base_mmocr_inferencer import BaseMMOCRInferencer
class TextRecInferencer(BaseMMOCRInferencer):
def _pred2dict(self, data_sample: TextRecogDataSample) -> Dict:
"""Extract elements necessary to represent a prediction into a
dictionary. It's better to contain only basic data elements such as
strings and numbers in order to guarantee it's json-serializable.
Args:
data_sample (TextRecogDataSample): The data sample to be converted.
Returns:
dict: The output dictionary.
"""
result = {}
result['text'] = data_sample.pred_text.item
result['scores'] = float(np.mean(data_sample.pred_text.score))
return result

View File

@ -1,7 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .adapters import MMDet2MMOCR, MMOCR2MMDet
from .formatting import PackKIEInputs, PackTextDetInputs, PackTextRecogInputs
from .loading import (LoadImageFromFile, LoadImageFromLMDB, LoadKIEAnnotations,
from .loading import (LoadImageFromFile, LoadImageFromLMDB,
LoadImageFromNDArray, LoadKIEAnnotations,
LoadOCRAnnotations)
from .ocr_transforms import RandomCrop, RandomRotate, Resize
from .textdet_transforms import (BoundedScaleAspectJitter, FixInvalidPolygon,
@ -18,5 +19,6 @@ __all__ = [
'PackTextRecogInputs', 'RescaleToHeight', 'PadToWidth',
'ShortScaleAspectJitter', 'RandomFlip', 'BoundedScaleAspectJitter',
'PackKIEInputs', 'LoadKIEAnnotations', 'FixInvalidPolygon', 'MMDet2MMOCR',
'MMOCR2MMDet', 'LoadImageFromLMDB', 'LoadImageFromFile'
'MMOCR2MMDet', 'LoadImageFromLMDB', 'LoadImageFromFile',
'LoadImageFromNDArray'
]

View File

@ -251,9 +251,7 @@ class PackKIEInputs(BaseTransform):
'gt_texts': 'texts',
}
def __init__(self,
meta_keys=('img_path', 'ori_shape', 'img_shape',
'scale_factor')):
def __init__(self, meta_keys=()):
self.meta_keys = meta_keys
def transform(self, results: dict) -> dict:

View File

@ -114,6 +114,54 @@ class LoadImageFromFile(MMCV_LoadImageFromFile):
return repr_str
@TRANSFORMS.register_module()
class LoadImageFromNDArray(LoadImageFromFile):
"""Load an image from ``results['img']``.
Similar with :obj:`LoadImageFromFile`, but the image has been loaded as
:obj:`np.ndarray` in ``results['img']``. Can be used when loading image
from webcam.
Required Keys:
- img
Modified Keys:
- img
- img_path
- img_shape
- ori_shape
Args:
to_float32 (bool): Whether to convert the loaded image to a float32
numpy array. If set to False, the loaded image is an uint8 array.
Defaults to False.
"""
def transform(self, results: dict) -> dict:
"""Transform function to add image meta information.
Args:
results (dict): Result dict with Webcam read image in
``results['img']``.
Returns:
dict: The dict contains loaded image and meta information.
"""
img = results['img']
if self.to_float32:
img = img.astype(np.float32)
if self.color_type == 'grayscale':
img = mmcv.image.rgb2gray(img)
results['img_path'] = None
results['img'] = img
results['img_shape'] = img.shape[:2]
results['ori_shape'] = img.shape[:2]
return results
@TRANSFORMS.register_module()
class LoadOCRAnnotations(MMCV_LoadAnnotations):
"""Load and process the ``instances`` annotation provided by dataset.

View File

@ -23,7 +23,7 @@ class SDMGRPostProcessor:
Defaults to 'none'. Options are:
- 'none': The simplest link type involving no edge
postprocessing. The edge prediction will be returned as it is.
postprocessing. The edge prediction will be returned as-is.
- 'one-to-one': One key node can be connected to one value node.
- 'one-to-many': One key node can be connected to multiple value
nodes.
@ -98,13 +98,13 @@ class SDMGRPostProcessor:
for i in range(len(data_samples)):
data_samples[i].pred_instances = InstanceData()
data_samples[i].pred_instances.labels = node_preds[i]
data_samples[i].pred_instances.scores = node_scores[i]
data_samples[i].pred_instances.labels = node_preds[i].cpu()
data_samples[i].pred_instances.scores = node_scores[i].cpu()
if self.link_type != 'none':
edge_scores[i], edge_preds[i] = self.decode_edges(
node_preds[i], edge_scores[i], edge_preds[i])
data_samples[i].pred_instances.edge_labels = edge_preds[i]
data_samples[i].pred_instances.edge_scores = edge_scores[i]
data_samples[i].pred_instances.edge_labels = edge_preds[i].cpu()
data_samples[i].pred_instances.edge_scores = edge_scores[i].cpu()
return data_samples
@ -167,4 +167,4 @@ class SDMGRPostProcessor:
elif self.link_type == 'many-to-one':
tmp_edge_scores[i, :] = -1
return new_edge_scores, new_edge_labels
return new_edge_scores.cpu(), new_edge_labels.cpu()

View File

@ -78,7 +78,7 @@ class BaseTextRecogPostprocessor:
data_sample (TextRecogDataSample): Datasample of an image.
Returns:
tuple(list[int], list[float]): index and score.
tuple(list[int], list[float]): Index and scores per-character.
"""
raise NotImplementedError

471
mmocr/ocr.py 100755
View File

@ -0,0 +1,471 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
import warnings
from argparse import ArgumentParser
from pathlib import Path
from typing import Dict, List, Optional, Union
import torch
from mmocr.apis.inferencers import MMOCRInferencer
from mmocr.apis.inferencers.base_mmocr_inferencer import InputsType
from mmocr.utils import register_all_modules
def parse_args():
parser = ArgumentParser()
parser.add_argument(
'img', type=str, help='Input image file or folder path.')
parser.add_argument(
'--img-out-dir',
type=str,
default='',
help='Output directory of images.')
parser.add_argument(
'--det',
type=str,
default=None,
help='Pretrained text detection algorithm')
parser.add_argument(
'--det-config',
type=str,
default=None,
help='Path to the custom config file of the selected det model. It '
'overrides the settings in det')
parser.add_argument(
'--det-ckpt',
type=str,
default=None,
help='Path to the custom checkpoint file of the selected det model. '
'It overrides the settings in det')
parser.add_argument(
'--recog',
type=str,
default=None,
help='Pretrained text recognition algorithm')
parser.add_argument(
'--recog-config',
type=str,
default=None,
help='Path to the custom config file of the selected recog model. It'
'overrides the settings in recog')
parser.add_argument(
'--recog-ckpt',
type=str,
default=None,
help='Path to the custom checkpoint file of the selected recog model. '
'It overrides the settings in recog')
parser.add_argument(
'--kie',
type=str,
default=None,
help='Pretrained key information extraction algorithm')
parser.add_argument(
'--kie-config',
type=str,
default=None,
help='Path to the custom config file of the selected kie model. It'
'overrides the settings in kie')
parser.add_argument(
'--kie-ckpt',
type=str,
default=None,
help='Path to the custom checkpoint file of the selected kie model. '
'It overrides the settings in kie')
parser.add_argument(
'--config-dir',
type=str,
default=os.path.join(str(Path.cwd()), 'configs/'),
help='Path to the config directory where all the config files '
'are located. Defaults to "configs/"')
parser.add_argument(
'--device',
type=str,
default='cuda',
help='Device used for inference.')
parser.add_argument(
'--show',
action='store_true',
help='Display the image in a popup window.')
parser.add_argument(
'--print-result',
action='store_true',
help='Whether to print the results.')
parser.add_argument(
'--pred-out-file',
type=str,
default='',
help='File to save the inference results.')
args = parser.parse_args()
# Warnings
if not os.path.samefile(args.config_dir, os.path.join(str(
Path.cwd()))) and (args.det_config != ''
or args.recog_config != ''):
warnings.warn(
'config_dir will be overridden by det-config or recog-config.',
UserWarning)
return args
class MMOCR:
"""MMOCR API for text detection, recognition, KIE inference.
Args:
det (str): Name of the detection model. Default to 'FCE_IC15'.
det_config (str): Path to the config file for the detection model.
Default to None.
det_ckpt (str): Path to the checkpoint file for the detection model.
Default to None.
recog (str): Name of the recognition model. Default to 'CRNN'.
recog_config (str): Path to the config file for the recognition model.
Default to None.
recog_ckpt (str): Path to the checkpoint file for the recognition
model. Default to None.
kie (str): Name of the KIE model. Default to None.
kie_config (str): Path to the config file for the KIE model. Default
to None.
kie_ckpt (str): Path to the checkpoint file for the KIE model.
Default to None.
config_dir (str): Path to the directory containing config files.
Default to 'configs/'.
device (torch.device): Device to use for inference. Default to 'cuda'.
"""
def __init__(self,
det: str = None,
det_config: str = None,
det_ckpt: str = None,
recog: str = None,
recog_config: str = None,
recog_ckpt: str = None,
kie: str = None,
kie_config: str = None,
kie_ckpt: str = None,
config_dir: str = os.path.join(str(Path.cwd()), 'configs/'),
device: torch.device = 'cuda',
**kwargs) -> None:
register_all_modules(init_default_scope=True)
self.config_dir = config_dir
inferencer_kwargs = {}
inferencer_kwargs.update(
self._get_inferencer_kwargs(det, det_config, det_ckpt, 'det_'))
inferencer_kwargs.update(
self._get_inferencer_kwargs(recog, recog_config, recog_ckpt,
'rec_'))
inferencer_kwargs.update(
self._get_inferencer_kwargs(kie, kie_config, kie_ckpt, 'kie_'))
self.inferencer = MMOCRInferencer(device=device, **inferencer_kwargs)
def _get_inferencer_kwargs(self, model: Optional[str],
config: Optional[str], ckpt: Optional[str],
prefix: str) -> Dict:
"""Get the kwargs for the inferencer."""
kwargs = {}
if model is not None:
cfgs = self.get_model_config(model)
kwargs[prefix + 'config'] = os.path.join(self.config_dir,
cfgs['config'])
kwargs[prefix + 'ckpt'] = 'https://download.openmmlab.com/' + \
f'mmocr/{cfgs["ckpt"]}'
if config is not None:
if kwargs.get(prefix + 'config', None) is not None:
warnings.warn(
f'{model}\'s default config is overridden by {config}',
UserWarning)
kwargs[prefix + 'config'] = config
if ckpt is not None:
if kwargs.get(prefix + 'ckpt', None) is not None:
warnings.warn(
f'{model}\'s default checkpoint is overridden by {ckpt}',
UserWarning)
kwargs[prefix + 'ckpt'] = ckpt
return kwargs
def readtext(self,
img: InputsType,
img_out_dir: str = '',
show: bool = False,
print_result: bool = False,
pred_out_file: str = '',
**kwargs) -> Union[Dict, List[Dict]]:
"""Inferences text detection, recognition, and KIE on an image or a
folder of images.
Args:
imgs (str or np.array or Sequence[str or np.array]): Img,
folder path, np array or list/tuple (with img
paths or np arrays).
img_out_dir (str): Output directory of images. Defaults to ''.
show (bool): Whether to display the image in a popup window.
Defaults to False.
print_result (bool): Whether to print the results.
pred_out_file (str): File to save the inference results. If left as
empty, no file will be saved.
Returns:
Dict or List[Dict]: Each dict contains the inference result of
each image. Possible keys are "det_polygons", "det_scores",
"rec_texts", "rec_scores", "kie_labels", "kie_scores",
"kie_edge_labels" and "kie_edge_scores".
"""
return self.inferencer(
img,
img_out_dir=img_out_dir,
show=show,
print_result=print_result,
pred_out_file=pred_out_file)
def get_model_config(self, model_name: str) -> Dict:
"""Get the model configuration including model config and checkpoint
url.
Args:
model_name (str): Name of the model.
Returns:
dict: Model configuration.
"""
model_dict = {
'Tesseract': {},
# Detection models
'DB_r18': {
'config':
'textdet/'
'dbnet/dbnet_resnet18_fpnc_1200e_icdar2015.py',
'ckpt':
'textdet/'
'dbnet/'
'dbnet_resnet18_fpnc_1200e_icdar2015/'
'dbnet_resnet18_fpnc_1200e_icdar2015_20220825_221614-7c0e94f2.pth' # noqa: E501
},
'DB_r50': {
'config':
'textdet/'
'dbnet/dbnet_resnet50-dcnv2_fpnc_1200e_icdar2015.py',
'ckpt':
'textdet/'
'dbnet/'
'dbnet_resnet50-dcnv2_fpnc_1200e_icdar2015'
'dbnet_resnet50-dcnv2_fpnc_1200e_icdar2015_20220828_124917-452c443c.pth' # noqa: E501
},
'DBPP_r50': {
'config':
'textdet/'
'dbnetpp/dbnetpp_resnet50-dcnv2_fpnc_1200e_icdar2015.py',
'ckpt':
'textdet/'
'dbnetpp/'
'dbnetpp_resnet50-dcnv2_fpnc_1200e_icdar2015/'
'dbnetpp_resnet50-dcnv2_fpnc_1200e_icdar2015_20220829_230108-f289bd20.pth' # noqa: E501
},
'DRRG': {
'config':
'textdet/'
'drrg/drrg_resnet50_fpn-unet_1200e_ctw1500.py',
'ckpt':
'textdet/'
'drrg/'
'drrg_resnet50_fpn-unet_1200e_ctw1500/'
'drrg_resnet50_fpn-unet_1200e_ctw1500_20220827_105233-d5c702dd.pth' # noqa: E501
},
'FCE_IC15': {
'config':
'textdet/'
'fcenet/fcenet_resnet50_fpn_1500e_icdar2015.py',
'ckpt':
'textdet/'
'fcenet/'
'fcenet_resnet50_fpn_1500e_icdar2015/'
'fcenet_resnet50_fpn_1500e_icdar2015_20220826_140941-167d9042.pth' # noqa: E501
},
'FCE_CTW_DCNv2': {
'config':
'textdet/'
'fcenet/fcenet_resnet50-dcnv2_fpn_1500e_ctw1500.py',
'ckpt':
'textdet/'
'fcenet/'
'fcenet_resnet50-dcnv2_fpn_1500e_ctw1500/'
'fcenet_resnet50-dcnv2_fpn_1500e_ctw1500_20220825_221510-4d705392.pth' # noqa: E501
},
'MaskRCNN_CTW': {
'config':
'textdet/'
'maskrcnn/mask-rcnn_resnet50_fpn_160e_ctw1500.py',
'ckpt':
'textdet/'
'maskrcnn/'
'mask-rcnn_resnet50_fpn_160e_ctw1500/'
'mask-rcnn_resnet50_fpn_160e_ctw1500_20220826_154755-ce68ee8e.pth' # noqa: E501
},
'MaskRCNN_IC15': {
'config':
'textdet/'
'maskrcnn/mask-rcnn_resnet50_fpn_160e_icdar2015.py',
'ckpt':
'textdet/'
'maskrcnn/'
'mask-rcnn_resnet50_fpn_160e_icdar2015/'
'mask-rcnn_resnet50_fpn_160e_icdar2015_20220826_154808-ff5c30bf.pth' # noqa: E501
},
# 'MaskRCNN_IC17': {
# 'config':
# 'textdet/'
# 'maskrcnn/mask-rcnn_resnet50_fpn_160e_icdar2017.py',
# 'ckpt':
# 'textdet/'
# 'maskrcnn/'
# ''
# },
'PANet_CTW': {
'config':
'textdet/'
'panet/panet_resnet18_fpem-ffm_600e_ctw1500.py',
'ckpt':
'textdet/'
'panet/'
'panet_resnet18_fpem-ffm_600e_ctw1500/'
'panet_resnet18_fpem-ffm_600e_ctw1500_20220826_144818-980f32d0.pth' # noqa: E501
},
'PANet_IC15': {
'config':
'textdet/'
'panet/panet_resnet18_fpem-ffm_600e_icdar2015.py',
'ckpt':
'textdet/'
'panet/'
'panet_resnet18_fpem-ffm_600e_icdar2015/'
'panet_resnet18_fpem-ffm_600e_icdar2015_20220826_144817-be2acdb4.pth' # noqa: E501
},
'PS_CTW': {
'config':
'textdet/'
'psenet/psenet_resnet50_fpnf_600e_ctw1500.py',
'ckpt':
'textdet/'
'psenet/'
'psenet_resnet50_fpnf_600e_ctw1500/'
'psenet_resnet50_fpnf_600e_ctw1500_20220825_221459-7f974ac8.pth' # noqa: E501
},
'PS_IC15': {
'config':
'textdet/'
'psenet/psenet_resnet50_fpnf_600e_icdar2015.py',
'ckpt':
'textdet/'
'psenet/'
'psenet_resnet50_fpnf_600e_icdar2015/'
'psenet_resnet50_fpnf_600e_icdar2015_20220825_222709-b6741ec3.pth' # noqa: E501
},
'TextSnake': {
'config':
'textdet/'
'textsnake/textsnake_resnet50_fpn-unet_1200e_ctw1500.py',
'ckpt':
'textdet/'
'textsnake/'
'textsnake_resnet50_fpn-unet_1200e_ctw1500/'
'textsnake_resnet50_fpn-unet_1200e_ctw1500_20220825_221459-c0b6adc4.pth' # noqa: E501
},
# Recognition models
'CRNN': {
'config':
'textrecog/crnn/crnn_mini-vgg_5e_mj.py',
'ckpt':
'textrecog/crnn/crnn_mini-vgg_5e_mj/crnn_mini-vgg_5e_mj_20220826_224120-8afbedbb.pth' # noqa: E501
},
# 'SAR': {
# 'config':
# 'textrecog/sar/'
# 'sar_resnet31_parallel-decoder_5e_st-sub_mj-sub_sa_real.py',
# 'ckpt':
# ''
# },
# 'SAR_CN': {
# 'config':
# 'textrecog/'
# 'sar/sar_r31_parallel_decoder_chinese.py',
# 'ckpt':
# 'textrecog/'
# ''
# },
# 'NRTR_1/16-1/8': {
# 'config':
# 'textrecog/'
# 'nrtr/nrtr_resnet31-1by16-1by8_6e_st_mj.py',
# 'ckpt':
# 'textrecog/'
# ''
# },
# 'NRTR_1/8-1/4': {
# 'config':
# 'textrecog/'
# 'nrtr/nrtr_resnet31-1by8-1by4_6e_st_mj.py',
# 'ckpt':
# 'textrecog/'
# ''
# },
# 'RobustScanner': {
# 'config':
# 'textrecog/robust_scanner/'
# 'robustscanner_resnet31_5e_st-sub_mj-sub_sa_real.py',
# 'ckpt':
# 'textrecog/'
# ''
# },
# 'SATRN': {
# 'config': 'textrecog/satrn/satrn_shallow_5e_st_mj.py',
# 'ckpt': ''
# },
# 'SATRN_sm': {
# 'config': 'textrecog/satrn/satrn_shallow-small_5e_st_mj.py',
# 'ckpt': ''
# },
# 'ABINet': {
# 'config': 'textrecog/abinet/abinet_20e_st-an_mj.py',
# 'ckpt': ''
# },
# 'ABINet_Vision': {
# 'config': 'textrecog/abinet/abinet-vision_20e_st-an_mj.py',
# 'ckpt': ''
# },
# 'CRNN_TPS': {
# 'config':
# 'textrecog/tps/crnn_tps_academic_dataset.py',
# 'ckpt':
# ''
# },
# 'MASTER': {
# 'config': 'textrecog/master/master_resnet31_12e_st_mj_sa.py',
# 'ckpt': ''
# },
# KIE models
'SDMGR': {
'config':
'kie/sdmgr/sdmgr_unet16_60e_wildreceipt.py',
'ckpt':
'kie/'
'sdmgr/'
'sdmgr_unet16_60e_wildreceipt/'
'sdmgr_unet16_60e_wildreceipt_20220825_151648-22419f37.pth'
}
}
if model_name not in model_dict:
raise ValueError(f'Model {model_name} is not supported.')
else:
return model_dict[model_name]
def main():
args = parse_args()
ocr = MMOCR(**vars(args))
ocr.readtext(**vars(args))
if __name__ == '__main__':
main()

View File

@ -1,877 +0,0 @@
#!/usr/bin/env python
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os
import warnings
from argparse import ArgumentParser, Namespace
from pathlib import Path
import mmcv
import numpy as np
import torch
from mmcv.image.misc import tensor2imgs
from mmcv.runner import load_checkpoint
from mmcv.utils.config import Config
from PIL import Image
try:
import tesserocr
except ImportError:
tesserocr = None
from mmocr.apis import init_detector
from mmocr.apis.inference import model_inference
from mmocr.datasets import WildReceiptDataset
from mmocr.models.textdet.detectors import TextDetectorMixin
from mmocr.models.textrecog.recognizers import BaseRecognizer
from mmocr.registry import MODELS
from mmocr.utils import is_type_list, stitch_boxes_into_lines
from mmocr.utils.fileio import list_from_file
from mmocr.utils.img_utils import crop_img
from mmocr.utils.model import revert_sync_batchnorm
from mmocr.visualization.visualize import det_recog_show_result
# Parse CLI arguments
def parse_args():
parser = ArgumentParser()
parser.add_argument(
'img', type=str, help='Input image file or folder path.')
parser.add_argument(
'--output',
type=str,
default='',
help='Output file/folder name for visualization')
parser.add_argument(
'--det',
type=str,
default='PANet_IC15',
help='Pretrained text detection algorithm')
parser.add_argument(
'--det-config',
type=str,
default='',
help='Path to the custom config file of the selected det model. It '
'overrides the settings in det')
parser.add_argument(
'--det-ckpt',
type=str,
default='',
help='Path to the custom checkpoint file of the selected det model. '
'It overrides the settings in det')
parser.add_argument(
'--recog',
type=str,
default='SEG',
help='Pretrained text recognition algorithm')
parser.add_argument(
'--recog-config',
type=str,
default='',
help='Path to the custom config file of the selected recog model. It'
'overrides the settings in recog')
parser.add_argument(
'--recog-ckpt',
type=str,
default='',
help='Path to the custom checkpoint file of the selected recog model. '
'It overrides the settings in recog')
parser.add_argument(
'--kie',
type=str,
default='',
help='Pretrained key information extraction algorithm')
parser.add_argument(
'--kie-config',
type=str,
default='',
help='Path to the custom config file of the selected kie model. It'
'overrides the settings in kie')
parser.add_argument(
'--kie-ckpt',
type=str,
default='',
help='Path to the custom checkpoint file of the selected kie model. '
'It overrides the settings in kie')
parser.add_argument(
'--config-dir',
type=str,
default=os.path.join(str(Path.cwd()), 'configs/'),
help='Path to the config directory where all the config files '
'are located. Defaults to "configs/"')
parser.add_argument(
'--batch-mode',
action='store_true',
help='Whether use batch mode for inference')
parser.add_argument(
'--recog-batch-size',
type=int,
default=0,
help='Batch size for text recognition')
parser.add_argument(
'--det-batch-size',
type=int,
default=0,
help='Batch size for text detection')
parser.add_argument(
'--single-batch-size',
type=int,
default=0,
help='Batch size for separate det/recog inference')
parser.add_argument(
'--device', default=None, help='Device used for inference.')
parser.add_argument(
'--export',
type=str,
default='',
help='Folder where the results of each image are exported')
parser.add_argument(
'--export-format',
type=str,
default='json',
help='Format of the exported result file(s)')
parser.add_argument(
'--details',
action='store_true',
help='Whether include the text boxes coordinates and confidence values'
)
parser.add_argument(
'--imshow',
action='store_true',
help='Whether show image with OpenCV.')
parser.add_argument(
'--print-result',
action='store_true',
help='Prints the recognised text')
parser.add_argument(
'--merge', action='store_true', help='Merge neighboring boxes')
parser.add_argument(
'--merge-xdist',
type=float,
default=20,
help='The maximum x-axis distance to merge boxes')
args = parser.parse_args()
if args.det == 'None':
args.det = None
if args.recog == 'None':
args.recog = None
# Warnings
if args.merge and not (args.det and args.recog):
warnings.warn(
'Box merging will not work if the script is not'
' running in detection + recognition mode.', UserWarning)
if not os.path.samefile(args.config_dir, os.path.join(str(
Path.cwd()))) and (args.det_config != ''
or args.recog_config != ''):
warnings.warn(
'config_dir will be overridden by det-config or recog-config.',
UserWarning)
return args
class MMOCR:
def __init__(self,
det='PANet_IC15',
det_config='',
det_ckpt='',
recog='SEG',
recog_config='',
recog_ckpt='',
kie='',
kie_config='',
kie_ckpt='',
config_dir=os.path.join(str(Path.cwd()), 'configs/'),
device=None,
**kwargs):
textdet_models = {
'DB_r18': {
'config':
'dbnet/dbnet_r18_fpnc_1200e_icdar2015.py',
'ckpt':
'dbnet/'
'dbnet_r18_fpnc_sbn_1200e_icdar2015_20210329-ba3ab597.pth'
},
'DB_r50': {
'config':
'dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py',
'ckpt':
'dbnet/'
'dbnet_r50dcnv2_fpnc_sbn_1200e_icdar2015_20211025-9fe3b590.pth'
},
'DBPP_r50': {
'config':
'dbnetpp/dbnetpp_r50dcnv2_fpnc_1200e_icdar2015.py',
'ckpt':
'dbnet/'
'dbnetpp_r50dcnv2_fpnc_1200e_icdar2015-20220502-d7a76fff.pth'
},
'DRRG': {
'config':
'drrg/drrg_r50_fpn_unet_1200e_ctw1500.py',
'ckpt':
'drrg/drrg_r50_fpn_unet_1200e_ctw1500_20211022-fb30b001.pth'
},
'FCE_IC15': {
'config':
'fcenet/fcenet_r50_fpn_1500e_icdar2015.py',
'ckpt':
'fcenet/fcenet_r50_fpn_1500e_icdar2015_20211022-daefb6ed.pth'
},
'FCE_CTW_DCNv2': {
'config':
'fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500.py',
'ckpt':
'fcenet/' +
'fcenet_r50dcnv2_fpn_1500e_ctw1500_20211022-e326d7ec.pth'
},
'MaskRCNN_CTW': {
'config':
'maskrcnn/mask_rcnn_r50_fpn_160e_ctw1500.py',
'ckpt':
'maskrcnn/'
'mask_rcnn_r50_fpn_160e_ctw1500_20210219-96497a76.pth'
},
'MaskRCNN_IC15': {
'config':
'maskrcnn/mask_rcnn_r50_fpn_160e_icdar2015.py',
'ckpt':
'maskrcnn/'
'mask_rcnn_r50_fpn_160e_icdar2015_20210219-8eb340a3.pth'
},
'MaskRCNN_IC17': {
'config':
'maskrcnn/mask_rcnn_r50_fpn_160e_icdar2017.py',
'ckpt':
'maskrcnn/'
'mask_rcnn_r50_fpn_160e_icdar2017_20210218-c6ec3ebb.pth'
},
'PANet_CTW': {
'config':
'panet/panet_r18_fpem_ffm_600e_ctw1500.py',
'ckpt':
'panet/'
'panet_r18_fpem_ffm_sbn_600e_ctw1500_20210219-3b3a9aa3.pth'
},
'PANet_IC15': {
'config':
'panet/panet_r18_fpem_ffm_600e_icdar2015.py',
'ckpt':
'panet/'
'panet_r18_fpem_ffm_sbn_600e_icdar2015_20210219-42dbe46a.pth'
},
'PS_CTW': {
'config': 'psenet/psenet_r50_fpnf_600e_ctw1500.py',
'ckpt':
'psenet/psenet_r50_fpnf_600e_ctw1500_20210401-216fed50.pth'
},
'PS_IC15': {
'config':
'psenet/psenet_r50_fpnf_600e_icdar2015.py',
'ckpt':
'psenet/psenet_r50_fpnf_600e_icdar2015_pretrain-eefd8fe6.pth'
},
'TextSnake': {
'config':
'textsnake/textsnake_r50_fpn_unet_1200e_ctw1500.py',
'ckpt':
'textsnake/textsnake_r50_fpn_unet_1200e_ctw1500-27f65b64.pth'
},
'Tesseract': {}
}
textrecog_models = {
'CRNN': {
'config': 'crnn/crnn_academic_dataset.py',
'ckpt': 'crnn/crnn_academic-a723a1c5.pth'
},
'SAR': {
'config': 'sar/sar_r31_parallel_decoder_academic.py',
'ckpt': 'sar/sar_r31_parallel_decoder_academic-dba3a4a3.pth'
},
'SAR_CN': {
'config':
'sar/sar_r31_parallel_decoder_chinese.py',
'ckpt':
'sar/sar_r31_parallel_decoder_chineseocr_20210507-b4be8214.pth'
},
'NRTR_1/16-1/8': {
'config': 'nrtr/nrtr_r31_1by16_1by8_academic.py',
'ckpt':
'nrtr/nrtr_r31_1by16_1by8_academic_20211124-f60cebf4.pth'
},
'NRTR_1/8-1/4': {
'config': 'nrtr/nrtr_r31_1by8_1by4_academic.py',
'ckpt':
'nrtr/nrtr_r31_1by8_1by4_academic_20211123-e1fdb322.pth'
},
'RobustScanner': {
'config': 'robust_scanner/robustscanner_r31_academic.py',
'ckpt': 'robustscanner/robustscanner_r31_academic-5f05874f.pth'
},
'SATRN': {
'config': 'satrn/satrn_academic.py',
'ckpt': 'satrn/satrn_academic_20211009-cb8b1580.pth'
},
'SATRN_sm': {
'config': 'satrn/satrn_small.py',
'ckpt': 'satrn/satrn_small_20211009-2cf13355.pth'
},
'ABINet': {
'config': 'abinet/abinet_academic.py',
'ckpt': 'abinet/abinet_academic-f718abf6.pth'
},
'ABINet_Vision': {
'config': 'abinet/abinet_vision_only_academic.py',
'ckpt': 'abinet/abinet_vision_only_academic-e6b9ea89.pth'
},
'SEG': {
'config': 'seg/seg_r31_1by16_fpnocr_academic.py',
'ckpt': 'seg/seg_r31_1by16_fpnocr_academic-72235b11.pth'
},
'CRNN_TPS': {
'config': 'tps/crnn_tps_academic_dataset.py',
'ckpt': 'tps/crnn_tps_academic_dataset_20210510-d221a905.pth'
},
'Tesseract': {},
'MASTER': {
'config': 'master/master_r31_12e_ST_MJ_SA.py',
'ckpt': 'master/master_r31_12e_ST_MJ_SA-787edd36.pth'
}
}
kie_models = {
'SDMGR': {
'config': 'sdmgr/sdmgr_unet16_60e_wildreceipt.py',
'ckpt':
'sdmgr/sdmgr_unet16_60e_wildreceipt_20210520-7489e6de.pth'
}
}
self.td = det
self.tr = recog
self.kie = kie
self.device = device
if self.device is None:
self.device = torch.device(
'cuda' if torch.cuda.is_available() else 'cpu')
# Check if the det/recog model choice is valid
if self.td and self.td not in textdet_models:
raise ValueError(self.td,
'is not a supported text detection algorthm')
elif self.tr and self.tr not in textrecog_models:
raise ValueError(self.tr,
'is not a supported text recognition algorithm')
elif self.kie:
if self.kie not in kie_models:
raise ValueError(
self.kie, 'is not a supported key information extraction'
' algorithm')
elif not (self.td and self.tr):
raise NotImplementedError(
self.kie, 'has to run together'
' with text detection and recognition algorithms.')
self.detect_model = None
if self.td and self.td == 'Tesseract':
if tesserocr is None:
raise ImportError('Please install tesserocr first. '
'Check out the installation guide at '
'https://github.com/sirfz/tesserocr')
self.detect_model = 'Tesseract_det'
elif self.td:
# Build detection model
if not det_config:
det_config = os.path.join(config_dir, 'textdet/',
textdet_models[self.td]['config'])
if not det_ckpt:
det_ckpt = 'https://download.openmmlab.com/mmocr/textdet/' + \
textdet_models[self.td]['ckpt']
self.detect_model = init_detector(
det_config, det_ckpt, device=self.device)
self.detect_model = revert_sync_batchnorm(self.detect_model)
self.recog_model = None
if self.tr and self.tr == 'Tesseract':
if tesserocr is None:
raise ImportError('Please install tesserocr first. '
'Check out the installation guide at '
'https://github.com/sirfz/tesserocr')
self.recog_model = 'Tesseract_recog'
elif self.tr:
# Build recognition model
if not recog_config:
recog_config = os.path.join(
config_dir, 'textrecog/',
textrecog_models[self.tr]['config'])
if not recog_ckpt:
recog_ckpt = 'https://download.openmmlab.com/mmocr/' + \
'textrecog/' + textrecog_models[self.tr]['ckpt']
self.recog_model = init_detector(
recog_config, recog_ckpt, device=self.device)
self.recog_model = revert_sync_batchnorm(self.recog_model)
self.kie_model = None
if self.kie:
# Build key information extraction model
if not kie_config:
kie_config = os.path.join(config_dir, 'kie/',
kie_models[self.kie]['config'])
if not kie_ckpt:
kie_ckpt = 'https://download.openmmlab.com/mmocr/' + \
'kie/' + kie_models[self.kie]['ckpt']
kie_cfg = Config.fromfile(kie_config)
self.kie_model = MODELS.build(
kie_cfg.model, test_cfg=kie_cfg.get('test_cfg'))
self.kie_model = revert_sync_batchnorm(self.kie_model)
self.kie_model.cfg = kie_cfg
load_checkpoint(self.kie_model, kie_ckpt, map_location=self.device)
# Attribute check
for model in list(filter(None, [self.recog_model, self.detect_model])):
if hasattr(model, 'module'):
model = model.module
@staticmethod
def get_tesserocr_api():
"""Get tesserocr api depending on different platform."""
import subprocess
import sys
if sys.platform == 'linux':
api = tesserocr.PyTessBaseAPI()
elif sys.platform == 'win32':
try:
p = subprocess.Popen(
'where tesseract', stdout=subprocess.PIPE, shell=True)
s = p.communicate()[0].decode('utf-8').split('\\')
path = s[:-1] + ['tessdata']
tessdata_path = '/'.join(path)
api = tesserocr.PyTessBaseAPI(path=tessdata_path)
except RuntimeError:
raise RuntimeError(
'Please install tesseract first.\n Check out the'
' installation guide at'
' https://github.com/UB-Mannheim/tesseract/wiki')
else:
raise NotImplementedError
return api
def tesseract_det_inference(self, imgs, **kwargs):
"""Inference image(s) with the tesseract detector.
Args:
imgs (ndarray or list[ndarray]): image(s) to inference.
Returns:
result (dict): Predicted results.
"""
is_batch = True
if isinstance(imgs, np.ndarray):
is_batch = False
imgs = [imgs]
assert is_type_list(imgs, np.ndarray)
api = self.get_tesserocr_api()
# Get detection result using tesseract
results = []
for img in imgs:
image = Image.fromarray(img)
api.SetImage(image)
boxes = api.GetComponentImages(tesserocr.RIL.TEXTLINE, True)
boundaries = []
for _, box, _, _ in boxes:
min_x = box['x']
min_y = box['y']
max_x = box['x'] + box['w']
max_y = box['y'] + box['h']
boundary = [
min_x, min_y, max_x, min_y, max_x, max_y, min_x, max_y, 1.0
]
boundaries.append(boundary)
results.append({'boundary_result': boundaries})
# close tesserocr api
api.End()
if not is_batch:
return results[0]
else:
return results
def tesseract_recog_inference(self, imgs, **kwargs):
"""Inference image(s) with the tesseract recognizer.
Args:
imgs (ndarray or list[ndarray]): image(s) to inference.
Returns:
result (dict): Predicted results.
"""
is_batch = True
if isinstance(imgs, np.ndarray):
is_batch = False
imgs = [imgs]
assert is_type_list(imgs, np.ndarray)
api = self.get_tesserocr_api()
results = []
for img in imgs:
image = Image.fromarray(img)
api.SetImage(image)
api.SetRectangle(0, 0, img.shape[1], img.shape[0])
# Remove beginning and trailing spaces from Tesseract
text = api.GetUTF8Text().strip()
conf = api.MeanTextConf() / 100
results.append({'text': text, 'score': conf})
# close tesserocr api
api.End()
if not is_batch:
return results[0]
else:
return results
def readtext(self,
img,
output=None,
details=False,
export=None,
export_format='json',
batch_mode=False,
recog_batch_size=0,
det_batch_size=0,
single_batch_size=0,
imshow=False,
print_result=False,
merge=False,
merge_xdist=20,
**kwargs):
args = locals().copy()
[args.pop(x, None) for x in ['kwargs', 'self']]
args = Namespace(**args)
# Input and output arguments processing
self._args_processing(args)
self.args = args
pp_result = None
# Send args and models to the MMOCR model inference API
# and call post-processing functions for the output
if self.detect_model and self.recog_model:
det_recog_result = self.det_recog_kie_inference(
self.detect_model, self.recog_model, kie_model=self.kie_model)
pp_result = self.det_recog_pp(det_recog_result)
else:
for model in list(
filter(None, [self.recog_model, self.detect_model])):
result = self.single_inference(model, args.arrays,
args.batch_mode,
args.single_batch_size)
pp_result = self.single_pp(result, model)
return pp_result
# Post processing function for end2end ocr
def det_recog_pp(self, result):
final_results = []
args = self.args
for arr, output, export, det_recog_result in zip(
args.arrays, args.output, args.export, result):
if output or args.imshow:
if self.kie_model:
res_img = det_recog_show_result(arr, det_recog_result)
else:
res_img = det_recog_show_result(
arr, det_recog_result, out_file=output)
if args.imshow and not self.kie_model:
mmcv.imshow(res_img, 'inference results')
if not args.details:
simple_res = {}
simple_res['filename'] = det_recog_result['filename']
simple_res['text'] = [
x['text'] for x in det_recog_result['result']
]
final_result = simple_res
else:
final_result = det_recog_result
if export:
mmcv.dump(final_result, export, indent=4)
if args.print_result:
print(final_result, end='\n\n')
final_results.append(final_result)
return final_results
# Post processing function for separate det/recog inference
def single_pp(self, result, model):
for arr, output, export, res in zip(self.args.arrays, self.args.output,
self.args.export, result):
if export:
mmcv.dump(res, export, indent=4)
if output or self.args.imshow:
if model == 'Tesseract_det':
res_img = TextDetectorMixin(show_score=False).show_result(
arr, res, out_file=output)
elif model == 'Tesseract_recog':
res_img = BaseRecognizer.show_result(
arr, res, out_file=output)
else:
res_img = model.show_result(arr, res, out_file=output)
if self.args.imshow:
mmcv.imshow(res_img, 'inference results')
if self.args.print_result:
print(res, end='\n\n')
return result
def generate_kie_labels(self, result, boxes, class_list):
idx_to_cls = {}
if class_list is not None:
for line in list_from_file(class_list):
class_idx, class_label = line.strip().split()
idx_to_cls[class_idx] = class_label
max_value, max_idx = torch.max(result['nodes'].detach().cpu(), -1)
node_pred_label = max_idx.numpy().tolist()
node_pred_score = max_value.numpy().tolist()
labels = []
for i in range(len(boxes)):
pred_label = str(node_pred_label[i])
if pred_label in idx_to_cls:
pred_label = idx_to_cls[pred_label]
pred_score = node_pred_score[i]
labels.append((pred_label, pred_score))
return labels
def visualize_kie_output(self,
model,
data,
result,
out_file=None,
show=False):
"""Visualizes KIE output."""
img_tensor = data['img'].data
img_meta = data['img_metas'].data
gt_bboxes = data['gt_bboxes'].data.numpy().tolist()
if img_tensor.dtype == torch.uint8:
# The img tensor is the raw input not being normalized
# (For SDMGR non-visual)
img = img_tensor.cpu().numpy().transpose(1, 2, 0)
else:
img = tensor2imgs(
img_tensor.unsqueeze(0), **img_meta.get('img_norm_cfg', {}))[0]
h, w, _ = img_meta.get('img_shape', img.shape)
img_show = img[:h, :w, :]
model.show_result(
img_show, result, gt_bboxes, show=show, out_file=out_file)
# End2end ocr inference pipeline
def det_recog_kie_inference(self, det_model, recog_model, kie_model=None):
end2end_res = []
# Find bounding boxes in the images (text detection)
det_result = self.single_inference(det_model, self.args.arrays,
self.args.batch_mode,
self.args.det_batch_size)
bboxes_list = [res['boundary_result'] for res in det_result]
if kie_model:
kie_dataset = WildReceiptDataset(
dict_file=kie_model.cfg.data.test.dict_file)
# For each bounding box, the image is cropped and
# sent to the recognition model either one by one
# or all together depending on the batch_mode
for filename, arr, bboxes, out_file in zip(self.args.filenames,
self.args.arrays,
bboxes_list,
self.args.output):
img_e2e_res = {}
img_e2e_res['filename'] = filename
img_e2e_res['result'] = []
box_imgs = []
for bbox in bboxes:
box_res = {}
box_res['box'] = [round(x) for x in bbox[:-1]]
box_res['box_score'] = float(bbox[-1])
box = bbox[:8]
if len(bbox) > 9:
min_x = min(bbox[0:-1:2])
min_y = min(bbox[1:-1:2])
max_x = max(bbox[0:-1:2])
max_y = max(bbox[1:-1:2])
box = [
min_x, min_y, max_x, min_y, max_x, max_y, min_x, max_y
]
box_img = crop_img(arr, box)
if self.args.batch_mode:
box_imgs.append(box_img)
else:
if recog_model == 'Tesseract_recog':
recog_result = self.single_inference(
recog_model, box_img, batch_mode=True)
else:
recog_result = model_inference(recog_model, box_img)
text = recog_result['text']
text_score = recog_result['score']
if isinstance(text_score, list):
text_score = sum(text_score) / max(1, len(text))
box_res['text'] = text
box_res['text_score'] = text_score
img_e2e_res['result'].append(box_res)
if self.args.batch_mode:
recog_results = self.single_inference(
recog_model, box_imgs, True, self.args.recog_batch_size)
for i, recog_result in enumerate(recog_results):
text = recog_result['text']
text_score = recog_result['score']
if isinstance(text_score, (list, tuple)):
text_score = sum(text_score) / max(1, len(text))
img_e2e_res['result'][i]['text'] = text
img_e2e_res['result'][i]['text_score'] = text_score
if self.args.merge:
img_e2e_res['result'] = stitch_boxes_into_lines(
img_e2e_res['result'], self.args.merge_xdist, 0.5)
if kie_model:
annotations = copy.deepcopy(img_e2e_res['result'])
# Customized for kie_dataset, which
# assumes that boxes are represented by only 4 points
for i, ann in enumerate(annotations):
min_x = min(ann['box'][::2])
min_y = min(ann['box'][1::2])
max_x = max(ann['box'][::2])
max_y = max(ann['box'][1::2])
annotations[i]['box'] = [
min_x, min_y, max_x, min_y, max_x, max_y, min_x, max_y
]
ann_info = kie_dataset._parse_anno_info(annotations)
ann_info['ori_bboxes'] = ann_info.get('ori_bboxes',
ann_info['bboxes'])
ann_info['gt_bboxes'] = ann_info.get('gt_bboxes',
ann_info['bboxes'])
kie_result, data = model_inference(
kie_model,
arr,
ann=ann_info,
return_data=True,
batch_mode=self.args.batch_mode)
# visualize KIE results
self.visualize_kie_output(
kie_model,
data,
kie_result,
out_file=out_file,
show=self.args.imshow)
gt_bboxes = data['gt_bboxes'].data.numpy().tolist()
labels = self.generate_kie_labels(kie_result, gt_bboxes,
kie_model.class_list)
for i in range(len(gt_bboxes)):
img_e2e_res['result'][i]['label'] = labels[i][0]
img_e2e_res['result'][i]['label_score'] = labels[i][1]
end2end_res.append(img_e2e_res)
return end2end_res
# Separate det/recog inference pipeline
def single_inference(self, model, arrays, batch_mode, batch_size=0):
def inference(m, a, **kwargs):
if model == 'Tesseract_det':
return self.tesseract_det_inference(a)
elif model == 'Tesseract_recog':
return self.tesseract_recog_inference(a)
else:
return model_inference(m, a, **kwargs)
result = []
if batch_mode:
if batch_size == 0:
result = inference(model, arrays, batch_mode=True)
else:
n = batch_size
arr_chunks = [
arrays[i:i + n] for i in range(0, len(arrays), n)
]
for chunk in arr_chunks:
result.extend(inference(model, chunk, batch_mode=True))
else:
for arr in arrays:
result.append(inference(model, arr, batch_mode=False))
return result
# Arguments pre-processing function
def _args_processing(self, args):
# Check if the input is a list/tuple that
# contains only np arrays or strings
if isinstance(args.img, (list, tuple)):
img_list = args.img
if not all([isinstance(x, (np.ndarray, str)) for x in args.img]):
raise AssertionError('Images must be strings or numpy arrays')
# Create a list of the images
if isinstance(args.img, str):
img_path = Path(args.img)
if img_path.is_dir():
img_list = [str(x) for x in img_path.glob('*')]
else:
img_list = [str(img_path)]
elif isinstance(args.img, np.ndarray):
img_list = [args.img]
# Read all image(s) in advance to reduce wasted time
# re-reading the images for visualization output
args.arrays = [mmcv.imread(x) for x in img_list]
# Create a list of filenames (used for output images and result files)
if isinstance(img_list[0], str):
args.filenames = [str(Path(x).stem) for x in img_list]
else:
args.filenames = [str(x) for x in range(len(img_list))]
# If given an output argument, create a list of output image filenames
num_res = len(img_list)
if args.output:
output_path = Path(args.output)
if output_path.is_dir():
args.output = [
str(output_path / f'out_{x}.png') for x in args.filenames
]
else:
args.output = [str(args.output)]
if args.batch_mode:
raise AssertionError('Output of multiple images inference'
' must be a directory')
else:
args.output = [None] * num_res
# If given an export argument, create a list of
# result filenames for each image
if args.export:
export_path = Path(args.export)
args.export = [
str(export_path / f'out_{x}.{args.export_format}')
for x in args.filenames
]
else:
args.export = [None] * num_res
return args
# Create an inference pipeline with parsed arguments
def main():
args = parse_args()
ocr = MMOCR(**vars(args))
ocr.readtext(**vars(args))
if __name__ == '__main__':
main()

View File

@ -16,6 +16,7 @@ def register_all_modules(init_default_scope: bool = True) -> None:
to https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/registry.md
Defaults to True.
""" # noqa
import mmocr.apis # noqa: F401,F403
import mmocr.datasets # noqa: F401,F403
import mmocr.engine # noqa: F401,F403
import mmocr.evaluation # noqa: F401,F403

View File

@ -174,8 +174,6 @@ class TestPackKIEInputs(TestCase):
torch.int64)
self.assertIsInstance(data_sample.gt_instances.texts, list)
self.assertIn('img_path', data_sample)
transform = PackKIEInputs(meta_keys=('img_path', ))
results = transform(copy.deepcopy(datainfo))
self.assertIn('inputs', results)
@ -191,7 +189,4 @@ class TestPackKIEInputs(TestCase):
self.assertEqual(results['inputs'].shape, torch.Size((0, 0, 0)))
def test_repr(self):
self.assertEqual(
repr(self.transform),
("PackKIEInputs(meta_keys=('img_path', 'ori_shape', "
"'img_shape', 'scale_factor'))"))
self.assertEqual(repr(self.transform), ('PackKIEInputs(meta_keys=())'))