mirror of https://github.com/open-mmlab/mmocr.git
[Enhancement] Support batch visualization & dumping in Inferencer (#1722)
* [Enhancement] Support batch visualization & dumping in Inferencer * fix empty det output * Update mmocr/apis/inferencers/base_mmocr_inferencer.py Co-authored-by: liukuikun <24622904+Harold-lkk@users.noreply.github.com> --------- Co-authored-by: liukuikun <24622904+Harold-lkk@users.noreply.github.com>pull/1737/head
parent
1127240108
commit
e9bf689f74
|
@ -1,6 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
||||
from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import mmcv
|
||||
import mmengine
|
||||
|
@ -9,6 +9,7 @@ from mmengine.dataset import Compose
|
|||
from mmengine.infer.infer import BaseInferencer, ModelType
|
||||
from mmengine.registry import init_default_scope
|
||||
from mmengine.structures import InstanceData
|
||||
from rich.progress import track
|
||||
from torch import Tensor
|
||||
|
||||
from mmocr.utils import ConfigType
|
||||
|
@ -44,10 +45,10 @@ class BaseMMOCRInferencer(BaseInferencer):
|
|||
forward_kwargs: set = set()
|
||||
visualize_kwargs: set = {
|
||||
'return_vis', 'show', 'wait_time', 'draw_pred', 'pred_score_thr',
|
||||
'img_out_dir'
|
||||
'save_vis'
|
||||
}
|
||||
postprocess_kwargs: set = {
|
||||
'print_result', 'pred_out_file', 'return_datasample'
|
||||
'print_result', 'return_datasample', 'save_pred'
|
||||
}
|
||||
loading_transforms: list = ['LoadImageFromFile', 'LoadImageFromNDArray']
|
||||
|
||||
|
@ -55,26 +56,69 @@ class BaseMMOCRInferencer(BaseInferencer):
|
|||
model: Union[ModelType, str, None] = None,
|
||||
weights: Optional[str] = None,
|
||||
device: Optional[str] = None,
|
||||
scope: Optional[str] = 'mmocr') -> None:
|
||||
# A global counter tracking the number of images processed, for
|
||||
# naming of the output images
|
||||
self.num_visualized_imgs = 0
|
||||
scope: str = 'mmocr') -> None:
|
||||
# A global counter tracking the number of images given in the form
|
||||
# of ndarray, for naming the output images
|
||||
self.num_unnamed_imgs = 0
|
||||
init_default_scope(scope)
|
||||
super().__init__(
|
||||
model=model, weights=weights, device=device, scope=scope)
|
||||
|
||||
def preprocess(self, inputs: InputsType, batch_size: int = 1, **kwargs):
|
||||
"""Process the inputs into a model-feedable format.
|
||||
|
||||
Args:
|
||||
inputs (InputsType): Inputs given by user.
|
||||
batch_size (int): batch size. Defaults to 1.
|
||||
|
||||
Yields:
|
||||
Any: Data processed by the ``pipeline`` and ``collate_fn``.
|
||||
"""
|
||||
chunked_data = self._get_chunk_data(inputs, batch_size)
|
||||
yield from map(self.collate_fn, chunked_data)
|
||||
|
||||
def _get_chunk_data(self, inputs: Iterable, chunk_size: int):
|
||||
"""Get batch data from inputs.
|
||||
|
||||
Args:
|
||||
inputs (Iterable): An iterable dataset.
|
||||
chunk_size (int): Equivalent to batch size.
|
||||
|
||||
Yields:
|
||||
list: batch data.
|
||||
"""
|
||||
inputs_iter = iter(inputs)
|
||||
while True:
|
||||
try:
|
||||
chunk_data = []
|
||||
for _ in range(chunk_size):
|
||||
inputs_ = next(inputs_iter)
|
||||
pipe_out = self.pipeline(inputs_)
|
||||
if pipe_out['data_samples'].get('img_path') is None:
|
||||
pipe_out['data_samples'].set_metainfo(
|
||||
dict(img_path=f'{self.num_unnamed_imgs}.jpg'))
|
||||
self.num_unnamed_imgs += 1
|
||||
chunk_data.append((inputs_, pipe_out))
|
||||
yield chunk_data
|
||||
except StopIteration:
|
||||
if chunk_data:
|
||||
yield chunk_data
|
||||
break
|
||||
|
||||
def __call__(self,
|
||||
inputs: InputsType,
|
||||
return_datasamples: bool = False,
|
||||
batch_size: int = 1,
|
||||
progress_bar: bool = True,
|
||||
return_vis: bool = False,
|
||||
show: bool = False,
|
||||
wait_time: int = 0,
|
||||
draw_pred: bool = True,
|
||||
pred_score_thr: float = 0.3,
|
||||
img_out_dir: str = '',
|
||||
out_dir: str = 'results/',
|
||||
save_vis: bool = False,
|
||||
save_pred: bool = False,
|
||||
print_result: bool = False,
|
||||
pred_out_file: str = '',
|
||||
**kwargs) -> dict:
|
||||
"""Call the inferencer.
|
||||
|
||||
|
@ -85,6 +129,8 @@ class BaseMMOCRInferencer(BaseInferencer):
|
|||
return_datasamples (bool): Whether to return results as
|
||||
:obj:`BaseDataElement`. Defaults to False.
|
||||
batch_size (int): Inference batch size. Defaults to 1.
|
||||
progress_bar (bool): Whether to show a progress bar. Defaults to
|
||||
True.
|
||||
return_vis (bool): Whether to return the visualization result.
|
||||
Defaults to False.
|
||||
show (bool): Whether to display the visualization results in a
|
||||
|
@ -94,8 +140,11 @@ class BaseMMOCRInferencer(BaseInferencer):
|
|||
Defaults to True.
|
||||
pred_score_thr (float): Minimum score of bboxes to draw.
|
||||
Defaults to 0.3.
|
||||
img_out_dir (str): Output directory of visualization results.
|
||||
If left as empty, no file will be saved. Defaults to ''.
|
||||
out_dir (str): Output directory of results. Defaults to 'results/'.
|
||||
save_vis (bool): Whether to save the visualization results to
|
||||
"out_dir". Defaults to False.
|
||||
save_pred (bool): Whether to save the inference results to
|
||||
"out_dir". Defaults to False.
|
||||
print_result (bool): Whether to print the inference result w/o
|
||||
visualization to the console. Defaults to False.
|
||||
pred_out_file: File to save the inference results w/o
|
||||
|
@ -109,22 +158,53 @@ class BaseMMOCRInferencer(BaseInferencer):
|
|||
and ``postprocess_kwargs``.
|
||||
|
||||
Returns:
|
||||
dict: Inference and visualization results.
|
||||
dict: Inference and visualization results, mapped from
|
||||
"predictions" and "visualization".
|
||||
"""
|
||||
return super().__call__(
|
||||
inputs,
|
||||
return_datasamples,
|
||||
batch_size,
|
||||
if (save_vis or save_pred) and not out_dir:
|
||||
raise ValueError('out_dir must be specified when save_vis or '
|
||||
'save_pred is True!')
|
||||
if out_dir:
|
||||
img_out_dir = osp.join(out_dir, 'vis')
|
||||
pred_out_dir = osp.join(out_dir, 'preds')
|
||||
else:
|
||||
img_out_dir, pred_out_dir = '', ''
|
||||
(
|
||||
preprocess_kwargs,
|
||||
forward_kwargs,
|
||||
visualize_kwargs,
|
||||
postprocess_kwargs,
|
||||
) = self._dispatch_kwargs(
|
||||
return_vis=return_vis,
|
||||
show=show,
|
||||
wait_time=wait_time,
|
||||
draw_pred=draw_pred,
|
||||
pred_score_thr=pred_score_thr,
|
||||
img_out_dir=img_out_dir,
|
||||
save_vis=save_vis,
|
||||
save_pred=save_pred,
|
||||
print_result=print_result,
|
||||
pred_out_file=pred_out_file,
|
||||
**kwargs)
|
||||
|
||||
ori_inputs = self._inputs_to_list(inputs)
|
||||
inputs = self.preprocess(
|
||||
ori_inputs, batch_size=batch_size, **preprocess_kwargs)
|
||||
results = {'predictions': [], 'visualization': []}
|
||||
for ori_inputs, data in track(
|
||||
inputs, description='Inference', disable=not progress_bar):
|
||||
preds = self.forward(data, **forward_kwargs)
|
||||
visualization = self.visualize(
|
||||
ori_inputs, preds, img_out_dir=img_out_dir, **visualize_kwargs)
|
||||
batch_res = self.postprocess(
|
||||
preds,
|
||||
visualization,
|
||||
return_datasamples,
|
||||
pred_out_dir=pred_out_dir,
|
||||
**postprocess_kwargs)
|
||||
results['predictions'].extend(batch_res['predictions'])
|
||||
if batch_res['visualization'] is not None:
|
||||
results['visualization'].extend(batch_res['visualization'])
|
||||
return results
|
||||
|
||||
def _init_pipeline(self, cfg: ConfigType) -> Compose:
|
||||
"""Initialize the test pipeline."""
|
||||
pipeline_cfg = cfg.test_dataloader.dataset.pipeline
|
||||
|
@ -170,6 +250,7 @@ class BaseMMOCRInferencer(BaseInferencer):
|
|||
wait_time: int = 0,
|
||||
draw_pred: bool = True,
|
||||
pred_score_thr: float = 0.3,
|
||||
save_vis: bool = False,
|
||||
img_out_dir: str = '') -> Union[List[np.ndarray], None]:
|
||||
"""Visualize predictions.
|
||||
|
||||
|
@ -185,6 +266,8 @@ class BaseMMOCRInferencer(BaseInferencer):
|
|||
Defaults to True.
|
||||
pred_score_thr (float): Minimum score of bboxes to draw.
|
||||
Defaults to 0.3.
|
||||
save_vis (bool): Whether to save the visualization result. Defaults
|
||||
to False.
|
||||
img_out_dir (str): Output directory of visualization results.
|
||||
If left as empty, no file will be saved. Defaults to ''.
|
||||
|
||||
|
@ -192,8 +275,7 @@ class BaseMMOCRInferencer(BaseInferencer):
|
|||
List[np.ndarray] or None: Returns visualization results only if
|
||||
applicable.
|
||||
"""
|
||||
if self.visualizer is None or (not show and img_out_dir == ''
|
||||
and not return_vis):
|
||||
if self.visualizer is None or not (show or save_vis or return_vis):
|
||||
return None
|
||||
|
||||
if getattr(self, 'visualizer') is None:
|
||||
|
@ -206,17 +288,19 @@ class BaseMMOCRInferencer(BaseInferencer):
|
|||
if isinstance(single_input, str):
|
||||
img_bytes = mmengine.fileio.get(single_input)
|
||||
img = mmcv.imfrombytes(img_bytes, channel_order='rgb')
|
||||
img_name = osp.basename(single_input)
|
||||
elif isinstance(single_input, np.ndarray):
|
||||
img = single_input.copy()[:, :, ::-1] # to RGB
|
||||
img_num = str(self.num_visualized_imgs).zfill(8)
|
||||
img_name = f'{img_num}.jpg'
|
||||
else:
|
||||
raise ValueError('Unsupported input type: '
|
||||
f'{type(single_input)}')
|
||||
img_name = osp.splitext(osp.basename(pred.img_path))[0]
|
||||
|
||||
out_file = osp.join(img_out_dir, img_name) if img_out_dir != '' \
|
||||
else None
|
||||
if save_vis and img_out_dir:
|
||||
out_file = osp.splitext(img_name)[0]
|
||||
out_file = f'{out_file}.jpg'
|
||||
out_file = osp.join(img_out_dir, out_file)
|
||||
else:
|
||||
out_file = None
|
||||
|
||||
visualization = self.visualizer.add_datasample(
|
||||
img_name,
|
||||
|
@ -230,7 +314,6 @@ class BaseMMOCRInferencer(BaseInferencer):
|
|||
out_file=out_file,
|
||||
)
|
||||
results.append(visualization)
|
||||
self.num_visualized_imgs += 1
|
||||
|
||||
return results
|
||||
|
||||
|
@ -240,7 +323,8 @@ class BaseMMOCRInferencer(BaseInferencer):
|
|||
visualization: Optional[List[np.ndarray]] = None,
|
||||
return_datasample: bool = False,
|
||||
print_result: bool = False,
|
||||
pred_out_file: str = '',
|
||||
save_pred: bool = False,
|
||||
pred_out_dir: str = '',
|
||||
) -> Union[ResType, Tuple[ResType, np.ndarray]]:
|
||||
"""Process the predictions and visualization results from ``forward``
|
||||
and ``visualize``.
|
||||
|
@ -258,7 +342,9 @@ class BaseMMOCRInferencer(BaseInferencer):
|
|||
inference results. If False, dict will be used.
|
||||
print_result (bool): Whether to print the inference result w/o
|
||||
visualization to the console. Defaults to False.
|
||||
pred_out_file: File to save the inference results w/o
|
||||
save_pred (bool): Whether to save the inference result. Defaults to
|
||||
False.
|
||||
pred_out_dir: File to save the inference results w/o
|
||||
visualization. If left as empty, no file will be saved.
|
||||
Defaults to ''.
|
||||
|
||||
|
@ -279,13 +365,16 @@ class BaseMMOCRInferencer(BaseInferencer):
|
|||
results = []
|
||||
for pred in preds:
|
||||
result = self.pred2dict(pred)
|
||||
if save_pred and pred_out_dir:
|
||||
pred_name = osp.splitext(osp.basename(pred.img_path))[0]
|
||||
pred_name = f'{pred_name}.json'
|
||||
pred_out_file = osp.join(pred_out_dir, pred_name)
|
||||
mmengine.dump(result, pred_out_file)
|
||||
results.append(result)
|
||||
# Add img to the results after printing and dumping
|
||||
result_dict['predictions'] = results
|
||||
if print_result:
|
||||
print(result_dict)
|
||||
if pred_out_file != '':
|
||||
mmengine.dump(result_dict, pred_out_file)
|
||||
result_dict['visualization'] = visualization
|
||||
return result_dict
|
||||
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import os.path as osp
|
||||
from typing import Dict, List, Optional, Sequence, Union
|
||||
from typing import Any, Dict, List, Optional, Sequence, Union
|
||||
|
||||
import mmcv
|
||||
import mmengine
|
||||
import numpy as np
|
||||
from mmengine.dataset import Compose
|
||||
from mmengine.dataset import Compose, pseudo_collate
|
||||
from mmengine.runner.checkpoint import _load_checkpoint
|
||||
|
||||
from mmocr.registry import DATASETS
|
||||
|
@ -45,6 +45,7 @@ class KIEInferencer(BaseMMOCRInferencer):
|
|||
super().__init__(
|
||||
model=model, weights=weights, device=device, scope=scope)
|
||||
self._load_metainfo_to_visualizer(weights, self.cfg)
|
||||
self.collate_fn = self.kie_collate
|
||||
|
||||
def _load_metainfo_to_visualizer(self, weights: Optional[str],
|
||||
cfg: ConfigType) -> None:
|
||||
|
@ -90,6 +91,21 @@ class KIEInferencer(BaseMMOCRInferencer):
|
|||
return Compose(pipeline_cfg)
|
||||
return super()._init_pipeline(cfg)
|
||||
|
||||
@staticmethod
|
||||
def kie_collate(data_batch: Sequence) -> Any:
|
||||
"""A collate function designed for KIE, where the first element (input)
|
||||
is a dict and we only want to keep it as-is instead of batching
|
||||
elements inside.
|
||||
|
||||
Returns:
|
||||
Any: Transversed Data in the same format as the data_itement of
|
||||
``data_batch``.
|
||||
""" # noqa: E501
|
||||
transposed = list(zip(*data_batch))
|
||||
for i in range(1, len(transposed)):
|
||||
transposed[i] = pseudo_collate(transposed[i])
|
||||
return transposed
|
||||
|
||||
def _inputs_to_list(self, inputs: InputsType) -> list:
|
||||
"""Preprocess the inputs to a list.
|
||||
|
||||
|
@ -167,94 +183,9 @@ class KIEInferencer(BaseMMOCRInferencer):
|
|||
else:
|
||||
atype = type(single_input['img'])
|
||||
raise ValueError(f'Unsupported input type: {atype}')
|
||||
|
||||
return processed_inputs
|
||||
|
||||
def __call__(self,
|
||||
inputs: InputsType,
|
||||
return_datasamples: bool = False,
|
||||
batch_size: int = 1,
|
||||
return_vis: bool = False,
|
||||
show: bool = False,
|
||||
wait_time: int = 0,
|
||||
draw_pred: bool = True,
|
||||
pred_score_thr: float = 0.3,
|
||||
img_out_dir: str = '',
|
||||
print_result: bool = False,
|
||||
pred_out_file: str = '',
|
||||
**kwargs) -> dict:
|
||||
"""Call the inferencer.
|
||||
|
||||
The inputs for KIE Inferencer is special compared to other tasks.
|
||||
They can be a dict or list[dict], where each dictionary contains
|
||||
following keys:
|
||||
|
||||
- img (str or ndarray): Path to the image or the image itself. If KIE
|
||||
Inferencer is used in no-visual mode, this key is not required.
|
||||
Note: If it's an numpy array, it should be in BGR order.
|
||||
- img_shape (tuple(int, int)): Image shape in (H, W). In
|
||||
- instances (list[dict]): A list of instances.
|
||||
- bbox (ndarray(dtype=np.float32)): Shape (4, ). Bounding box.
|
||||
- text (str): Annotation text.
|
||||
|
||||
Each ``instance`` looks like the following:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{
|
||||
# A nested list of 4 numbers representing the bounding box of
|
||||
# the instance, in (x1, y1, x2, y2) order.
|
||||
'bbox': np.array([[x1, y1, x2, y2], [x1, y1, x2, y2], ...],
|
||||
dtype=np.int32),
|
||||
|
||||
# List of texts.
|
||||
"texts": ['text1', 'text2', ...],
|
||||
}
|
||||
|
||||
Args:
|
||||
inputs (InputsType): Inputs for the inferencer.
|
||||
return_datasamples (bool): Whether to return results as
|
||||
:obj:`BaseDataElement`. Defaults to False.
|
||||
batch_size (int): Inference batch size. Defaults to 1.
|
||||
return_vis (bool): Whether to return the visualization result.
|
||||
Defaults to False.
|
||||
show (bool): Whether to display the visualization results in a
|
||||
popup window. Defaults to False.
|
||||
wait_time (float): The interval of show (s). Defaults to 0.
|
||||
draw_pred (bool): Whether to draw predicted bounding boxes.
|
||||
Defaults to True.
|
||||
pred_score_thr (float): Minimum score of bboxes to draw.
|
||||
Defaults to 0.3.
|
||||
img_out_dir (str): Output directory of visualization results.
|
||||
If left as empty, no file will be saved. Defaults to ''.
|
||||
print_result (bool): Whether to print the inference result w/o
|
||||
visualization to the console. Defaults to False.
|
||||
pred_out_file: File to save the inference results w/o
|
||||
visualization. If left as empty, no file will be saved.
|
||||
Defaults to ''.
|
||||
|
||||
**kwargs: Other keyword arguments passed to :meth:`preprocess`,
|
||||
:meth:`forward`, :meth:`visualize` and :meth:`postprocess`.
|
||||
Each key in kwargs should be in the corresponding set of
|
||||
``preprocess_kwargs``, ``forward_kwargs``, ``visualize_kwargs``
|
||||
and ``postprocess_kwargs``.
|
||||
|
||||
Returns:
|
||||
dict: Inference and visualization results.
|
||||
"""
|
||||
return super().__call__(
|
||||
inputs,
|
||||
return_datasamples,
|
||||
batch_size,
|
||||
return_vis=return_vis,
|
||||
show=show,
|
||||
wait_time=wait_time,
|
||||
draw_pred=draw_pred,
|
||||
pred_score_thr=pred_score_thr,
|
||||
img_out_dir=img_out_dir,
|
||||
print_result=print_result,
|
||||
pred_out_file=pred_out_file,
|
||||
**kwargs)
|
||||
|
||||
def visualize(self,
|
||||
inputs: InputsType,
|
||||
preds: PredType,
|
||||
|
@ -263,7 +194,8 @@ class KIEInferencer(BaseMMOCRInferencer):
|
|||
wait_time: int = 0,
|
||||
draw_pred: bool = True,
|
||||
pred_score_thr: float = 0.3,
|
||||
img_out_dir: str = '') -> List[np.ndarray]:
|
||||
save_vis: bool = False,
|
||||
img_out_dir: str = '') -> Union[List[np.ndarray], None]:
|
||||
"""Visualize predictions.
|
||||
|
||||
Args:
|
||||
|
@ -278,10 +210,16 @@ class KIEInferencer(BaseMMOCRInferencer):
|
|||
Defaults to True.
|
||||
pred_score_thr (float): Minimum score of bboxes to draw.
|
||||
Defaults to 0.3.
|
||||
img_out_dir (str): Output directory of images. Defaults to ''.
|
||||
save_vis (bool): Whether to save the visualization result. Defaults
|
||||
to False.
|
||||
img_out_dir (str): Output directory of visualization results.
|
||||
If left as empty, no file will be saved. Defaults to ''.
|
||||
|
||||
Returns:
|
||||
List[np.ndarray] or None: Returns visualization results only if
|
||||
applicable.
|
||||
"""
|
||||
if self.visualizer is None or (not show and img_out_dir == ''
|
||||
and not return_vis):
|
||||
if self.visualizer is None or not (show or save_vis or return_vis):
|
||||
return None
|
||||
|
||||
if getattr(self, 'visualizer') is None:
|
||||
|
@ -291,24 +229,26 @@ class KIEInferencer(BaseMMOCRInferencer):
|
|||
results = []
|
||||
|
||||
for single_input, pred in zip(inputs, preds):
|
||||
img_num = str(self.num_visualized_imgs).zfill(8)
|
||||
assert 'img' in single_input or 'img_shape' in single_input
|
||||
if 'img' in single_input:
|
||||
if isinstance(single_input['img'], str):
|
||||
img = mmcv.imread(single_input['img'], channel_order='rgb')
|
||||
img_name = osp.basename(single_input['img'])
|
||||
img_bytes = mmengine.fileio.get(single_input['img'])
|
||||
img = mmcv.imfrombytes(img_bytes, channel_order='rgb')
|
||||
elif isinstance(single_input['img'], np.ndarray):
|
||||
img = single_input['img'].copy()[:, :, ::-1] # To RGB
|
||||
img_name = f'{img_num}.jpg'
|
||||
elif 'img_shape' in single_input:
|
||||
img = np.zeros(single_input['img_shape'], dtype=np.uint8)
|
||||
img_name = f'{img_num}.jpg'
|
||||
else:
|
||||
raise ValueError('Input does not contain either "img" or '
|
||||
'"img_shape"')
|
||||
img_name = osp.splitext(osp.basename(pred.img_path))[0]
|
||||
|
||||
out_file = osp.join(img_out_dir, img_name) if img_out_dir != '' \
|
||||
else None
|
||||
if save_vis and img_out_dir:
|
||||
out_file = osp.splitext(img_name)[0]
|
||||
out_file = f'{out_file}.jpg'
|
||||
out_file = osp.join(img_out_dir, out_file)
|
||||
else:
|
||||
out_file = None
|
||||
|
||||
visualization = self.visualizer.add_datasample(
|
||||
img_name,
|
||||
|
@ -322,7 +262,6 @@ class KIEInferencer(BaseMMOCRInferencer):
|
|||
out_file=out_file,
|
||||
)
|
||||
results.append(visualization)
|
||||
self.num_visualized_imgs += 1
|
||||
|
||||
return results
|
||||
|
||||
|
|
|
@ -1,16 +1,15 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import os.path as osp
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import mmcv
|
||||
import mmengine
|
||||
import numpy as np
|
||||
from mmengine.fileio import (get_file_backend, isdir, join_path,
|
||||
list_dir_or_file)
|
||||
from rich.progress import track
|
||||
|
||||
from mmocr.registry import VISUALIZERS
|
||||
from mmocr.structures.textdet_data_sample import TextDetDataSample
|
||||
from mmocr.structures import TextSpottingDataSample
|
||||
from mmocr.utils import ConfigType, bbox2poly, crop_img, poly2bbox
|
||||
from .base_mmocr_inferencer import (BaseMMOCRInferencer, InputsType, PredType,
|
||||
ResType)
|
||||
|
@ -65,7 +64,6 @@ class MMOCRInferencer(BaseMMOCRInferencer):
|
|||
'provided.')
|
||||
|
||||
self.visualizer = None
|
||||
self.num_visualized_imgs = 0
|
||||
|
||||
if det is not None:
|
||||
self.textdet_inferencer = TextDetInferencer(
|
||||
|
@ -93,43 +91,19 @@ class MMOCRInferencer(BaseMMOCRInferencer):
|
|||
self.kie_inferencer = KIEInferencer(kie, kie_weights, device)
|
||||
self.mode = 'det_rec_kie'
|
||||
|
||||
def _inputs_to_list(self, inputs: InputsType) -> list:
|
||||
"""Preprocess the inputs to a list.
|
||||
|
||||
Preprocess inputs to a list according to its type:
|
||||
|
||||
- list or tuple: return inputs
|
||||
- str:
|
||||
- Directory path: return all files in the directory
|
||||
- normal string: return a list containing the string
|
||||
|
||||
Args:
|
||||
inputs (InputsType): Inputs for the inferencer.
|
||||
|
||||
Returns:
|
||||
list: List of input for the :meth:`preprocess`.
|
||||
"""
|
||||
inputs = copy.deepcopy(inputs)
|
||||
if isinstance(inputs, str):
|
||||
backend = get_file_backend(inputs)
|
||||
if hasattr(backend, 'isdir') and isdir(inputs):
|
||||
# Backends like HttpsBackend do not implement `isdir`, so only
|
||||
# those backends that implement `isdir` could accept the inputs
|
||||
# as a directory
|
||||
filename_list = list_dir_or_file(inputs, list_dir=False)
|
||||
inputs = [
|
||||
join_path(inputs, filename) for filename in filename_list
|
||||
]
|
||||
|
||||
if not isinstance(inputs, (list, tuple)):
|
||||
inputs = [inputs]
|
||||
|
||||
for i in range(len(inputs)):
|
||||
if not isinstance(inputs[i], np.ndarray):
|
||||
img_bytes = mmengine.fileio.get(inputs[i])
|
||||
inputs[i] = mmcv.imfrombytes(img_bytes)
|
||||
|
||||
return list(inputs)
|
||||
def _inputs2ndarrray(self, inputs: List[InputsType]) -> List[np.ndarray]:
|
||||
"""Preprocess the inputs to a list of numpy arrays."""
|
||||
new_inputs = []
|
||||
for item in inputs:
|
||||
if isinstance(item, np.ndarray):
|
||||
new_inputs.append(item)
|
||||
elif isinstance(item, str):
|
||||
img_bytes = mmengine.fileio.get(item)
|
||||
new_inputs.append(mmcv.imfrombytes(img_bytes))
|
||||
else:
|
||||
raise NotImplementedError(f'The input type {type(item)} is not'
|
||||
'supported yet.')
|
||||
return new_inputs
|
||||
|
||||
def forward(self, inputs: InputsType, batch_size: int,
|
||||
**forward_kwargs) -> PredType:
|
||||
|
@ -144,6 +118,7 @@ class MMOCRInferencer(BaseMMOCRInferencer):
|
|||
"kie"..
|
||||
"""
|
||||
result = {}
|
||||
forward_kwargs['progress_bar'] = False
|
||||
if self.mode == 'rec':
|
||||
# The extra list wrapper here is for the ease of postprocessing
|
||||
self.rec_inputs = inputs
|
||||
|
@ -153,15 +128,16 @@ class MMOCRInferencer(BaseMMOCRInferencer):
|
|||
batch_size=batch_size,
|
||||
**forward_kwargs)['predictions']
|
||||
result['rec'] = [[p] for p in predictions]
|
||||
elif self.mode.startswith('det'):
|
||||
elif self.mode.startswith('det'): # 'det'/'det_rec'/'det_rec_kie'
|
||||
result['det'] = self.textdet_inferencer(
|
||||
inputs,
|
||||
return_datasamples=True,
|
||||
batch_size=batch_size,
|
||||
**forward_kwargs)['predictions']
|
||||
if self.mode.startswith('det_rec'):
|
||||
if self.mode.startswith('det_rec'): # 'det_rec'/'det_rec_kie'
|
||||
result['rec'] = []
|
||||
for img, det_data_sample in zip(inputs, result['det']):
|
||||
for img, det_data_sample in zip(
|
||||
self._inputs2ndarrray(inputs), result['det']):
|
||||
det_pred = det_data_sample.pred_instances
|
||||
self.rec_inputs = []
|
||||
for polygon in det_pred['polygons']:
|
||||
|
@ -177,6 +153,10 @@ class MMOCRInferencer(BaseMMOCRInferencer):
|
|||
**forward_kwargs)['predictions'])
|
||||
if self.mode == 'det_rec_kie':
|
||||
self.kie_inputs = []
|
||||
# TODO: when the det output is empty, kie will fail
|
||||
# as no gt-instances can be provided. It's a known
|
||||
# issue but cannot be solved elegantly since we support
|
||||
# batch inference.
|
||||
for img, det_data_sample, rec_data_samples in zip(
|
||||
inputs, result['det'], result['rec']):
|
||||
det_pred = det_data_sample.pred_instances
|
||||
|
@ -197,7 +177,7 @@ class MMOCRInferencer(BaseMMOCRInferencer):
|
|||
return result
|
||||
|
||||
def visualize(self, inputs: InputsType, preds: PredType,
|
||||
**kwargs) -> List[np.ndarray]:
|
||||
**kwargs) -> Union[List[np.ndarray], None]:
|
||||
"""Visualize predictions.
|
||||
|
||||
Args:
|
||||
|
@ -210,7 +190,14 @@ class MMOCRInferencer(BaseMMOCRInferencer):
|
|||
Defaults to True.
|
||||
pred_score_thr (float): Minimum score of bboxes to draw.
|
||||
Defaults to 0.3.
|
||||
img_out_dir (str): Output directory of images. Defaults to ''.
|
||||
save_vis (bool): Whether to save the visualization result. Defaults
|
||||
to False.
|
||||
img_out_dir (str): Output directory of visualization results.
|
||||
If left as empty, no file will be saved. Defaults to ''.
|
||||
|
||||
Returns:
|
||||
List[np.ndarray] or None: Returns visualization results only if
|
||||
applicable.
|
||||
"""
|
||||
|
||||
if 'kie' in self.mode:
|
||||
|
@ -232,6 +219,9 @@ class MMOCRInferencer(BaseMMOCRInferencer):
|
|||
self,
|
||||
inputs: InputsType,
|
||||
batch_size: int = 1,
|
||||
out_dir: str = 'results/',
|
||||
save_vis: bool = False,
|
||||
save_pred: bool = False,
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
"""Call the inferencer.
|
||||
|
@ -239,9 +229,12 @@ class MMOCRInferencer(BaseMMOCRInferencer):
|
|||
Args:
|
||||
inputs (InputsType): Inputs for the inferencer. It can be a path
|
||||
to image / image directory, or an array, or a list of these.
|
||||
return_datasamples (bool): Whether to return results as
|
||||
:obj:`BaseDataElement`. Defaults to False.
|
||||
batch_size (int): Batch size. Defaults to 1.
|
||||
out_dir (str): Output directory of results. Defaults to 'results/'.
|
||||
save_vis (bool): Whether to save the visualization results to
|
||||
"out_dir". Defaults to False.
|
||||
save_pred (bool): Whether to save the inference results to
|
||||
"out_dir". Defaults to False.
|
||||
**kwargs: Key words arguments passed to :meth:`preprocess`,
|
||||
:meth:`forward`, :meth:`visualize` and :meth:`postprocess`.
|
||||
Each key in kwargs should be in the corresponding set of
|
||||
|
@ -249,47 +242,75 @@ class MMOCRInferencer(BaseMMOCRInferencer):
|
|||
and ``postprocess_kwargs``.
|
||||
|
||||
Returns:
|
||||
dict: Inference and visualization results.
|
||||
dict: Inference and visualization results, mapped from
|
||||
"predictions" and "visualization".
|
||||
"""
|
||||
if (save_vis or save_pred) and not out_dir:
|
||||
raise ValueError('out_dir must be specified when save_vis or '
|
||||
'save_pred is True!')
|
||||
if out_dir:
|
||||
img_out_dir = osp.join(out_dir, 'vis')
|
||||
pred_out_dir = osp.join(out_dir, 'preds')
|
||||
else:
|
||||
img_out_dir, pred_out_dir = '', ''
|
||||
|
||||
(
|
||||
preprocess_kwargs,
|
||||
forward_kwargs,
|
||||
visualize_kwargs,
|
||||
postprocess_kwargs,
|
||||
) = self._dispatch_kwargs(**kwargs)
|
||||
) = self._dispatch_kwargs(
|
||||
save_vis=save_vis, save_pred=save_pred, **kwargs)
|
||||
|
||||
ori_inputs = self._inputs_to_list(inputs)
|
||||
|
||||
preds = self.forward(ori_inputs, batch_size, **forward_kwargs)
|
||||
|
||||
visualization = self.visualize(
|
||||
ori_inputs, preds,
|
||||
**visualize_kwargs) # type: ignore # noqa: E501
|
||||
results = self.postprocess(preds, visualization, **postprocess_kwargs)
|
||||
chunked_inputs = super(BaseMMOCRInferencer,
|
||||
self)._get_chunk_data(ori_inputs, batch_size)
|
||||
results = {'predictions': [], 'visualization': []}
|
||||
for ori_input in track(chunked_inputs, description='Inference'):
|
||||
preds = self.forward(ori_input, batch_size, **forward_kwargs)
|
||||
visualization = self.visualize(
|
||||
ori_input, preds, img_out_dir=img_out_dir, **visualize_kwargs)
|
||||
batch_res = self.postprocess(
|
||||
preds,
|
||||
visualization,
|
||||
pred_out_dir=pred_out_dir,
|
||||
**postprocess_kwargs)
|
||||
results['predictions'].extend(batch_res['predictions'])
|
||||
if batch_res['visualization'] is not None:
|
||||
results['visualization'].extend(batch_res['visualization'])
|
||||
return results
|
||||
|
||||
def postprocess(self,
|
||||
preds: PredType,
|
||||
visualization: Optional[List[np.ndarray]] = None,
|
||||
print_result: bool = False,
|
||||
pred_out_file: str = ''
|
||||
save_pred: bool = False,
|
||||
pred_out_dir: str = ''
|
||||
) -> Union[ResType, Tuple[ResType, np.ndarray]]:
|
||||
"""Postprocess predictions.
|
||||
"""Process the predictions and visualization results from ``forward``
|
||||
and ``visualize``.
|
||||
|
||||
This method should be responsible for the following tasks:
|
||||
|
||||
1. Convert datasamples into a json-serializable dict if needed.
|
||||
2. Pack the predictions and visualization results and return them.
|
||||
3. Dump or log the predictions.
|
||||
|
||||
Args:
|
||||
preds (Dict): Predictions of the model.
|
||||
preds (PredType): Predictions of the model.
|
||||
visualization (Optional[np.ndarray]): Visualized predictions.
|
||||
print_result (bool): Whether to print the result.
|
||||
Defaults to False.
|
||||
pred_out_file (str): Output file name to store predictions
|
||||
without images. Supported file formats are “json”, “yaml/yml”
|
||||
and “pickle/pkl”. Defaults to ''.
|
||||
save_pred (bool): Whether to save the inference result. Defaults to
|
||||
False.
|
||||
pred_out_dir: File to save the inference results w/o
|
||||
visualization. If left as empty, no file will be saved.
|
||||
Defaults to ''.
|
||||
|
||||
Returns:
|
||||
Dict or List[Dict]: Each dict contains the inference result of
|
||||
each image. Possible keys are "det_polygons", "det_scores",
|
||||
"rec_texts", "rec_scores", "kie_labels", "kie_scores",
|
||||
"kie_edge_labels" and "kie_edge_scores".
|
||||
Dict: Inference and visualization results, mapped from
|
||||
"predictions" and "visualization".
|
||||
"""
|
||||
|
||||
result_dict = {}
|
||||
|
@ -320,22 +341,28 @@ class MMOCRInferencer(BaseMMOCRInferencer):
|
|||
kie_edge_scores=kie_dict_res['edge_scores'],
|
||||
kie_edge_labels=kie_dict_res['edge_labels'])
|
||||
|
||||
if save_pred and pred_out_dir:
|
||||
pred_key = 'det' if 'det' in self.mode else 'rec'
|
||||
for pred, pred_result in zip(preds[pred_key], pred_results):
|
||||
img_path = (
|
||||
pred.img_path if pred_key == 'det' else pred[0].img_path)
|
||||
pred_name = osp.splitext(osp.basename(img_path))[0]
|
||||
pred_name = f'{pred_name}.json'
|
||||
pred_out_file = osp.join(pred_out_dir, pred_name)
|
||||
mmengine.dump(pred_result, pred_out_file)
|
||||
|
||||
result_dict['predictions'] = pred_results
|
||||
if print_result:
|
||||
print(result_dict)
|
||||
if pred_out_file != '':
|
||||
mmengine.dump(result_dict, pred_out_file)
|
||||
result_dict['visualization'] = visualization
|
||||
return result_dict
|
||||
|
||||
def _pack_e2e_datasamples(self, preds: Dict) -> List[TextDetDataSample]:
|
||||
def _pack_e2e_datasamples(self,
|
||||
preds: Dict) -> List[TextSpottingDataSample]:
|
||||
"""Pack text detection and recognition results into a list of
|
||||
TextDetDataSample.
|
||||
|
||||
Note that it is a temporary solution since the TextSpottingDataSample
|
||||
is not ready.
|
||||
"""
|
||||
TextSpottingDataSample."""
|
||||
results = []
|
||||
|
||||
for det_data_sample, rec_data_samples in zip(preds['det'],
|
||||
preds['rec']):
|
||||
texts = []
|
||||
|
|
17
mmocr/ocr.py
17
mmocr/ocr.py
|
@ -9,10 +9,10 @@ def parse_args():
|
|||
parser.add_argument(
|
||||
'inputs', type=str, help='Input image file or folder path.')
|
||||
parser.add_argument(
|
||||
'--img-out-dir',
|
||||
'--out-dir',
|
||||
type=str,
|
||||
default='',
|
||||
help='Output directory of images.')
|
||||
default='results/',
|
||||
help='Output directory of results.')
|
||||
parser.add_argument(
|
||||
'--det',
|
||||
type=str,
|
||||
|
@ -69,10 +69,13 @@ def parse_args():
|
|||
action='store_true',
|
||||
help='Whether to print the results.')
|
||||
parser.add_argument(
|
||||
'--pred-out-file',
|
||||
type=str,
|
||||
default='',
|
||||
help='File to save the inference results.')
|
||||
'--save_pred',
|
||||
action='store_true',
|
||||
help='Save the inference results to out_dir.')
|
||||
parser.add_argument(
|
||||
'--save_vis',
|
||||
action='store_true',
|
||||
help='Save the visualization results to out_dir.')
|
||||
|
||||
call_args = vars(parser.parse_args())
|
||||
|
||||
|
|
|
@ -85,6 +85,8 @@ class BaseLocalVisualizer(Visualizer):
|
|||
font_families (Union[str, List[str]]): The font families of labels.
|
||||
Defaults to 'sans-serif'.
|
||||
"""
|
||||
if not labels and not bboxes:
|
||||
return image
|
||||
if colors is not None and isinstance(colors, (list, tuple)):
|
||||
size = math.ceil(len(labels) / len(colors))
|
||||
colors = (colors * size)[:len(labels)]
|
||||
|
|
|
@ -44,18 +44,19 @@ class TextSpottingLocalVisualizer(BaseLocalVisualizer):
|
|||
img_shape = image.shape[:2]
|
||||
empty_shape = (img_shape[0], img_shape[1], 3)
|
||||
text_image = np.full(empty_shape, 255, dtype=np.uint8)
|
||||
text_image = self.get_labels_image(
|
||||
text_image,
|
||||
labels=texts,
|
||||
bboxes=bboxes,
|
||||
font_families=self.font_families)
|
||||
if texts:
|
||||
text_image = self.get_labels_image(
|
||||
text_image,
|
||||
labels=texts,
|
||||
bboxes=bboxes,
|
||||
font_families=self.font_families)
|
||||
if polygons:
|
||||
polygons = [polygon.reshape(-1, 2) for polygon in polygons]
|
||||
image = self.get_polygons_image(
|
||||
image, polygons, filling=True, colors=self.PALETTE)
|
||||
text_image = self.get_polygons_image(
|
||||
text_image, polygons, colors=self.PALETTE)
|
||||
else:
|
||||
elif len(bboxes) > 0:
|
||||
image = self.get_bboxes_image(
|
||||
image, bboxes, filling=True, colors=self.PALETTE)
|
||||
text_image = self.get_bboxes_image(
|
||||
|
@ -103,27 +104,28 @@ class TextSpottingLocalVisualizer(BaseLocalVisualizer):
|
|||
"""
|
||||
cat_images = []
|
||||
|
||||
if draw_gt:
|
||||
gt_bboxes = data_sample.gt_instances.get('bboxes', None)
|
||||
gt_texts = data_sample.gt_instances.texts
|
||||
gt_polygons = data_sample.gt_instances.get('polygons', None)
|
||||
gt_img_data = self._draw_instances(image, gt_bboxes, gt_polygons,
|
||||
gt_texts)
|
||||
cat_images.append(gt_img_data)
|
||||
if data_sample is not None:
|
||||
if draw_gt and 'gt_instances' in data_sample:
|
||||
gt_bboxes = data_sample.gt_instances.get('bboxes', None)
|
||||
gt_texts = data_sample.gt_instances.texts
|
||||
gt_polygons = data_sample.gt_instances.get('polygons', None)
|
||||
gt_img_data = self._draw_instances(image, gt_bboxes,
|
||||
gt_polygons, gt_texts)
|
||||
cat_images.append(gt_img_data)
|
||||
|
||||
if draw_pred:
|
||||
pred_instances = data_sample.pred_instances
|
||||
pred_instances = pred_instances[
|
||||
pred_instances.scores > pred_score_thr].cpu().numpy()
|
||||
pred_bboxes = pred_instances.get('bboxes', None)
|
||||
pred_texts = pred_instances.texts
|
||||
pred_polygons = pred_instances.get('polygons', None)
|
||||
if pred_bboxes is None:
|
||||
pred_bboxes = [poly2bbox(poly) for poly in pred_polygons]
|
||||
pred_bboxes = np.array(pred_bboxes)
|
||||
pred_img_data = self._draw_instances(image, pred_bboxes,
|
||||
pred_polygons, pred_texts)
|
||||
cat_images.append(pred_img_data)
|
||||
if draw_pred and 'pred_instances' in data_sample:
|
||||
pred_instances = data_sample.pred_instances
|
||||
pred_instances = pred_instances[
|
||||
pred_instances.scores > pred_score_thr].cpu().numpy()
|
||||
pred_bboxes = pred_instances.get('bboxes', None)
|
||||
pred_texts = pred_instances.texts
|
||||
pred_polygons = pred_instances.get('polygons', None)
|
||||
if pred_bboxes is None:
|
||||
pred_bboxes = [poly2bbox(poly) for poly in pred_polygons]
|
||||
pred_bboxes = np.array(pred_bboxes)
|
||||
pred_img_data = self._draw_instances(image, pred_bboxes,
|
||||
pred_polygons, pred_texts)
|
||||
cat_images.append(pred_img_data)
|
||||
|
||||
cat_images = self._cat_image(cat_images, axis=0)
|
||||
if cat_images is None:
|
||||
|
|
|
@ -75,12 +75,15 @@ class TestKIEInferencer(TestCase):
|
|||
|
||||
def assert_predictions_equal(self, preds1, preds2):
|
||||
for pred1, pred2 in zip(preds1, preds2):
|
||||
self.assertTrue(np.allclose(pred1['labels'], pred2['labels'], 0.1))
|
||||
self.assertTrue(
|
||||
np.allclose(pred1['edge_scores'], pred2['edge_scores'], 0.1))
|
||||
self.assertTrue(
|
||||
np.allclose(pred1['edge_labels'], pred2['edge_labels'], 0.1))
|
||||
self.assertTrue(np.allclose(pred1['scores'], pred2['scores'], 0.1))
|
||||
self.assert_prediction_equal(pred1, pred2)
|
||||
|
||||
def assert_prediction_equal(self, pred1, pred2):
|
||||
self.assertTrue(np.allclose(pred1['labels'], pred2['labels'], 0.1))
|
||||
self.assertTrue(
|
||||
np.allclose(pred1['edge_scores'], pred2['edge_scores'], 0.1))
|
||||
self.assertTrue(
|
||||
np.allclose(pred1['edge_labels'], pred2['edge_labels'], 0.1))
|
||||
self.assertTrue(np.allclose(pred1['scores'], pred2['scores'], 0.1))
|
||||
|
||||
def test_call(self):
|
||||
# no visual, single input
|
||||
|
@ -130,9 +133,9 @@ class TestKIEInferencer(TestCase):
|
|||
|
||||
# img_out_dir
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
self.inferencer(self.data_img_str, img_out_dir=tmp_dir)
|
||||
for img_dir in ['00000000.jpg', '00000001.jpg']:
|
||||
self.assertTrue(osp.exists(osp.join(tmp_dir, img_dir)))
|
||||
self.inferencer(self.data_img_str, out_dir=tmp_dir, save_vis=True)
|
||||
for img_dir in ['1.jpg', '2.jpg']:
|
||||
self.assertTrue(osp.exists(osp.join(tmp_dir, 'vis', img_dir)))
|
||||
|
||||
def test_postprocess(self):
|
||||
# return_datasample
|
||||
|
@ -141,14 +144,19 @@ class TestKIEInferencer(TestCase):
|
|||
|
||||
# pred_out_file
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pred_out_file = osp.join(tmp_dir, 'tmp.pkl')
|
||||
res = self.inferencer(
|
||||
self.data_img_ndarray,
|
||||
print_result=True,
|
||||
pred_out_file=pred_out_file)
|
||||
dumped_res = mmengine.load(pred_out_file)
|
||||
self.assert_predictions_equal(res['predictions'],
|
||||
dumped_res['predictions'])
|
||||
out_dir=tmp_dir,
|
||||
save_pred=True)
|
||||
file_names = [
|
||||
f'{self.inferencer.num_unnamed_imgs - i}.json'
|
||||
for i in range(len(self.data_img_ndarray), 0, -1)
|
||||
]
|
||||
for pred, file_name in zip(res['predictions'], file_names):
|
||||
dumped_res = mmengine.load(
|
||||
osp.join(tmp_dir, 'preds', file_name))
|
||||
self.assert_prediction_equal(dumped_res, pred)
|
||||
|
||||
@mock.patch('mmocr.apis.inferencers.kie_inferencer._load_checkpoint')
|
||||
def test_load_metainfo_to_visualizer(self, mock_load):
|
||||
|
|
|
@ -95,6 +95,18 @@ class TestMMOCRInferencer(TestCase):
|
|||
np.allclose(res_img_ndarray['visualization'][0],
|
||||
res_img_path['visualization'][0]))
|
||||
|
||||
# test save_vis and save_pred
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
res = inferencer(
|
||||
img_paths, out_dir=tmp_dir, save_vis=True, save_pred=True)
|
||||
for img_dir in ['img_1.jpg', 'img_2.jpg']:
|
||||
self.assertTrue(osp.exists(osp.join(tmp_dir, 'vis', img_dir)))
|
||||
for i, pred_dir in enumerate(['img_1.json', 'img_2.json']):
|
||||
dumped_res = mmengine.load(
|
||||
osp.join(tmp_dir, 'preds', pred_dir))
|
||||
self.assert_predictions_equal(res['predictions'][i],
|
||||
dumped_res)
|
||||
|
||||
@mock.patch('mmengine.infer.infer._load_checkpoint')
|
||||
def test_rec(self, mock_load):
|
||||
mock_load.side_effect = lambda *x, **y: None
|
||||
|
@ -129,6 +141,18 @@ class TestMMOCRInferencer(TestCase):
|
|||
np.allclose(res_img_ndarray['visualization'][0],
|
||||
res_img_path['visualization'][0]))
|
||||
|
||||
# test save_vis and save_pred
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
res = inferencer(
|
||||
img_paths, out_dir=tmp_dir, save_vis=True, save_pred=True)
|
||||
for img_dir in ['1036169.jpg', '1058891.jpg']:
|
||||
self.assertTrue(osp.exists(osp.join(tmp_dir, 'vis', img_dir)))
|
||||
for i, pred_dir in enumerate(['1036169.json', '1058891.json']):
|
||||
dumped_res = mmengine.load(
|
||||
osp.join(tmp_dir, 'preds', pred_dir))
|
||||
self.assert_predictions_equal(res['predictions'][i],
|
||||
dumped_res)
|
||||
|
||||
@mock.patch('mmengine.infer.infer._load_checkpoint')
|
||||
def test_det_rec(self, mock_load):
|
||||
mock_load.side_effect = lambda *x, **y: None
|
||||
|
@ -166,6 +190,21 @@ class TestMMOCRInferencer(TestCase):
|
|||
np.allclose(res_img_ndarray['visualization'][0],
|
||||
res_img_path['visualization'][0]))
|
||||
|
||||
# test save_vis and save_pred
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
res = inferencer(
|
||||
img_paths, out_dir=tmp_dir, save_vis=True, save_pred=True)
|
||||
for img_dir in ['img_1.jpg', 'img_2.jpg']:
|
||||
self.assertTrue(osp.exists(osp.join(tmp_dir, 'vis', img_dir)))
|
||||
for i, pred_dir in enumerate(['img_1.json', 'img_2.json']):
|
||||
dumped_res = mmengine.load(
|
||||
osp.join(tmp_dir, 'preds', pred_dir))
|
||||
self.assert_predictions_equal(res['predictions'][i],
|
||||
dumped_res)
|
||||
|
||||
# corner case: when the det model cannot detect any texts
|
||||
inferencer(np.zeros((100, 100, 3)), return_vis=True)
|
||||
|
||||
@mock.patch('mmengine.infer.infer._load_checkpoint')
|
||||
def test_dec_rec_kie(self, mock_load):
|
||||
mock_load.side_effect = lambda *x, **y: None
|
||||
|
@ -205,18 +244,14 @@ class TestMMOCRInferencer(TestCase):
|
|||
np.allclose(res_img_ndarray['visualization'][0],
|
||||
res_img_path['visualization'][0]))
|
||||
|
||||
# test visualization
|
||||
# img_out_dir
|
||||
# test save_vis and save_pred
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
inferencer(img_paths, img_out_dir=tmp_dir)
|
||||
for img_dir in ['00000006.jpg', '00000007.jpg']:
|
||||
self.assertTrue(osp.exists(osp.join(tmp_dir, img_dir)))
|
||||
|
||||
# pred_out_file
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pred_out_file = osp.join(tmp_dir, 'tmp.pkl')
|
||||
res = inferencer(
|
||||
img_path, print_result=True, pred_out_file=pred_out_file)
|
||||
dumped_res = mmengine.load(pred_out_file)
|
||||
self.assert_predictions_equal(res['predictions'],
|
||||
dumped_res['predictions'])
|
||||
img_paths, out_dir=tmp_dir, save_vis=True, save_pred=True)
|
||||
for img_dir in ['1.jpg', '2.jpg']:
|
||||
self.assertTrue(osp.exists(osp.join(tmp_dir, 'vis', img_dir)))
|
||||
for i, pred_dir in enumerate(['1.json', '2.json']):
|
||||
dumped_res = mmengine.load(
|
||||
osp.join(tmp_dir, 'preds', pred_dir))
|
||||
self.assert_predictions_equal(res['predictions'][i],
|
||||
dumped_res)
|
||||
|
|
|
@ -41,9 +41,11 @@ class TestTextDetinferencer(TestCase):
|
|||
|
||||
def assert_predictions_equal(self, preds1, preds2):
|
||||
for pred1, pred2 in zip(preds1, preds2):
|
||||
self.assertTrue(
|
||||
np.allclose(pred1['polygons'], pred2['polygons'], 0.1))
|
||||
self.assertTrue(np.allclose(pred1['scores'], pred2['scores'], 0.1))
|
||||
self.assert_prediction_equal(pred1, pred2)
|
||||
|
||||
def assert_prediction_equal(self, pred1, pred2):
|
||||
self.assertTrue(np.allclose(pred1['polygons'], pred2['polygons'], 0.1))
|
||||
self.assertTrue(np.allclose(pred1['scores'], pred2['scores'], 0.1))
|
||||
|
||||
def test_call(self):
|
||||
# single img
|
||||
|
@ -91,9 +93,9 @@ class TestTextDetinferencer(TestCase):
|
|||
|
||||
# img_out_dir
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
self.inferencer(img_paths, img_out_dir=tmp_dir)
|
||||
self.inferencer(img_paths, out_dir=tmp_dir, save_vis=True)
|
||||
for img_dir in ['img_1.jpg', 'img_2.jpg']:
|
||||
self.assertTrue(osp.exists(osp.join(tmp_dir, img_dir)))
|
||||
self.assertTrue(osp.exists(osp.join(tmp_dir, 'vis', img_dir)))
|
||||
|
||||
def test_postprocess(self):
|
||||
# return_datasample
|
||||
|
@ -101,14 +103,13 @@ class TestTextDetinferencer(TestCase):
|
|||
res = self.inferencer(img_path, return_datasamples=True)
|
||||
self.assertTrue(is_type_list(res['predictions'], TextDetDataSample))
|
||||
|
||||
# pred_out_file
|
||||
# dump predictions
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pred_out_file = osp.join(tmp_dir, 'tmp.pkl')
|
||||
res = self.inferencer(
|
||||
img_path, print_result=True, pred_out_file=pred_out_file)
|
||||
dumped_res = mmengine.load(pred_out_file)
|
||||
self.assert_predictions_equal(res['predictions'],
|
||||
dumped_res['predictions'])
|
||||
img_path, print_result=True, out_dir=tmp_dir, save_pred=True)
|
||||
dumped_res = mmengine.load(
|
||||
osp.join(tmp_dir, 'preds', 'img_1.json'))
|
||||
self.assert_prediction_equal(res['predictions'][0], dumped_res)
|
||||
|
||||
def test_pred2dict(self):
|
||||
data_sample = TextDetDataSample()
|
||||
|
|
|
@ -40,8 +40,11 @@ class TestTextRecinferencer(TestCase):
|
|||
|
||||
def assert_predictions_equal(self, preds1, preds2):
|
||||
for pred1, pred2 in zip(preds1, preds2):
|
||||
self.assertEqual(pred1['text'], pred2['text'])
|
||||
self.assertTrue(np.allclose(pred1['scores'], pred2['scores'], 0.1))
|
||||
self.assert_prediction_equal(pred1, pred2)
|
||||
|
||||
def assert_prediction_equal(self, pred1, pred2):
|
||||
self.assertEqual(pred1['text'], pred2['text'])
|
||||
self.assertTrue(np.allclose(pred1['scores'], pred2['scores'], 0.1))
|
||||
|
||||
def test_call(self):
|
||||
# single img
|
||||
|
@ -86,9 +89,9 @@ class TestTextRecinferencer(TestCase):
|
|||
|
||||
# img_out_dir
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
self.inferencer(img_paths, img_out_dir=tmp_dir)
|
||||
self.inferencer(img_paths, out_dir=tmp_dir, save_vis=True)
|
||||
for img_dir in ['1036169.jpg', '1058891.jpg']:
|
||||
self.assertTrue(osp.exists(osp.join(tmp_dir, img_dir)))
|
||||
self.assertTrue(osp.exists(osp.join(tmp_dir, 'vis', img_dir)))
|
||||
|
||||
def test_postprocess(self):
|
||||
# return_datasample
|
||||
|
@ -98,9 +101,8 @@ class TestTextRecinferencer(TestCase):
|
|||
|
||||
# pred_out_file
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pred_out_file = osp.join(tmp_dir, 'tmp.pkl')
|
||||
res = self.inferencer(
|
||||
img_path, print_result=True, pred_out_file=pred_out_file)
|
||||
dumped_res = mmengine.load(pred_out_file)
|
||||
self.assert_predictions_equal(res['predictions'],
|
||||
dumped_res['predictions'])
|
||||
img_path, print_result=True, out_dir=tmp_dir, save_pred=True)
|
||||
dumped_res = mmengine.load(
|
||||
osp.join(tmp_dir, 'preds', '1036169.json'))
|
||||
self.assert_prediction_equal(res['predictions'][0], dumped_res)
|
||||
|
|
Loading…
Reference in New Issue