[Enhancement] Support batch visualization & dumping in Inferencer (#1722)

* [Enhancement] Support batch visualization & dumping in Inferencer

* fix empty det output

* Update mmocr/apis/inferencers/base_mmocr_inferencer.py

Co-authored-by: liukuikun <24622904+Harold-lkk@users.noreply.github.com>

---------

Co-authored-by: liukuikun <24622904+Harold-lkk@users.noreply.github.com>
pull/1737/head
Tong Gao 2023-02-17 12:40:09 +08:00 committed by GitHub
parent 1127240108
commit e9bf689f74
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 393 additions and 285 deletions

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from typing import Dict, List, Optional, Sequence, Tuple, Union
from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union
import mmcv
import mmengine
@ -9,6 +9,7 @@ from mmengine.dataset import Compose
from mmengine.infer.infer import BaseInferencer, ModelType
from mmengine.registry import init_default_scope
from mmengine.structures import InstanceData
from rich.progress import track
from torch import Tensor
from mmocr.utils import ConfigType
@ -44,10 +45,10 @@ class BaseMMOCRInferencer(BaseInferencer):
forward_kwargs: set = set()
visualize_kwargs: set = {
'return_vis', 'show', 'wait_time', 'draw_pred', 'pred_score_thr',
'img_out_dir'
'save_vis'
}
postprocess_kwargs: set = {
'print_result', 'pred_out_file', 'return_datasample'
'print_result', 'return_datasample', 'save_pred'
}
loading_transforms: list = ['LoadImageFromFile', 'LoadImageFromNDArray']
@ -55,26 +56,69 @@ class BaseMMOCRInferencer(BaseInferencer):
model: Union[ModelType, str, None] = None,
weights: Optional[str] = None,
device: Optional[str] = None,
scope: Optional[str] = 'mmocr') -> None:
# A global counter tracking the number of images processed, for
# naming of the output images
self.num_visualized_imgs = 0
scope: str = 'mmocr') -> None:
# A global counter tracking the number of images given in the form
# of ndarray, for naming the output images
self.num_unnamed_imgs = 0
init_default_scope(scope)
super().__init__(
model=model, weights=weights, device=device, scope=scope)
def preprocess(self, inputs: InputsType, batch_size: int = 1, **kwargs):
"""Process the inputs into a model-feedable format.
Args:
inputs (InputsType): Inputs given by user.
batch_size (int): batch size. Defaults to 1.
Yields:
Any: Data processed by the ``pipeline`` and ``collate_fn``.
"""
chunked_data = self._get_chunk_data(inputs, batch_size)
yield from map(self.collate_fn, chunked_data)
def _get_chunk_data(self, inputs: Iterable, chunk_size: int):
"""Get batch data from inputs.
Args:
inputs (Iterable): An iterable dataset.
chunk_size (int): Equivalent to batch size.
Yields:
list: batch data.
"""
inputs_iter = iter(inputs)
while True:
try:
chunk_data = []
for _ in range(chunk_size):
inputs_ = next(inputs_iter)
pipe_out = self.pipeline(inputs_)
if pipe_out['data_samples'].get('img_path') is None:
pipe_out['data_samples'].set_metainfo(
dict(img_path=f'{self.num_unnamed_imgs}.jpg'))
self.num_unnamed_imgs += 1
chunk_data.append((inputs_, pipe_out))
yield chunk_data
except StopIteration:
if chunk_data:
yield chunk_data
break
def __call__(self,
inputs: InputsType,
return_datasamples: bool = False,
batch_size: int = 1,
progress_bar: bool = True,
return_vis: bool = False,
show: bool = False,
wait_time: int = 0,
draw_pred: bool = True,
pred_score_thr: float = 0.3,
img_out_dir: str = '',
out_dir: str = 'results/',
save_vis: bool = False,
save_pred: bool = False,
print_result: bool = False,
pred_out_file: str = '',
**kwargs) -> dict:
"""Call the inferencer.
@ -85,6 +129,8 @@ class BaseMMOCRInferencer(BaseInferencer):
return_datasamples (bool): Whether to return results as
:obj:`BaseDataElement`. Defaults to False.
batch_size (int): Inference batch size. Defaults to 1.
progress_bar (bool): Whether to show a progress bar. Defaults to
True.
return_vis (bool): Whether to return the visualization result.
Defaults to False.
show (bool): Whether to display the visualization results in a
@ -94,8 +140,11 @@ class BaseMMOCRInferencer(BaseInferencer):
Defaults to True.
pred_score_thr (float): Minimum score of bboxes to draw.
Defaults to 0.3.
img_out_dir (str): Output directory of visualization results.
If left as empty, no file will be saved. Defaults to ''.
out_dir (str): Output directory of results. Defaults to 'results/'.
save_vis (bool): Whether to save the visualization results to
"out_dir". Defaults to False.
save_pred (bool): Whether to save the inference results to
"out_dir". Defaults to False.
print_result (bool): Whether to print the inference result w/o
visualization to the console. Defaults to False.
pred_out_file: File to save the inference results w/o
@ -109,22 +158,53 @@ class BaseMMOCRInferencer(BaseInferencer):
and ``postprocess_kwargs``.
Returns:
dict: Inference and visualization results.
dict: Inference and visualization results, mapped from
"predictions" and "visualization".
"""
return super().__call__(
inputs,
return_datasamples,
batch_size,
if (save_vis or save_pred) and not out_dir:
raise ValueError('out_dir must be specified when save_vis or '
'save_pred is True!')
if out_dir:
img_out_dir = osp.join(out_dir, 'vis')
pred_out_dir = osp.join(out_dir, 'preds')
else:
img_out_dir, pred_out_dir = '', ''
(
preprocess_kwargs,
forward_kwargs,
visualize_kwargs,
postprocess_kwargs,
) = self._dispatch_kwargs(
return_vis=return_vis,
show=show,
wait_time=wait_time,
draw_pred=draw_pred,
pred_score_thr=pred_score_thr,
img_out_dir=img_out_dir,
save_vis=save_vis,
save_pred=save_pred,
print_result=print_result,
pred_out_file=pred_out_file,
**kwargs)
ori_inputs = self._inputs_to_list(inputs)
inputs = self.preprocess(
ori_inputs, batch_size=batch_size, **preprocess_kwargs)
results = {'predictions': [], 'visualization': []}
for ori_inputs, data in track(
inputs, description='Inference', disable=not progress_bar):
preds = self.forward(data, **forward_kwargs)
visualization = self.visualize(
ori_inputs, preds, img_out_dir=img_out_dir, **visualize_kwargs)
batch_res = self.postprocess(
preds,
visualization,
return_datasamples,
pred_out_dir=pred_out_dir,
**postprocess_kwargs)
results['predictions'].extend(batch_res['predictions'])
if batch_res['visualization'] is not None:
results['visualization'].extend(batch_res['visualization'])
return results
def _init_pipeline(self, cfg: ConfigType) -> Compose:
"""Initialize the test pipeline."""
pipeline_cfg = cfg.test_dataloader.dataset.pipeline
@ -170,6 +250,7 @@ class BaseMMOCRInferencer(BaseInferencer):
wait_time: int = 0,
draw_pred: bool = True,
pred_score_thr: float = 0.3,
save_vis: bool = False,
img_out_dir: str = '') -> Union[List[np.ndarray], None]:
"""Visualize predictions.
@ -185,6 +266,8 @@ class BaseMMOCRInferencer(BaseInferencer):
Defaults to True.
pred_score_thr (float): Minimum score of bboxes to draw.
Defaults to 0.3.
save_vis (bool): Whether to save the visualization result. Defaults
to False.
img_out_dir (str): Output directory of visualization results.
If left as empty, no file will be saved. Defaults to ''.
@ -192,8 +275,7 @@ class BaseMMOCRInferencer(BaseInferencer):
List[np.ndarray] or None: Returns visualization results only if
applicable.
"""
if self.visualizer is None or (not show and img_out_dir == ''
and not return_vis):
if self.visualizer is None or not (show or save_vis or return_vis):
return None
if getattr(self, 'visualizer') is None:
@ -206,17 +288,19 @@ class BaseMMOCRInferencer(BaseInferencer):
if isinstance(single_input, str):
img_bytes = mmengine.fileio.get(single_input)
img = mmcv.imfrombytes(img_bytes, channel_order='rgb')
img_name = osp.basename(single_input)
elif isinstance(single_input, np.ndarray):
img = single_input.copy()[:, :, ::-1] # to RGB
img_num = str(self.num_visualized_imgs).zfill(8)
img_name = f'{img_num}.jpg'
else:
raise ValueError('Unsupported input type: '
f'{type(single_input)}')
img_name = osp.splitext(osp.basename(pred.img_path))[0]
out_file = osp.join(img_out_dir, img_name) if img_out_dir != '' \
else None
if save_vis and img_out_dir:
out_file = osp.splitext(img_name)[0]
out_file = f'{out_file}.jpg'
out_file = osp.join(img_out_dir, out_file)
else:
out_file = None
visualization = self.visualizer.add_datasample(
img_name,
@ -230,7 +314,6 @@ class BaseMMOCRInferencer(BaseInferencer):
out_file=out_file,
)
results.append(visualization)
self.num_visualized_imgs += 1
return results
@ -240,7 +323,8 @@ class BaseMMOCRInferencer(BaseInferencer):
visualization: Optional[List[np.ndarray]] = None,
return_datasample: bool = False,
print_result: bool = False,
pred_out_file: str = '',
save_pred: bool = False,
pred_out_dir: str = '',
) -> Union[ResType, Tuple[ResType, np.ndarray]]:
"""Process the predictions and visualization results from ``forward``
and ``visualize``.
@ -258,7 +342,9 @@ class BaseMMOCRInferencer(BaseInferencer):
inference results. If False, dict will be used.
print_result (bool): Whether to print the inference result w/o
visualization to the console. Defaults to False.
pred_out_file: File to save the inference results w/o
save_pred (bool): Whether to save the inference result. Defaults to
False.
pred_out_dir: File to save the inference results w/o
visualization. If left as empty, no file will be saved.
Defaults to ''.
@ -279,13 +365,16 @@ class BaseMMOCRInferencer(BaseInferencer):
results = []
for pred in preds:
result = self.pred2dict(pred)
if save_pred and pred_out_dir:
pred_name = osp.splitext(osp.basename(pred.img_path))[0]
pred_name = f'{pred_name}.json'
pred_out_file = osp.join(pred_out_dir, pred_name)
mmengine.dump(result, pred_out_file)
results.append(result)
# Add img to the results after printing and dumping
result_dict['predictions'] = results
if print_result:
print(result_dict)
if pred_out_file != '':
mmengine.dump(result_dict, pred_out_file)
result_dict['visualization'] = visualization
return result_dict

View File

@ -1,12 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os.path as osp
from typing import Dict, List, Optional, Sequence, Union
from typing import Any, Dict, List, Optional, Sequence, Union
import mmcv
import mmengine
import numpy as np
from mmengine.dataset import Compose
from mmengine.dataset import Compose, pseudo_collate
from mmengine.runner.checkpoint import _load_checkpoint
from mmocr.registry import DATASETS
@ -45,6 +45,7 @@ class KIEInferencer(BaseMMOCRInferencer):
super().__init__(
model=model, weights=weights, device=device, scope=scope)
self._load_metainfo_to_visualizer(weights, self.cfg)
self.collate_fn = self.kie_collate
def _load_metainfo_to_visualizer(self, weights: Optional[str],
cfg: ConfigType) -> None:
@ -90,6 +91,21 @@ class KIEInferencer(BaseMMOCRInferencer):
return Compose(pipeline_cfg)
return super()._init_pipeline(cfg)
@staticmethod
def kie_collate(data_batch: Sequence) -> Any:
"""A collate function designed for KIE, where the first element (input)
is a dict and we only want to keep it as-is instead of batching
elements inside.
Returns:
Any: Transversed Data in the same format as the data_itement of
``data_batch``.
""" # noqa: E501
transposed = list(zip(*data_batch))
for i in range(1, len(transposed)):
transposed[i] = pseudo_collate(transposed[i])
return transposed
def _inputs_to_list(self, inputs: InputsType) -> list:
"""Preprocess the inputs to a list.
@ -167,94 +183,9 @@ class KIEInferencer(BaseMMOCRInferencer):
else:
atype = type(single_input['img'])
raise ValueError(f'Unsupported input type: {atype}')
return processed_inputs
def __call__(self,
inputs: InputsType,
return_datasamples: bool = False,
batch_size: int = 1,
return_vis: bool = False,
show: bool = False,
wait_time: int = 0,
draw_pred: bool = True,
pred_score_thr: float = 0.3,
img_out_dir: str = '',
print_result: bool = False,
pred_out_file: str = '',
**kwargs) -> dict:
"""Call the inferencer.
The inputs for KIE Inferencer is special compared to other tasks.
They can be a dict or list[dict], where each dictionary contains
following keys:
- img (str or ndarray): Path to the image or the image itself. If KIE
Inferencer is used in no-visual mode, this key is not required.
Note: If it's an numpy array, it should be in BGR order.
- img_shape (tuple(int, int)): Image shape in (H, W). In
- instances (list[dict]): A list of instances.
- bbox (ndarray(dtype=np.float32)): Shape (4, ). Bounding box.
- text (str): Annotation text.
Each ``instance`` looks like the following:
.. code-block:: python
{
# A nested list of 4 numbers representing the bounding box of
# the instance, in (x1, y1, x2, y2) order.
'bbox': np.array([[x1, y1, x2, y2], [x1, y1, x2, y2], ...],
dtype=np.int32),
# List of texts.
"texts": ['text1', 'text2', ...],
}
Args:
inputs (InputsType): Inputs for the inferencer.
return_datasamples (bool): Whether to return results as
:obj:`BaseDataElement`. Defaults to False.
batch_size (int): Inference batch size. Defaults to 1.
return_vis (bool): Whether to return the visualization result.
Defaults to False.
show (bool): Whether to display the visualization results in a
popup window. Defaults to False.
wait_time (float): The interval of show (s). Defaults to 0.
draw_pred (bool): Whether to draw predicted bounding boxes.
Defaults to True.
pred_score_thr (float): Minimum score of bboxes to draw.
Defaults to 0.3.
img_out_dir (str): Output directory of visualization results.
If left as empty, no file will be saved. Defaults to ''.
print_result (bool): Whether to print the inference result w/o
visualization to the console. Defaults to False.
pred_out_file: File to save the inference results w/o
visualization. If left as empty, no file will be saved.
Defaults to ''.
**kwargs: Other keyword arguments passed to :meth:`preprocess`,
:meth:`forward`, :meth:`visualize` and :meth:`postprocess`.
Each key in kwargs should be in the corresponding set of
``preprocess_kwargs``, ``forward_kwargs``, ``visualize_kwargs``
and ``postprocess_kwargs``.
Returns:
dict: Inference and visualization results.
"""
return super().__call__(
inputs,
return_datasamples,
batch_size,
return_vis=return_vis,
show=show,
wait_time=wait_time,
draw_pred=draw_pred,
pred_score_thr=pred_score_thr,
img_out_dir=img_out_dir,
print_result=print_result,
pred_out_file=pred_out_file,
**kwargs)
def visualize(self,
inputs: InputsType,
preds: PredType,
@ -263,7 +194,8 @@ class KIEInferencer(BaseMMOCRInferencer):
wait_time: int = 0,
draw_pred: bool = True,
pred_score_thr: float = 0.3,
img_out_dir: str = '') -> List[np.ndarray]:
save_vis: bool = False,
img_out_dir: str = '') -> Union[List[np.ndarray], None]:
"""Visualize predictions.
Args:
@ -278,10 +210,16 @@ class KIEInferencer(BaseMMOCRInferencer):
Defaults to True.
pred_score_thr (float): Minimum score of bboxes to draw.
Defaults to 0.3.
img_out_dir (str): Output directory of images. Defaults to ''.
save_vis (bool): Whether to save the visualization result. Defaults
to False.
img_out_dir (str): Output directory of visualization results.
If left as empty, no file will be saved. Defaults to ''.
Returns:
List[np.ndarray] or None: Returns visualization results only if
applicable.
"""
if self.visualizer is None or (not show and img_out_dir == ''
and not return_vis):
if self.visualizer is None or not (show or save_vis or return_vis):
return None
if getattr(self, 'visualizer') is None:
@ -291,24 +229,26 @@ class KIEInferencer(BaseMMOCRInferencer):
results = []
for single_input, pred in zip(inputs, preds):
img_num = str(self.num_visualized_imgs).zfill(8)
assert 'img' in single_input or 'img_shape' in single_input
if 'img' in single_input:
if isinstance(single_input['img'], str):
img = mmcv.imread(single_input['img'], channel_order='rgb')
img_name = osp.basename(single_input['img'])
img_bytes = mmengine.fileio.get(single_input['img'])
img = mmcv.imfrombytes(img_bytes, channel_order='rgb')
elif isinstance(single_input['img'], np.ndarray):
img = single_input['img'].copy()[:, :, ::-1] # To RGB
img_name = f'{img_num}.jpg'
elif 'img_shape' in single_input:
img = np.zeros(single_input['img_shape'], dtype=np.uint8)
img_name = f'{img_num}.jpg'
else:
raise ValueError('Input does not contain either "img" or '
'"img_shape"')
img_name = osp.splitext(osp.basename(pred.img_path))[0]
out_file = osp.join(img_out_dir, img_name) if img_out_dir != '' \
else None
if save_vis and img_out_dir:
out_file = osp.splitext(img_name)[0]
out_file = f'{out_file}.jpg'
out_file = osp.join(img_out_dir, out_file)
else:
out_file = None
visualization = self.visualizer.add_datasample(
img_name,
@ -322,7 +262,6 @@ class KIEInferencer(BaseMMOCRInferencer):
out_file=out_file,
)
results.append(visualization)
self.num_visualized_imgs += 1
return results

View File

@ -1,16 +1,15 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os.path as osp
from datetime import datetime
from typing import Dict, List, Optional, Tuple, Union
import mmcv
import mmengine
import numpy as np
from mmengine.fileio import (get_file_backend, isdir, join_path,
list_dir_or_file)
from rich.progress import track
from mmocr.registry import VISUALIZERS
from mmocr.structures.textdet_data_sample import TextDetDataSample
from mmocr.structures import TextSpottingDataSample
from mmocr.utils import ConfigType, bbox2poly, crop_img, poly2bbox
from .base_mmocr_inferencer import (BaseMMOCRInferencer, InputsType, PredType,
ResType)
@ -65,7 +64,6 @@ class MMOCRInferencer(BaseMMOCRInferencer):
'provided.')
self.visualizer = None
self.num_visualized_imgs = 0
if det is not None:
self.textdet_inferencer = TextDetInferencer(
@ -93,43 +91,19 @@ class MMOCRInferencer(BaseMMOCRInferencer):
self.kie_inferencer = KIEInferencer(kie, kie_weights, device)
self.mode = 'det_rec_kie'
def _inputs_to_list(self, inputs: InputsType) -> list:
"""Preprocess the inputs to a list.
Preprocess inputs to a list according to its type:
- list or tuple: return inputs
- str:
- Directory path: return all files in the directory
- normal string: return a list containing the string
Args:
inputs (InputsType): Inputs for the inferencer.
Returns:
list: List of input for the :meth:`preprocess`.
"""
inputs = copy.deepcopy(inputs)
if isinstance(inputs, str):
backend = get_file_backend(inputs)
if hasattr(backend, 'isdir') and isdir(inputs):
# Backends like HttpsBackend do not implement `isdir`, so only
# those backends that implement `isdir` could accept the inputs
# as a directory
filename_list = list_dir_or_file(inputs, list_dir=False)
inputs = [
join_path(inputs, filename) for filename in filename_list
]
if not isinstance(inputs, (list, tuple)):
inputs = [inputs]
for i in range(len(inputs)):
if not isinstance(inputs[i], np.ndarray):
img_bytes = mmengine.fileio.get(inputs[i])
inputs[i] = mmcv.imfrombytes(img_bytes)
return list(inputs)
def _inputs2ndarrray(self, inputs: List[InputsType]) -> List[np.ndarray]:
"""Preprocess the inputs to a list of numpy arrays."""
new_inputs = []
for item in inputs:
if isinstance(item, np.ndarray):
new_inputs.append(item)
elif isinstance(item, str):
img_bytes = mmengine.fileio.get(item)
new_inputs.append(mmcv.imfrombytes(img_bytes))
else:
raise NotImplementedError(f'The input type {type(item)} is not'
'supported yet.')
return new_inputs
def forward(self, inputs: InputsType, batch_size: int,
**forward_kwargs) -> PredType:
@ -144,6 +118,7 @@ class MMOCRInferencer(BaseMMOCRInferencer):
"kie"..
"""
result = {}
forward_kwargs['progress_bar'] = False
if self.mode == 'rec':
# The extra list wrapper here is for the ease of postprocessing
self.rec_inputs = inputs
@ -153,15 +128,16 @@ class MMOCRInferencer(BaseMMOCRInferencer):
batch_size=batch_size,
**forward_kwargs)['predictions']
result['rec'] = [[p] for p in predictions]
elif self.mode.startswith('det'):
elif self.mode.startswith('det'): # 'det'/'det_rec'/'det_rec_kie'
result['det'] = self.textdet_inferencer(
inputs,
return_datasamples=True,
batch_size=batch_size,
**forward_kwargs)['predictions']
if self.mode.startswith('det_rec'):
if self.mode.startswith('det_rec'): # 'det_rec'/'det_rec_kie'
result['rec'] = []
for img, det_data_sample in zip(inputs, result['det']):
for img, det_data_sample in zip(
self._inputs2ndarrray(inputs), result['det']):
det_pred = det_data_sample.pred_instances
self.rec_inputs = []
for polygon in det_pred['polygons']:
@ -177,6 +153,10 @@ class MMOCRInferencer(BaseMMOCRInferencer):
**forward_kwargs)['predictions'])
if self.mode == 'det_rec_kie':
self.kie_inputs = []
# TODO: when the det output is empty, kie will fail
# as no gt-instances can be provided. It's a known
# issue but cannot be solved elegantly since we support
# batch inference.
for img, det_data_sample, rec_data_samples in zip(
inputs, result['det'], result['rec']):
det_pred = det_data_sample.pred_instances
@ -197,7 +177,7 @@ class MMOCRInferencer(BaseMMOCRInferencer):
return result
def visualize(self, inputs: InputsType, preds: PredType,
**kwargs) -> List[np.ndarray]:
**kwargs) -> Union[List[np.ndarray], None]:
"""Visualize predictions.
Args:
@ -210,7 +190,14 @@ class MMOCRInferencer(BaseMMOCRInferencer):
Defaults to True.
pred_score_thr (float): Minimum score of bboxes to draw.
Defaults to 0.3.
img_out_dir (str): Output directory of images. Defaults to ''.
save_vis (bool): Whether to save the visualization result. Defaults
to False.
img_out_dir (str): Output directory of visualization results.
If left as empty, no file will be saved. Defaults to ''.
Returns:
List[np.ndarray] or None: Returns visualization results only if
applicable.
"""
if 'kie' in self.mode:
@ -232,6 +219,9 @@ class MMOCRInferencer(BaseMMOCRInferencer):
self,
inputs: InputsType,
batch_size: int = 1,
out_dir: str = 'results/',
save_vis: bool = False,
save_pred: bool = False,
**kwargs,
) -> dict:
"""Call the inferencer.
@ -239,9 +229,12 @@ class MMOCRInferencer(BaseMMOCRInferencer):
Args:
inputs (InputsType): Inputs for the inferencer. It can be a path
to image / image directory, or an array, or a list of these.
return_datasamples (bool): Whether to return results as
:obj:`BaseDataElement`. Defaults to False.
batch_size (int): Batch size. Defaults to 1.
out_dir (str): Output directory of results. Defaults to 'results/'.
save_vis (bool): Whether to save the visualization results to
"out_dir". Defaults to False.
save_pred (bool): Whether to save the inference results to
"out_dir". Defaults to False.
**kwargs: Key words arguments passed to :meth:`preprocess`,
:meth:`forward`, :meth:`visualize` and :meth:`postprocess`.
Each key in kwargs should be in the corresponding set of
@ -249,47 +242,75 @@ class MMOCRInferencer(BaseMMOCRInferencer):
and ``postprocess_kwargs``.
Returns:
dict: Inference and visualization results.
dict: Inference and visualization results, mapped from
"predictions" and "visualization".
"""
if (save_vis or save_pred) and not out_dir:
raise ValueError('out_dir must be specified when save_vis or '
'save_pred is True!')
if out_dir:
img_out_dir = osp.join(out_dir, 'vis')
pred_out_dir = osp.join(out_dir, 'preds')
else:
img_out_dir, pred_out_dir = '', ''
(
preprocess_kwargs,
forward_kwargs,
visualize_kwargs,
postprocess_kwargs,
) = self._dispatch_kwargs(**kwargs)
) = self._dispatch_kwargs(
save_vis=save_vis, save_pred=save_pred, **kwargs)
ori_inputs = self._inputs_to_list(inputs)
preds = self.forward(ori_inputs, batch_size, **forward_kwargs)
visualization = self.visualize(
ori_inputs, preds,
**visualize_kwargs) # type: ignore # noqa: E501
results = self.postprocess(preds, visualization, **postprocess_kwargs)
chunked_inputs = super(BaseMMOCRInferencer,
self)._get_chunk_data(ori_inputs, batch_size)
results = {'predictions': [], 'visualization': []}
for ori_input in track(chunked_inputs, description='Inference'):
preds = self.forward(ori_input, batch_size, **forward_kwargs)
visualization = self.visualize(
ori_input, preds, img_out_dir=img_out_dir, **visualize_kwargs)
batch_res = self.postprocess(
preds,
visualization,
pred_out_dir=pred_out_dir,
**postprocess_kwargs)
results['predictions'].extend(batch_res['predictions'])
if batch_res['visualization'] is not None:
results['visualization'].extend(batch_res['visualization'])
return results
def postprocess(self,
preds: PredType,
visualization: Optional[List[np.ndarray]] = None,
print_result: bool = False,
pred_out_file: str = ''
save_pred: bool = False,
pred_out_dir: str = ''
) -> Union[ResType, Tuple[ResType, np.ndarray]]:
"""Postprocess predictions.
"""Process the predictions and visualization results from ``forward``
and ``visualize``.
This method should be responsible for the following tasks:
1. Convert datasamples into a json-serializable dict if needed.
2. Pack the predictions and visualization results and return them.
3. Dump or log the predictions.
Args:
preds (Dict): Predictions of the model.
preds (PredType): Predictions of the model.
visualization (Optional[np.ndarray]): Visualized predictions.
print_result (bool): Whether to print the result.
Defaults to False.
pred_out_file (str): Output file name to store predictions
without images. Supported file formats are json, yaml/yml
and pickle/pkl. Defaults to ''.
save_pred (bool): Whether to save the inference result. Defaults to
False.
pred_out_dir: File to save the inference results w/o
visualization. If left as empty, no file will be saved.
Defaults to ''.
Returns:
Dict or List[Dict]: Each dict contains the inference result of
each image. Possible keys are "det_polygons", "det_scores",
"rec_texts", "rec_scores", "kie_labels", "kie_scores",
"kie_edge_labels" and "kie_edge_scores".
Dict: Inference and visualization results, mapped from
"predictions" and "visualization".
"""
result_dict = {}
@ -320,22 +341,28 @@ class MMOCRInferencer(BaseMMOCRInferencer):
kie_edge_scores=kie_dict_res['edge_scores'],
kie_edge_labels=kie_dict_res['edge_labels'])
if save_pred and pred_out_dir:
pred_key = 'det' if 'det' in self.mode else 'rec'
for pred, pred_result in zip(preds[pred_key], pred_results):
img_path = (
pred.img_path if pred_key == 'det' else pred[0].img_path)
pred_name = osp.splitext(osp.basename(img_path))[0]
pred_name = f'{pred_name}.json'
pred_out_file = osp.join(pred_out_dir, pred_name)
mmengine.dump(pred_result, pred_out_file)
result_dict['predictions'] = pred_results
if print_result:
print(result_dict)
if pred_out_file != '':
mmengine.dump(result_dict, pred_out_file)
result_dict['visualization'] = visualization
return result_dict
def _pack_e2e_datasamples(self, preds: Dict) -> List[TextDetDataSample]:
def _pack_e2e_datasamples(self,
preds: Dict) -> List[TextSpottingDataSample]:
"""Pack text detection and recognition results into a list of
TextDetDataSample.
Note that it is a temporary solution since the TextSpottingDataSample
is not ready.
"""
TextSpottingDataSample."""
results = []
for det_data_sample, rec_data_samples in zip(preds['det'],
preds['rec']):
texts = []

View File

@ -9,10 +9,10 @@ def parse_args():
parser.add_argument(
'inputs', type=str, help='Input image file or folder path.')
parser.add_argument(
'--img-out-dir',
'--out-dir',
type=str,
default='',
help='Output directory of images.')
default='results/',
help='Output directory of results.')
parser.add_argument(
'--det',
type=str,
@ -69,10 +69,13 @@ def parse_args():
action='store_true',
help='Whether to print the results.')
parser.add_argument(
'--pred-out-file',
type=str,
default='',
help='File to save the inference results.')
'--save_pred',
action='store_true',
help='Save the inference results to out_dir.')
parser.add_argument(
'--save_vis',
action='store_true',
help='Save the visualization results to out_dir.')
call_args = vars(parser.parse_args())

View File

@ -85,6 +85,8 @@ class BaseLocalVisualizer(Visualizer):
font_families (Union[str, List[str]]): The font families of labels.
Defaults to 'sans-serif'.
"""
if not labels and not bboxes:
return image
if colors is not None and isinstance(colors, (list, tuple)):
size = math.ceil(len(labels) / len(colors))
colors = (colors * size)[:len(labels)]

View File

@ -44,18 +44,19 @@ class TextSpottingLocalVisualizer(BaseLocalVisualizer):
img_shape = image.shape[:2]
empty_shape = (img_shape[0], img_shape[1], 3)
text_image = np.full(empty_shape, 255, dtype=np.uint8)
text_image = self.get_labels_image(
text_image,
labels=texts,
bboxes=bboxes,
font_families=self.font_families)
if texts:
text_image = self.get_labels_image(
text_image,
labels=texts,
bboxes=bboxes,
font_families=self.font_families)
if polygons:
polygons = [polygon.reshape(-1, 2) for polygon in polygons]
image = self.get_polygons_image(
image, polygons, filling=True, colors=self.PALETTE)
text_image = self.get_polygons_image(
text_image, polygons, colors=self.PALETTE)
else:
elif len(bboxes) > 0:
image = self.get_bboxes_image(
image, bboxes, filling=True, colors=self.PALETTE)
text_image = self.get_bboxes_image(
@ -103,27 +104,28 @@ class TextSpottingLocalVisualizer(BaseLocalVisualizer):
"""
cat_images = []
if draw_gt:
gt_bboxes = data_sample.gt_instances.get('bboxes', None)
gt_texts = data_sample.gt_instances.texts
gt_polygons = data_sample.gt_instances.get('polygons', None)
gt_img_data = self._draw_instances(image, gt_bboxes, gt_polygons,
gt_texts)
cat_images.append(gt_img_data)
if data_sample is not None:
if draw_gt and 'gt_instances' in data_sample:
gt_bboxes = data_sample.gt_instances.get('bboxes', None)
gt_texts = data_sample.gt_instances.texts
gt_polygons = data_sample.gt_instances.get('polygons', None)
gt_img_data = self._draw_instances(image, gt_bboxes,
gt_polygons, gt_texts)
cat_images.append(gt_img_data)
if draw_pred:
pred_instances = data_sample.pred_instances
pred_instances = pred_instances[
pred_instances.scores > pred_score_thr].cpu().numpy()
pred_bboxes = pred_instances.get('bboxes', None)
pred_texts = pred_instances.texts
pred_polygons = pred_instances.get('polygons', None)
if pred_bboxes is None:
pred_bboxes = [poly2bbox(poly) for poly in pred_polygons]
pred_bboxes = np.array(pred_bboxes)
pred_img_data = self._draw_instances(image, pred_bboxes,
pred_polygons, pred_texts)
cat_images.append(pred_img_data)
if draw_pred and 'pred_instances' in data_sample:
pred_instances = data_sample.pred_instances
pred_instances = pred_instances[
pred_instances.scores > pred_score_thr].cpu().numpy()
pred_bboxes = pred_instances.get('bboxes', None)
pred_texts = pred_instances.texts
pred_polygons = pred_instances.get('polygons', None)
if pred_bboxes is None:
pred_bboxes = [poly2bbox(poly) for poly in pred_polygons]
pred_bboxes = np.array(pred_bboxes)
pred_img_data = self._draw_instances(image, pred_bboxes,
pred_polygons, pred_texts)
cat_images.append(pred_img_data)
cat_images = self._cat_image(cat_images, axis=0)
if cat_images is None:

View File

@ -75,12 +75,15 @@ class TestKIEInferencer(TestCase):
def assert_predictions_equal(self, preds1, preds2):
for pred1, pred2 in zip(preds1, preds2):
self.assertTrue(np.allclose(pred1['labels'], pred2['labels'], 0.1))
self.assertTrue(
np.allclose(pred1['edge_scores'], pred2['edge_scores'], 0.1))
self.assertTrue(
np.allclose(pred1['edge_labels'], pred2['edge_labels'], 0.1))
self.assertTrue(np.allclose(pred1['scores'], pred2['scores'], 0.1))
self.assert_prediction_equal(pred1, pred2)
def assert_prediction_equal(self, pred1, pred2):
self.assertTrue(np.allclose(pred1['labels'], pred2['labels'], 0.1))
self.assertTrue(
np.allclose(pred1['edge_scores'], pred2['edge_scores'], 0.1))
self.assertTrue(
np.allclose(pred1['edge_labels'], pred2['edge_labels'], 0.1))
self.assertTrue(np.allclose(pred1['scores'], pred2['scores'], 0.1))
def test_call(self):
# no visual, single input
@ -130,9 +133,9 @@ class TestKIEInferencer(TestCase):
# img_out_dir
with tempfile.TemporaryDirectory() as tmp_dir:
self.inferencer(self.data_img_str, img_out_dir=tmp_dir)
for img_dir in ['00000000.jpg', '00000001.jpg']:
self.assertTrue(osp.exists(osp.join(tmp_dir, img_dir)))
self.inferencer(self.data_img_str, out_dir=tmp_dir, save_vis=True)
for img_dir in ['1.jpg', '2.jpg']:
self.assertTrue(osp.exists(osp.join(tmp_dir, 'vis', img_dir)))
def test_postprocess(self):
# return_datasample
@ -141,14 +144,19 @@ class TestKIEInferencer(TestCase):
# pred_out_file
with tempfile.TemporaryDirectory() as tmp_dir:
pred_out_file = osp.join(tmp_dir, 'tmp.pkl')
res = self.inferencer(
self.data_img_ndarray,
print_result=True,
pred_out_file=pred_out_file)
dumped_res = mmengine.load(pred_out_file)
self.assert_predictions_equal(res['predictions'],
dumped_res['predictions'])
out_dir=tmp_dir,
save_pred=True)
file_names = [
f'{self.inferencer.num_unnamed_imgs - i}.json'
for i in range(len(self.data_img_ndarray), 0, -1)
]
for pred, file_name in zip(res['predictions'], file_names):
dumped_res = mmengine.load(
osp.join(tmp_dir, 'preds', file_name))
self.assert_prediction_equal(dumped_res, pred)
@mock.patch('mmocr.apis.inferencers.kie_inferencer._load_checkpoint')
def test_load_metainfo_to_visualizer(self, mock_load):

View File

@ -95,6 +95,18 @@ class TestMMOCRInferencer(TestCase):
np.allclose(res_img_ndarray['visualization'][0],
res_img_path['visualization'][0]))
# test save_vis and save_pred
with tempfile.TemporaryDirectory() as tmp_dir:
res = inferencer(
img_paths, out_dir=tmp_dir, save_vis=True, save_pred=True)
for img_dir in ['img_1.jpg', 'img_2.jpg']:
self.assertTrue(osp.exists(osp.join(tmp_dir, 'vis', img_dir)))
for i, pred_dir in enumerate(['img_1.json', 'img_2.json']):
dumped_res = mmengine.load(
osp.join(tmp_dir, 'preds', pred_dir))
self.assert_predictions_equal(res['predictions'][i],
dumped_res)
@mock.patch('mmengine.infer.infer._load_checkpoint')
def test_rec(self, mock_load):
mock_load.side_effect = lambda *x, **y: None
@ -129,6 +141,18 @@ class TestMMOCRInferencer(TestCase):
np.allclose(res_img_ndarray['visualization'][0],
res_img_path['visualization'][0]))
# test save_vis and save_pred
with tempfile.TemporaryDirectory() as tmp_dir:
res = inferencer(
img_paths, out_dir=tmp_dir, save_vis=True, save_pred=True)
for img_dir in ['1036169.jpg', '1058891.jpg']:
self.assertTrue(osp.exists(osp.join(tmp_dir, 'vis', img_dir)))
for i, pred_dir in enumerate(['1036169.json', '1058891.json']):
dumped_res = mmengine.load(
osp.join(tmp_dir, 'preds', pred_dir))
self.assert_predictions_equal(res['predictions'][i],
dumped_res)
@mock.patch('mmengine.infer.infer._load_checkpoint')
def test_det_rec(self, mock_load):
mock_load.side_effect = lambda *x, **y: None
@ -166,6 +190,21 @@ class TestMMOCRInferencer(TestCase):
np.allclose(res_img_ndarray['visualization'][0],
res_img_path['visualization'][0]))
# test save_vis and save_pred
with tempfile.TemporaryDirectory() as tmp_dir:
res = inferencer(
img_paths, out_dir=tmp_dir, save_vis=True, save_pred=True)
for img_dir in ['img_1.jpg', 'img_2.jpg']:
self.assertTrue(osp.exists(osp.join(tmp_dir, 'vis', img_dir)))
for i, pred_dir in enumerate(['img_1.json', 'img_2.json']):
dumped_res = mmengine.load(
osp.join(tmp_dir, 'preds', pred_dir))
self.assert_predictions_equal(res['predictions'][i],
dumped_res)
# corner case: when the det model cannot detect any texts
inferencer(np.zeros((100, 100, 3)), return_vis=True)
@mock.patch('mmengine.infer.infer._load_checkpoint')
def test_dec_rec_kie(self, mock_load):
mock_load.side_effect = lambda *x, **y: None
@ -205,18 +244,14 @@ class TestMMOCRInferencer(TestCase):
np.allclose(res_img_ndarray['visualization'][0],
res_img_path['visualization'][0]))
# test visualization
# img_out_dir
# test save_vis and save_pred
with tempfile.TemporaryDirectory() as tmp_dir:
inferencer(img_paths, img_out_dir=tmp_dir)
for img_dir in ['00000006.jpg', '00000007.jpg']:
self.assertTrue(osp.exists(osp.join(tmp_dir, img_dir)))
# pred_out_file
with tempfile.TemporaryDirectory() as tmp_dir:
pred_out_file = osp.join(tmp_dir, 'tmp.pkl')
res = inferencer(
img_path, print_result=True, pred_out_file=pred_out_file)
dumped_res = mmengine.load(pred_out_file)
self.assert_predictions_equal(res['predictions'],
dumped_res['predictions'])
img_paths, out_dir=tmp_dir, save_vis=True, save_pred=True)
for img_dir in ['1.jpg', '2.jpg']:
self.assertTrue(osp.exists(osp.join(tmp_dir, 'vis', img_dir)))
for i, pred_dir in enumerate(['1.json', '2.json']):
dumped_res = mmengine.load(
osp.join(tmp_dir, 'preds', pred_dir))
self.assert_predictions_equal(res['predictions'][i],
dumped_res)

View File

@ -41,9 +41,11 @@ class TestTextDetinferencer(TestCase):
def assert_predictions_equal(self, preds1, preds2):
for pred1, pred2 in zip(preds1, preds2):
self.assertTrue(
np.allclose(pred1['polygons'], pred2['polygons'], 0.1))
self.assertTrue(np.allclose(pred1['scores'], pred2['scores'], 0.1))
self.assert_prediction_equal(pred1, pred2)
def assert_prediction_equal(self, pred1, pred2):
self.assertTrue(np.allclose(pred1['polygons'], pred2['polygons'], 0.1))
self.assertTrue(np.allclose(pred1['scores'], pred2['scores'], 0.1))
def test_call(self):
# single img
@ -91,9 +93,9 @@ class TestTextDetinferencer(TestCase):
# img_out_dir
with tempfile.TemporaryDirectory() as tmp_dir:
self.inferencer(img_paths, img_out_dir=tmp_dir)
self.inferencer(img_paths, out_dir=tmp_dir, save_vis=True)
for img_dir in ['img_1.jpg', 'img_2.jpg']:
self.assertTrue(osp.exists(osp.join(tmp_dir, img_dir)))
self.assertTrue(osp.exists(osp.join(tmp_dir, 'vis', img_dir)))
def test_postprocess(self):
# return_datasample
@ -101,14 +103,13 @@ class TestTextDetinferencer(TestCase):
res = self.inferencer(img_path, return_datasamples=True)
self.assertTrue(is_type_list(res['predictions'], TextDetDataSample))
# pred_out_file
# dump predictions
with tempfile.TemporaryDirectory() as tmp_dir:
pred_out_file = osp.join(tmp_dir, 'tmp.pkl')
res = self.inferencer(
img_path, print_result=True, pred_out_file=pred_out_file)
dumped_res = mmengine.load(pred_out_file)
self.assert_predictions_equal(res['predictions'],
dumped_res['predictions'])
img_path, print_result=True, out_dir=tmp_dir, save_pred=True)
dumped_res = mmengine.load(
osp.join(tmp_dir, 'preds', 'img_1.json'))
self.assert_prediction_equal(res['predictions'][0], dumped_res)
def test_pred2dict(self):
data_sample = TextDetDataSample()

View File

@ -40,8 +40,11 @@ class TestTextRecinferencer(TestCase):
def assert_predictions_equal(self, preds1, preds2):
for pred1, pred2 in zip(preds1, preds2):
self.assertEqual(pred1['text'], pred2['text'])
self.assertTrue(np.allclose(pred1['scores'], pred2['scores'], 0.1))
self.assert_prediction_equal(pred1, pred2)
def assert_prediction_equal(self, pred1, pred2):
self.assertEqual(pred1['text'], pred2['text'])
self.assertTrue(np.allclose(pred1['scores'], pred2['scores'], 0.1))
def test_call(self):
# single img
@ -86,9 +89,9 @@ class TestTextRecinferencer(TestCase):
# img_out_dir
with tempfile.TemporaryDirectory() as tmp_dir:
self.inferencer(img_paths, img_out_dir=tmp_dir)
self.inferencer(img_paths, out_dir=tmp_dir, save_vis=True)
for img_dir in ['1036169.jpg', '1058891.jpg']:
self.assertTrue(osp.exists(osp.join(tmp_dir, img_dir)))
self.assertTrue(osp.exists(osp.join(tmp_dir, 'vis', img_dir)))
def test_postprocess(self):
# return_datasample
@ -98,9 +101,8 @@ class TestTextRecinferencer(TestCase):
# pred_out_file
with tempfile.TemporaryDirectory() as tmp_dir:
pred_out_file = osp.join(tmp_dir, 'tmp.pkl')
res = self.inferencer(
img_path, print_result=True, pred_out_file=pred_out_file)
dumped_res = mmengine.load(pred_out_file)
self.assert_predictions_equal(res['predictions'],
dumped_res['predictions'])
img_path, print_result=True, out_dir=tmp_dir, save_pred=True)
dumped_res = mmengine.load(
osp.join(tmp_dir, 'preds', '1036169.json'))
self.assert_prediction_equal(res['predictions'][0], dumped_res)