diff --git a/mmocr/apis/inferencers/base_mmocr_inferencer.py b/mmocr/apis/inferencers/base_mmocr_inferencer.py index a3f96869..f27fc870 100644 --- a/mmocr/apis/inferencers/base_mmocr_inferencer.py +++ b/mmocr/apis/inferencers/base_mmocr_inferencer.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import os.path as osp -from typing import Dict, List, Optional, Sequence, Tuple, Union +from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union import mmcv import mmengine @@ -9,6 +9,7 @@ from mmengine.dataset import Compose from mmengine.infer.infer import BaseInferencer, ModelType from mmengine.registry import init_default_scope from mmengine.structures import InstanceData +from rich.progress import track from torch import Tensor from mmocr.utils import ConfigType @@ -44,10 +45,10 @@ class BaseMMOCRInferencer(BaseInferencer): forward_kwargs: set = set() visualize_kwargs: set = { 'return_vis', 'show', 'wait_time', 'draw_pred', 'pred_score_thr', - 'img_out_dir' + 'save_vis' } postprocess_kwargs: set = { - 'print_result', 'pred_out_file', 'return_datasample' + 'print_result', 'return_datasample', 'save_pred' } loading_transforms: list = ['LoadImageFromFile', 'LoadImageFromNDArray'] @@ -55,26 +56,69 @@ class BaseMMOCRInferencer(BaseInferencer): model: Union[ModelType, str, None] = None, weights: Optional[str] = None, device: Optional[str] = None, - scope: Optional[str] = 'mmocr') -> None: - # A global counter tracking the number of images processed, for - # naming of the output images - self.num_visualized_imgs = 0 + scope: str = 'mmocr') -> None: + # A global counter tracking the number of images given in the form + # of ndarray, for naming the output images + self.num_unnamed_imgs = 0 init_default_scope(scope) super().__init__( model=model, weights=weights, device=device, scope=scope) + def preprocess(self, inputs: InputsType, batch_size: int = 1, **kwargs): + """Process the inputs into a model-feedable format. + + Args: + inputs (InputsType): Inputs given by user. + batch_size (int): batch size. Defaults to 1. + + Yields: + Any: Data processed by the ``pipeline`` and ``collate_fn``. + """ + chunked_data = self._get_chunk_data(inputs, batch_size) + yield from map(self.collate_fn, chunked_data) + + def _get_chunk_data(self, inputs: Iterable, chunk_size: int): + """Get batch data from inputs. + + Args: + inputs (Iterable): An iterable dataset. + chunk_size (int): Equivalent to batch size. + + Yields: + list: batch data. + """ + inputs_iter = iter(inputs) + while True: + try: + chunk_data = [] + for _ in range(chunk_size): + inputs_ = next(inputs_iter) + pipe_out = self.pipeline(inputs_) + if pipe_out['data_samples'].get('img_path') is None: + pipe_out['data_samples'].set_metainfo( + dict(img_path=f'{self.num_unnamed_imgs}.jpg')) + self.num_unnamed_imgs += 1 + chunk_data.append((inputs_, pipe_out)) + yield chunk_data + except StopIteration: + if chunk_data: + yield chunk_data + break + def __call__(self, inputs: InputsType, return_datasamples: bool = False, batch_size: int = 1, + progress_bar: bool = True, return_vis: bool = False, show: bool = False, wait_time: int = 0, draw_pred: bool = True, pred_score_thr: float = 0.3, - img_out_dir: str = '', + out_dir: str = 'results/', + save_vis: bool = False, + save_pred: bool = False, print_result: bool = False, - pred_out_file: str = '', **kwargs) -> dict: """Call the inferencer. @@ -85,6 +129,8 @@ class BaseMMOCRInferencer(BaseInferencer): return_datasamples (bool): Whether to return results as :obj:`BaseDataElement`. Defaults to False. batch_size (int): Inference batch size. Defaults to 1. + progress_bar (bool): Whether to show a progress bar. Defaults to + True. return_vis (bool): Whether to return the visualization result. Defaults to False. show (bool): Whether to display the visualization results in a @@ -94,8 +140,11 @@ class BaseMMOCRInferencer(BaseInferencer): Defaults to True. pred_score_thr (float): Minimum score of bboxes to draw. Defaults to 0.3. - img_out_dir (str): Output directory of visualization results. - If left as empty, no file will be saved. Defaults to ''. + out_dir (str): Output directory of results. Defaults to 'results/'. + save_vis (bool): Whether to save the visualization results to + "out_dir". Defaults to False. + save_pred (bool): Whether to save the inference results to + "out_dir". Defaults to False. print_result (bool): Whether to print the inference result w/o visualization to the console. Defaults to False. pred_out_file: File to save the inference results w/o @@ -109,22 +158,53 @@ class BaseMMOCRInferencer(BaseInferencer): and ``postprocess_kwargs``. Returns: - dict: Inference and visualization results. + dict: Inference and visualization results, mapped from + "predictions" and "visualization". """ - return super().__call__( - inputs, - return_datasamples, - batch_size, + if (save_vis or save_pred) and not out_dir: + raise ValueError('out_dir must be specified when save_vis or ' + 'save_pred is True!') + if out_dir: + img_out_dir = osp.join(out_dir, 'vis') + pred_out_dir = osp.join(out_dir, 'preds') + else: + img_out_dir, pred_out_dir = '', '' + ( + preprocess_kwargs, + forward_kwargs, + visualize_kwargs, + postprocess_kwargs, + ) = self._dispatch_kwargs( return_vis=return_vis, show=show, wait_time=wait_time, draw_pred=draw_pred, pred_score_thr=pred_score_thr, - img_out_dir=img_out_dir, + save_vis=save_vis, + save_pred=save_pred, print_result=print_result, - pred_out_file=pred_out_file, **kwargs) + ori_inputs = self._inputs_to_list(inputs) + inputs = self.preprocess( + ori_inputs, batch_size=batch_size, **preprocess_kwargs) + results = {'predictions': [], 'visualization': []} + for ori_inputs, data in track( + inputs, description='Inference', disable=not progress_bar): + preds = self.forward(data, **forward_kwargs) + visualization = self.visualize( + ori_inputs, preds, img_out_dir=img_out_dir, **visualize_kwargs) + batch_res = self.postprocess( + preds, + visualization, + return_datasamples, + pred_out_dir=pred_out_dir, + **postprocess_kwargs) + results['predictions'].extend(batch_res['predictions']) + if batch_res['visualization'] is not None: + results['visualization'].extend(batch_res['visualization']) + return results + def _init_pipeline(self, cfg: ConfigType) -> Compose: """Initialize the test pipeline.""" pipeline_cfg = cfg.test_dataloader.dataset.pipeline @@ -170,6 +250,7 @@ class BaseMMOCRInferencer(BaseInferencer): wait_time: int = 0, draw_pred: bool = True, pred_score_thr: float = 0.3, + save_vis: bool = False, img_out_dir: str = '') -> Union[List[np.ndarray], None]: """Visualize predictions. @@ -185,6 +266,8 @@ class BaseMMOCRInferencer(BaseInferencer): Defaults to True. pred_score_thr (float): Minimum score of bboxes to draw. Defaults to 0.3. + save_vis (bool): Whether to save the visualization result. Defaults + to False. img_out_dir (str): Output directory of visualization results. If left as empty, no file will be saved. Defaults to ''. @@ -192,8 +275,7 @@ class BaseMMOCRInferencer(BaseInferencer): List[np.ndarray] or None: Returns visualization results only if applicable. """ - if self.visualizer is None or (not show and img_out_dir == '' - and not return_vis): + if self.visualizer is None or not (show or save_vis or return_vis): return None if getattr(self, 'visualizer') is None: @@ -206,17 +288,19 @@ class BaseMMOCRInferencer(BaseInferencer): if isinstance(single_input, str): img_bytes = mmengine.fileio.get(single_input) img = mmcv.imfrombytes(img_bytes, channel_order='rgb') - img_name = osp.basename(single_input) elif isinstance(single_input, np.ndarray): img = single_input.copy()[:, :, ::-1] # to RGB - img_num = str(self.num_visualized_imgs).zfill(8) - img_name = f'{img_num}.jpg' else: raise ValueError('Unsupported input type: ' f'{type(single_input)}') + img_name = osp.splitext(osp.basename(pred.img_path))[0] - out_file = osp.join(img_out_dir, img_name) if img_out_dir != '' \ - else None + if save_vis and img_out_dir: + out_file = osp.splitext(img_name)[0] + out_file = f'{out_file}.jpg' + out_file = osp.join(img_out_dir, out_file) + else: + out_file = None visualization = self.visualizer.add_datasample( img_name, @@ -230,7 +314,6 @@ class BaseMMOCRInferencer(BaseInferencer): out_file=out_file, ) results.append(visualization) - self.num_visualized_imgs += 1 return results @@ -240,7 +323,8 @@ class BaseMMOCRInferencer(BaseInferencer): visualization: Optional[List[np.ndarray]] = None, return_datasample: bool = False, print_result: bool = False, - pred_out_file: str = '', + save_pred: bool = False, + pred_out_dir: str = '', ) -> Union[ResType, Tuple[ResType, np.ndarray]]: """Process the predictions and visualization results from ``forward`` and ``visualize``. @@ -258,7 +342,9 @@ class BaseMMOCRInferencer(BaseInferencer): inference results. If False, dict will be used. print_result (bool): Whether to print the inference result w/o visualization to the console. Defaults to False. - pred_out_file: File to save the inference results w/o + save_pred (bool): Whether to save the inference result. Defaults to + False. + pred_out_dir: File to save the inference results w/o visualization. If left as empty, no file will be saved. Defaults to ''. @@ -279,13 +365,16 @@ class BaseMMOCRInferencer(BaseInferencer): results = [] for pred in preds: result = self.pred2dict(pred) + if save_pred and pred_out_dir: + pred_name = osp.splitext(osp.basename(pred.img_path))[0] + pred_name = f'{pred_name}.json' + pred_out_file = osp.join(pred_out_dir, pred_name) + mmengine.dump(result, pred_out_file) results.append(result) # Add img to the results after printing and dumping result_dict['predictions'] = results if print_result: print(result_dict) - if pred_out_file != '': - mmengine.dump(result_dict, pred_out_file) result_dict['visualization'] = visualization return result_dict diff --git a/mmocr/apis/inferencers/kie_inferencer.py b/mmocr/apis/inferencers/kie_inferencer.py index 7944798e..c7865d5c 100644 --- a/mmocr/apis/inferencers/kie_inferencer.py +++ b/mmocr/apis/inferencers/kie_inferencer.py @@ -1,12 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy import os.path as osp -from typing import Dict, List, Optional, Sequence, Union +from typing import Any, Dict, List, Optional, Sequence, Union import mmcv import mmengine import numpy as np -from mmengine.dataset import Compose +from mmengine.dataset import Compose, pseudo_collate from mmengine.runner.checkpoint import _load_checkpoint from mmocr.registry import DATASETS @@ -45,6 +45,7 @@ class KIEInferencer(BaseMMOCRInferencer): super().__init__( model=model, weights=weights, device=device, scope=scope) self._load_metainfo_to_visualizer(weights, self.cfg) + self.collate_fn = self.kie_collate def _load_metainfo_to_visualizer(self, weights: Optional[str], cfg: ConfigType) -> None: @@ -90,6 +91,21 @@ class KIEInferencer(BaseMMOCRInferencer): return Compose(pipeline_cfg) return super()._init_pipeline(cfg) + @staticmethod + def kie_collate(data_batch: Sequence) -> Any: + """A collate function designed for KIE, where the first element (input) + is a dict and we only want to keep it as-is instead of batching + elements inside. + + Returns: + Any: Transversed Data in the same format as the data_itement of + ``data_batch``. + """ # noqa: E501 + transposed = list(zip(*data_batch)) + for i in range(1, len(transposed)): + transposed[i] = pseudo_collate(transposed[i]) + return transposed + def _inputs_to_list(self, inputs: InputsType) -> list: """Preprocess the inputs to a list. @@ -167,94 +183,9 @@ class KIEInferencer(BaseMMOCRInferencer): else: atype = type(single_input['img']) raise ValueError(f'Unsupported input type: {atype}') + return processed_inputs - def __call__(self, - inputs: InputsType, - return_datasamples: bool = False, - batch_size: int = 1, - return_vis: bool = False, - show: bool = False, - wait_time: int = 0, - draw_pred: bool = True, - pred_score_thr: float = 0.3, - img_out_dir: str = '', - print_result: bool = False, - pred_out_file: str = '', - **kwargs) -> dict: - """Call the inferencer. - - The inputs for KIE Inferencer is special compared to other tasks. - They can be a dict or list[dict], where each dictionary contains - following keys: - - - img (str or ndarray): Path to the image or the image itself. If KIE - Inferencer is used in no-visual mode, this key is not required. - Note: If it's an numpy array, it should be in BGR order. - - img_shape (tuple(int, int)): Image shape in (H, W). In - - instances (list[dict]): A list of instances. - - bbox (ndarray(dtype=np.float32)): Shape (4, ). Bounding box. - - text (str): Annotation text. - - Each ``instance`` looks like the following: - - .. code-block:: python - - { - # A nested list of 4 numbers representing the bounding box of - # the instance, in (x1, y1, x2, y2) order. - 'bbox': np.array([[x1, y1, x2, y2], [x1, y1, x2, y2], ...], - dtype=np.int32), - - # List of texts. - "texts": ['text1', 'text2', ...], - } - - Args: - inputs (InputsType): Inputs for the inferencer. - return_datasamples (bool): Whether to return results as - :obj:`BaseDataElement`. Defaults to False. - batch_size (int): Inference batch size. Defaults to 1. - return_vis (bool): Whether to return the visualization result. - Defaults to False. - show (bool): Whether to display the visualization results in a - popup window. Defaults to False. - wait_time (float): The interval of show (s). Defaults to 0. - draw_pred (bool): Whether to draw predicted bounding boxes. - Defaults to True. - pred_score_thr (float): Minimum score of bboxes to draw. - Defaults to 0.3. - img_out_dir (str): Output directory of visualization results. - If left as empty, no file will be saved. Defaults to ''. - print_result (bool): Whether to print the inference result w/o - visualization to the console. Defaults to False. - pred_out_file: File to save the inference results w/o - visualization. If left as empty, no file will be saved. - Defaults to ''. - - **kwargs: Other keyword arguments passed to :meth:`preprocess`, - :meth:`forward`, :meth:`visualize` and :meth:`postprocess`. - Each key in kwargs should be in the corresponding set of - ``preprocess_kwargs``, ``forward_kwargs``, ``visualize_kwargs`` - and ``postprocess_kwargs``. - - Returns: - dict: Inference and visualization results. - """ - return super().__call__( - inputs, - return_datasamples, - batch_size, - return_vis=return_vis, - show=show, - wait_time=wait_time, - draw_pred=draw_pred, - pred_score_thr=pred_score_thr, - img_out_dir=img_out_dir, - print_result=print_result, - pred_out_file=pred_out_file, - **kwargs) - def visualize(self, inputs: InputsType, preds: PredType, @@ -263,7 +194,8 @@ class KIEInferencer(BaseMMOCRInferencer): wait_time: int = 0, draw_pred: bool = True, pred_score_thr: float = 0.3, - img_out_dir: str = '') -> List[np.ndarray]: + save_vis: bool = False, + img_out_dir: str = '') -> Union[List[np.ndarray], None]: """Visualize predictions. Args: @@ -278,10 +210,16 @@ class KIEInferencer(BaseMMOCRInferencer): Defaults to True. pred_score_thr (float): Minimum score of bboxes to draw. Defaults to 0.3. - img_out_dir (str): Output directory of images. Defaults to ''. + save_vis (bool): Whether to save the visualization result. Defaults + to False. + img_out_dir (str): Output directory of visualization results. + If left as empty, no file will be saved. Defaults to ''. + + Returns: + List[np.ndarray] or None: Returns visualization results only if + applicable. """ - if self.visualizer is None or (not show and img_out_dir == '' - and not return_vis): + if self.visualizer is None or not (show or save_vis or return_vis): return None if getattr(self, 'visualizer') is None: @@ -291,24 +229,26 @@ class KIEInferencer(BaseMMOCRInferencer): results = [] for single_input, pred in zip(inputs, preds): - img_num = str(self.num_visualized_imgs).zfill(8) assert 'img' in single_input or 'img_shape' in single_input if 'img' in single_input: if isinstance(single_input['img'], str): - img = mmcv.imread(single_input['img'], channel_order='rgb') - img_name = osp.basename(single_input['img']) + img_bytes = mmengine.fileio.get(single_input['img']) + img = mmcv.imfrombytes(img_bytes, channel_order='rgb') elif isinstance(single_input['img'], np.ndarray): img = single_input['img'].copy()[:, :, ::-1] # To RGB - img_name = f'{img_num}.jpg' elif 'img_shape' in single_input: img = np.zeros(single_input['img_shape'], dtype=np.uint8) - img_name = f'{img_num}.jpg' else: raise ValueError('Input does not contain either "img" or ' '"img_shape"') + img_name = osp.splitext(osp.basename(pred.img_path))[0] - out_file = osp.join(img_out_dir, img_name) if img_out_dir != '' \ - else None + if save_vis and img_out_dir: + out_file = osp.splitext(img_name)[0] + out_file = f'{out_file}.jpg' + out_file = osp.join(img_out_dir, out_file) + else: + out_file = None visualization = self.visualizer.add_datasample( img_name, @@ -322,7 +262,6 @@ class KIEInferencer(BaseMMOCRInferencer): out_file=out_file, ) results.append(visualization) - self.num_visualized_imgs += 1 return results diff --git a/mmocr/apis/inferencers/mmocr_inferencer.py b/mmocr/apis/inferencers/mmocr_inferencer.py index b0ef8aa7..2d88b5d4 100644 --- a/mmocr/apis/inferencers/mmocr_inferencer.py +++ b/mmocr/apis/inferencers/mmocr_inferencer.py @@ -1,16 +1,15 @@ # Copyright (c) OpenMMLab. All rights reserved. -import copy +import os.path as osp from datetime import datetime from typing import Dict, List, Optional, Tuple, Union import mmcv import mmengine import numpy as np -from mmengine.fileio import (get_file_backend, isdir, join_path, - list_dir_or_file) +from rich.progress import track from mmocr.registry import VISUALIZERS -from mmocr.structures.textdet_data_sample import TextDetDataSample +from mmocr.structures import TextSpottingDataSample from mmocr.utils import ConfigType, bbox2poly, crop_img, poly2bbox from .base_mmocr_inferencer import (BaseMMOCRInferencer, InputsType, PredType, ResType) @@ -65,7 +64,6 @@ class MMOCRInferencer(BaseMMOCRInferencer): 'provided.') self.visualizer = None - self.num_visualized_imgs = 0 if det is not None: self.textdet_inferencer = TextDetInferencer( @@ -93,43 +91,19 @@ class MMOCRInferencer(BaseMMOCRInferencer): self.kie_inferencer = KIEInferencer(kie, kie_weights, device) self.mode = 'det_rec_kie' - def _inputs_to_list(self, inputs: InputsType) -> list: - """Preprocess the inputs to a list. - - Preprocess inputs to a list according to its type: - - - list or tuple: return inputs - - str: - - Directory path: return all files in the directory - - normal string: return a list containing the string - - Args: - inputs (InputsType): Inputs for the inferencer. - - Returns: - list: List of input for the :meth:`preprocess`. - """ - inputs = copy.deepcopy(inputs) - if isinstance(inputs, str): - backend = get_file_backend(inputs) - if hasattr(backend, 'isdir') and isdir(inputs): - # Backends like HttpsBackend do not implement `isdir`, so only - # those backends that implement `isdir` could accept the inputs - # as a directory - filename_list = list_dir_or_file(inputs, list_dir=False) - inputs = [ - join_path(inputs, filename) for filename in filename_list - ] - - if not isinstance(inputs, (list, tuple)): - inputs = [inputs] - - for i in range(len(inputs)): - if not isinstance(inputs[i], np.ndarray): - img_bytes = mmengine.fileio.get(inputs[i]) - inputs[i] = mmcv.imfrombytes(img_bytes) - - return list(inputs) + def _inputs2ndarrray(self, inputs: List[InputsType]) -> List[np.ndarray]: + """Preprocess the inputs to a list of numpy arrays.""" + new_inputs = [] + for item in inputs: + if isinstance(item, np.ndarray): + new_inputs.append(item) + elif isinstance(item, str): + img_bytes = mmengine.fileio.get(item) + new_inputs.append(mmcv.imfrombytes(img_bytes)) + else: + raise NotImplementedError(f'The input type {type(item)} is not' + 'supported yet.') + return new_inputs def forward(self, inputs: InputsType, batch_size: int, **forward_kwargs) -> PredType: @@ -144,6 +118,7 @@ class MMOCRInferencer(BaseMMOCRInferencer): "kie".. """ result = {} + forward_kwargs['progress_bar'] = False if self.mode == 'rec': # The extra list wrapper here is for the ease of postprocessing self.rec_inputs = inputs @@ -153,15 +128,16 @@ class MMOCRInferencer(BaseMMOCRInferencer): batch_size=batch_size, **forward_kwargs)['predictions'] result['rec'] = [[p] for p in predictions] - elif self.mode.startswith('det'): + elif self.mode.startswith('det'): # 'det'/'det_rec'/'det_rec_kie' result['det'] = self.textdet_inferencer( inputs, return_datasamples=True, batch_size=batch_size, **forward_kwargs)['predictions'] - if self.mode.startswith('det_rec'): + if self.mode.startswith('det_rec'): # 'det_rec'/'det_rec_kie' result['rec'] = [] - for img, det_data_sample in zip(inputs, result['det']): + for img, det_data_sample in zip( + self._inputs2ndarrray(inputs), result['det']): det_pred = det_data_sample.pred_instances self.rec_inputs = [] for polygon in det_pred['polygons']: @@ -177,6 +153,10 @@ class MMOCRInferencer(BaseMMOCRInferencer): **forward_kwargs)['predictions']) if self.mode == 'det_rec_kie': self.kie_inputs = [] + # TODO: when the det output is empty, kie will fail + # as no gt-instances can be provided. It's a known + # issue but cannot be solved elegantly since we support + # batch inference. for img, det_data_sample, rec_data_samples in zip( inputs, result['det'], result['rec']): det_pred = det_data_sample.pred_instances @@ -197,7 +177,7 @@ class MMOCRInferencer(BaseMMOCRInferencer): return result def visualize(self, inputs: InputsType, preds: PredType, - **kwargs) -> List[np.ndarray]: + **kwargs) -> Union[List[np.ndarray], None]: """Visualize predictions. Args: @@ -210,7 +190,14 @@ class MMOCRInferencer(BaseMMOCRInferencer): Defaults to True. pred_score_thr (float): Minimum score of bboxes to draw. Defaults to 0.3. - img_out_dir (str): Output directory of images. Defaults to ''. + save_vis (bool): Whether to save the visualization result. Defaults + to False. + img_out_dir (str): Output directory of visualization results. + If left as empty, no file will be saved. Defaults to ''. + + Returns: + List[np.ndarray] or None: Returns visualization results only if + applicable. """ if 'kie' in self.mode: @@ -232,6 +219,9 @@ class MMOCRInferencer(BaseMMOCRInferencer): self, inputs: InputsType, batch_size: int = 1, + out_dir: str = 'results/', + save_vis: bool = False, + save_pred: bool = False, **kwargs, ) -> dict: """Call the inferencer. @@ -239,9 +229,12 @@ class MMOCRInferencer(BaseMMOCRInferencer): Args: inputs (InputsType): Inputs for the inferencer. It can be a path to image / image directory, or an array, or a list of these. - return_datasamples (bool): Whether to return results as - :obj:`BaseDataElement`. Defaults to False. batch_size (int): Batch size. Defaults to 1. + out_dir (str): Output directory of results. Defaults to 'results/'. + save_vis (bool): Whether to save the visualization results to + "out_dir". Defaults to False. + save_pred (bool): Whether to save the inference results to + "out_dir". Defaults to False. **kwargs: Key words arguments passed to :meth:`preprocess`, :meth:`forward`, :meth:`visualize` and :meth:`postprocess`. Each key in kwargs should be in the corresponding set of @@ -249,47 +242,75 @@ class MMOCRInferencer(BaseMMOCRInferencer): and ``postprocess_kwargs``. Returns: - dict: Inference and visualization results. + dict: Inference and visualization results, mapped from + "predictions" and "visualization". """ + if (save_vis or save_pred) and not out_dir: + raise ValueError('out_dir must be specified when save_vis or ' + 'save_pred is True!') + if out_dir: + img_out_dir = osp.join(out_dir, 'vis') + pred_out_dir = osp.join(out_dir, 'preds') + else: + img_out_dir, pred_out_dir = '', '' + ( preprocess_kwargs, forward_kwargs, visualize_kwargs, postprocess_kwargs, - ) = self._dispatch_kwargs(**kwargs) + ) = self._dispatch_kwargs( + save_vis=save_vis, save_pred=save_pred, **kwargs) ori_inputs = self._inputs_to_list(inputs) - preds = self.forward(ori_inputs, batch_size, **forward_kwargs) - - visualization = self.visualize( - ori_inputs, preds, - **visualize_kwargs) # type: ignore # noqa: E501 - results = self.postprocess(preds, visualization, **postprocess_kwargs) + chunked_inputs = super(BaseMMOCRInferencer, + self)._get_chunk_data(ori_inputs, batch_size) + results = {'predictions': [], 'visualization': []} + for ori_input in track(chunked_inputs, description='Inference'): + preds = self.forward(ori_input, batch_size, **forward_kwargs) + visualization = self.visualize( + ori_input, preds, img_out_dir=img_out_dir, **visualize_kwargs) + batch_res = self.postprocess( + preds, + visualization, + pred_out_dir=pred_out_dir, + **postprocess_kwargs) + results['predictions'].extend(batch_res['predictions']) + if batch_res['visualization'] is not None: + results['visualization'].extend(batch_res['visualization']) return results def postprocess(self, preds: PredType, visualization: Optional[List[np.ndarray]] = None, print_result: bool = False, - pred_out_file: str = '' + save_pred: bool = False, + pred_out_dir: str = '' ) -> Union[ResType, Tuple[ResType, np.ndarray]]: - """Postprocess predictions. + """Process the predictions and visualization results from ``forward`` + and ``visualize``. + + This method should be responsible for the following tasks: + + 1. Convert datasamples into a json-serializable dict if needed. + 2. Pack the predictions and visualization results and return them. + 3. Dump or log the predictions. Args: - preds (Dict): Predictions of the model. + preds (PredType): Predictions of the model. visualization (Optional[np.ndarray]): Visualized predictions. print_result (bool): Whether to print the result. Defaults to False. - pred_out_file (str): Output file name to store predictions - without images. Supported file formats are “json”, “yaml/yml” - and “pickle/pkl”. Defaults to ''. + save_pred (bool): Whether to save the inference result. Defaults to + False. + pred_out_dir: File to save the inference results w/o + visualization. If left as empty, no file will be saved. + Defaults to ''. Returns: - Dict or List[Dict]: Each dict contains the inference result of - each image. Possible keys are "det_polygons", "det_scores", - "rec_texts", "rec_scores", "kie_labels", "kie_scores", - "kie_edge_labels" and "kie_edge_scores". + Dict: Inference and visualization results, mapped from + "predictions" and "visualization". """ result_dict = {} @@ -320,22 +341,28 @@ class MMOCRInferencer(BaseMMOCRInferencer): kie_edge_scores=kie_dict_res['edge_scores'], kie_edge_labels=kie_dict_res['edge_labels']) + if save_pred and pred_out_dir: + pred_key = 'det' if 'det' in self.mode else 'rec' + for pred, pred_result in zip(preds[pred_key], pred_results): + img_path = ( + pred.img_path if pred_key == 'det' else pred[0].img_path) + pred_name = osp.splitext(osp.basename(img_path))[0] + pred_name = f'{pred_name}.json' + pred_out_file = osp.join(pred_out_dir, pred_name) + mmengine.dump(pred_result, pred_out_file) + result_dict['predictions'] = pred_results if print_result: print(result_dict) - if pred_out_file != '': - mmengine.dump(result_dict, pred_out_file) result_dict['visualization'] = visualization return result_dict - def _pack_e2e_datasamples(self, preds: Dict) -> List[TextDetDataSample]: + def _pack_e2e_datasamples(self, + preds: Dict) -> List[TextSpottingDataSample]: """Pack text detection and recognition results into a list of - TextDetDataSample. - - Note that it is a temporary solution since the TextSpottingDataSample - is not ready. - """ + TextSpottingDataSample.""" results = [] + for det_data_sample, rec_data_samples in zip(preds['det'], preds['rec']): texts = [] diff --git a/mmocr/ocr.py b/mmocr/ocr.py index 216f9a41..74ff9099 100755 --- a/mmocr/ocr.py +++ b/mmocr/ocr.py @@ -9,10 +9,10 @@ def parse_args(): parser.add_argument( 'inputs', type=str, help='Input image file or folder path.') parser.add_argument( - '--img-out-dir', + '--out-dir', type=str, - default='', - help='Output directory of images.') + default='results/', + help='Output directory of results.') parser.add_argument( '--det', type=str, @@ -69,10 +69,13 @@ def parse_args(): action='store_true', help='Whether to print the results.') parser.add_argument( - '--pred-out-file', - type=str, - default='', - help='File to save the inference results.') + '--save_pred', + action='store_true', + help='Save the inference results to out_dir.') + parser.add_argument( + '--save_vis', + action='store_true', + help='Save the visualization results to out_dir.') call_args = vars(parser.parse_args()) diff --git a/mmocr/visualization/base_visualizer.py b/mmocr/visualization/base_visualizer.py index b53972cd..868f6585 100644 --- a/mmocr/visualization/base_visualizer.py +++ b/mmocr/visualization/base_visualizer.py @@ -85,6 +85,8 @@ class BaseLocalVisualizer(Visualizer): font_families (Union[str, List[str]]): The font families of labels. Defaults to 'sans-serif'. """ + if not labels and not bboxes: + return image if colors is not None and isinstance(colors, (list, tuple)): size = math.ceil(len(labels) / len(colors)) colors = (colors * size)[:len(labels)] diff --git a/mmocr/visualization/textspotting_visualizer.py b/mmocr/visualization/textspotting_visualizer.py index 6d712261..c5a1dc19 100644 --- a/mmocr/visualization/textspotting_visualizer.py +++ b/mmocr/visualization/textspotting_visualizer.py @@ -44,18 +44,19 @@ class TextSpottingLocalVisualizer(BaseLocalVisualizer): img_shape = image.shape[:2] empty_shape = (img_shape[0], img_shape[1], 3) text_image = np.full(empty_shape, 255, dtype=np.uint8) - text_image = self.get_labels_image( - text_image, - labels=texts, - bboxes=bboxes, - font_families=self.font_families) + if texts: + text_image = self.get_labels_image( + text_image, + labels=texts, + bboxes=bboxes, + font_families=self.font_families) if polygons: polygons = [polygon.reshape(-1, 2) for polygon in polygons] image = self.get_polygons_image( image, polygons, filling=True, colors=self.PALETTE) text_image = self.get_polygons_image( text_image, polygons, colors=self.PALETTE) - else: + elif len(bboxes) > 0: image = self.get_bboxes_image( image, bboxes, filling=True, colors=self.PALETTE) text_image = self.get_bboxes_image( @@ -103,27 +104,28 @@ class TextSpottingLocalVisualizer(BaseLocalVisualizer): """ cat_images = [] - if draw_gt: - gt_bboxes = data_sample.gt_instances.get('bboxes', None) - gt_texts = data_sample.gt_instances.texts - gt_polygons = data_sample.gt_instances.get('polygons', None) - gt_img_data = self._draw_instances(image, gt_bboxes, gt_polygons, - gt_texts) - cat_images.append(gt_img_data) + if data_sample is not None: + if draw_gt and 'gt_instances' in data_sample: + gt_bboxes = data_sample.gt_instances.get('bboxes', None) + gt_texts = data_sample.gt_instances.texts + gt_polygons = data_sample.gt_instances.get('polygons', None) + gt_img_data = self._draw_instances(image, gt_bboxes, + gt_polygons, gt_texts) + cat_images.append(gt_img_data) - if draw_pred: - pred_instances = data_sample.pred_instances - pred_instances = pred_instances[ - pred_instances.scores > pred_score_thr].cpu().numpy() - pred_bboxes = pred_instances.get('bboxes', None) - pred_texts = pred_instances.texts - pred_polygons = pred_instances.get('polygons', None) - if pred_bboxes is None: - pred_bboxes = [poly2bbox(poly) for poly in pred_polygons] - pred_bboxes = np.array(pred_bboxes) - pred_img_data = self._draw_instances(image, pred_bboxes, - pred_polygons, pred_texts) - cat_images.append(pred_img_data) + if draw_pred and 'pred_instances' in data_sample: + pred_instances = data_sample.pred_instances + pred_instances = pred_instances[ + pred_instances.scores > pred_score_thr].cpu().numpy() + pred_bboxes = pred_instances.get('bboxes', None) + pred_texts = pred_instances.texts + pred_polygons = pred_instances.get('polygons', None) + if pred_bboxes is None: + pred_bboxes = [poly2bbox(poly) for poly in pred_polygons] + pred_bboxes = np.array(pred_bboxes) + pred_img_data = self._draw_instances(image, pred_bboxes, + pred_polygons, pred_texts) + cat_images.append(pred_img_data) cat_images = self._cat_image(cat_images, axis=0) if cat_images is None: diff --git a/tests/test_apis/test_inferencers/test_kie_inferencer.py b/tests/test_apis/test_inferencers/test_kie_inferencer.py index 7b431dd8..5f4fa545 100644 --- a/tests/test_apis/test_inferencers/test_kie_inferencer.py +++ b/tests/test_apis/test_inferencers/test_kie_inferencer.py @@ -75,12 +75,15 @@ class TestKIEInferencer(TestCase): def assert_predictions_equal(self, preds1, preds2): for pred1, pred2 in zip(preds1, preds2): - self.assertTrue(np.allclose(pred1['labels'], pred2['labels'], 0.1)) - self.assertTrue( - np.allclose(pred1['edge_scores'], pred2['edge_scores'], 0.1)) - self.assertTrue( - np.allclose(pred1['edge_labels'], pred2['edge_labels'], 0.1)) - self.assertTrue(np.allclose(pred1['scores'], pred2['scores'], 0.1)) + self.assert_prediction_equal(pred1, pred2) + + def assert_prediction_equal(self, pred1, pred2): + self.assertTrue(np.allclose(pred1['labels'], pred2['labels'], 0.1)) + self.assertTrue( + np.allclose(pred1['edge_scores'], pred2['edge_scores'], 0.1)) + self.assertTrue( + np.allclose(pred1['edge_labels'], pred2['edge_labels'], 0.1)) + self.assertTrue(np.allclose(pred1['scores'], pred2['scores'], 0.1)) def test_call(self): # no visual, single input @@ -130,9 +133,9 @@ class TestKIEInferencer(TestCase): # img_out_dir with tempfile.TemporaryDirectory() as tmp_dir: - self.inferencer(self.data_img_str, img_out_dir=tmp_dir) - for img_dir in ['00000000.jpg', '00000001.jpg']: - self.assertTrue(osp.exists(osp.join(tmp_dir, img_dir))) + self.inferencer(self.data_img_str, out_dir=tmp_dir, save_vis=True) + for img_dir in ['1.jpg', '2.jpg']: + self.assertTrue(osp.exists(osp.join(tmp_dir, 'vis', img_dir))) def test_postprocess(self): # return_datasample @@ -141,14 +144,19 @@ class TestKIEInferencer(TestCase): # pred_out_file with tempfile.TemporaryDirectory() as tmp_dir: - pred_out_file = osp.join(tmp_dir, 'tmp.pkl') res = self.inferencer( self.data_img_ndarray, print_result=True, - pred_out_file=pred_out_file) - dumped_res = mmengine.load(pred_out_file) - self.assert_predictions_equal(res['predictions'], - dumped_res['predictions']) + out_dir=tmp_dir, + save_pred=True) + file_names = [ + f'{self.inferencer.num_unnamed_imgs - i}.json' + for i in range(len(self.data_img_ndarray), 0, -1) + ] + for pred, file_name in zip(res['predictions'], file_names): + dumped_res = mmengine.load( + osp.join(tmp_dir, 'preds', file_name)) + self.assert_prediction_equal(dumped_res, pred) @mock.patch('mmocr.apis.inferencers.kie_inferencer._load_checkpoint') def test_load_metainfo_to_visualizer(self, mock_load): diff --git a/tests/test_apis/test_inferencers/test_mmocr_inferencer.py b/tests/test_apis/test_inferencers/test_mmocr_inferencer.py index f084ef9f..eebcd721 100644 --- a/tests/test_apis/test_inferencers/test_mmocr_inferencer.py +++ b/tests/test_apis/test_inferencers/test_mmocr_inferencer.py @@ -95,6 +95,18 @@ class TestMMOCRInferencer(TestCase): np.allclose(res_img_ndarray['visualization'][0], res_img_path['visualization'][0])) + # test save_vis and save_pred + with tempfile.TemporaryDirectory() as tmp_dir: + res = inferencer( + img_paths, out_dir=tmp_dir, save_vis=True, save_pred=True) + for img_dir in ['img_1.jpg', 'img_2.jpg']: + self.assertTrue(osp.exists(osp.join(tmp_dir, 'vis', img_dir))) + for i, pred_dir in enumerate(['img_1.json', 'img_2.json']): + dumped_res = mmengine.load( + osp.join(tmp_dir, 'preds', pred_dir)) + self.assert_predictions_equal(res['predictions'][i], + dumped_res) + @mock.patch('mmengine.infer.infer._load_checkpoint') def test_rec(self, mock_load): mock_load.side_effect = lambda *x, **y: None @@ -129,6 +141,18 @@ class TestMMOCRInferencer(TestCase): np.allclose(res_img_ndarray['visualization'][0], res_img_path['visualization'][0])) + # test save_vis and save_pred + with tempfile.TemporaryDirectory() as tmp_dir: + res = inferencer( + img_paths, out_dir=tmp_dir, save_vis=True, save_pred=True) + for img_dir in ['1036169.jpg', '1058891.jpg']: + self.assertTrue(osp.exists(osp.join(tmp_dir, 'vis', img_dir))) + for i, pred_dir in enumerate(['1036169.json', '1058891.json']): + dumped_res = mmengine.load( + osp.join(tmp_dir, 'preds', pred_dir)) + self.assert_predictions_equal(res['predictions'][i], + dumped_res) + @mock.patch('mmengine.infer.infer._load_checkpoint') def test_det_rec(self, mock_load): mock_load.side_effect = lambda *x, **y: None @@ -166,6 +190,21 @@ class TestMMOCRInferencer(TestCase): np.allclose(res_img_ndarray['visualization'][0], res_img_path['visualization'][0])) + # test save_vis and save_pred + with tempfile.TemporaryDirectory() as tmp_dir: + res = inferencer( + img_paths, out_dir=tmp_dir, save_vis=True, save_pred=True) + for img_dir in ['img_1.jpg', 'img_2.jpg']: + self.assertTrue(osp.exists(osp.join(tmp_dir, 'vis', img_dir))) + for i, pred_dir in enumerate(['img_1.json', 'img_2.json']): + dumped_res = mmengine.load( + osp.join(tmp_dir, 'preds', pred_dir)) + self.assert_predictions_equal(res['predictions'][i], + dumped_res) + + # corner case: when the det model cannot detect any texts + inferencer(np.zeros((100, 100, 3)), return_vis=True) + @mock.patch('mmengine.infer.infer._load_checkpoint') def test_dec_rec_kie(self, mock_load): mock_load.side_effect = lambda *x, **y: None @@ -205,18 +244,14 @@ class TestMMOCRInferencer(TestCase): np.allclose(res_img_ndarray['visualization'][0], res_img_path['visualization'][0])) - # test visualization - # img_out_dir + # test save_vis and save_pred with tempfile.TemporaryDirectory() as tmp_dir: - inferencer(img_paths, img_out_dir=tmp_dir) - for img_dir in ['00000006.jpg', '00000007.jpg']: - self.assertTrue(osp.exists(osp.join(tmp_dir, img_dir))) - - # pred_out_file - with tempfile.TemporaryDirectory() as tmp_dir: - pred_out_file = osp.join(tmp_dir, 'tmp.pkl') res = inferencer( - img_path, print_result=True, pred_out_file=pred_out_file) - dumped_res = mmengine.load(pred_out_file) - self.assert_predictions_equal(res['predictions'], - dumped_res['predictions']) + img_paths, out_dir=tmp_dir, save_vis=True, save_pred=True) + for img_dir in ['1.jpg', '2.jpg']: + self.assertTrue(osp.exists(osp.join(tmp_dir, 'vis', img_dir))) + for i, pred_dir in enumerate(['1.json', '2.json']): + dumped_res = mmengine.load( + osp.join(tmp_dir, 'preds', pred_dir)) + self.assert_predictions_equal(res['predictions'][i], + dumped_res) diff --git a/tests/test_apis/test_inferencers/test_textdet_inferencer.py b/tests/test_apis/test_inferencers/test_textdet_inferencer.py index 60130955..badb4eaa 100644 --- a/tests/test_apis/test_inferencers/test_textdet_inferencer.py +++ b/tests/test_apis/test_inferencers/test_textdet_inferencer.py @@ -41,9 +41,11 @@ class TestTextDetinferencer(TestCase): def assert_predictions_equal(self, preds1, preds2): for pred1, pred2 in zip(preds1, preds2): - self.assertTrue( - np.allclose(pred1['polygons'], pred2['polygons'], 0.1)) - self.assertTrue(np.allclose(pred1['scores'], pred2['scores'], 0.1)) + self.assert_prediction_equal(pred1, pred2) + + def assert_prediction_equal(self, pred1, pred2): + self.assertTrue(np.allclose(pred1['polygons'], pred2['polygons'], 0.1)) + self.assertTrue(np.allclose(pred1['scores'], pred2['scores'], 0.1)) def test_call(self): # single img @@ -91,9 +93,9 @@ class TestTextDetinferencer(TestCase): # img_out_dir with tempfile.TemporaryDirectory() as tmp_dir: - self.inferencer(img_paths, img_out_dir=tmp_dir) + self.inferencer(img_paths, out_dir=tmp_dir, save_vis=True) for img_dir in ['img_1.jpg', 'img_2.jpg']: - self.assertTrue(osp.exists(osp.join(tmp_dir, img_dir))) + self.assertTrue(osp.exists(osp.join(tmp_dir, 'vis', img_dir))) def test_postprocess(self): # return_datasample @@ -101,14 +103,13 @@ class TestTextDetinferencer(TestCase): res = self.inferencer(img_path, return_datasamples=True) self.assertTrue(is_type_list(res['predictions'], TextDetDataSample)) - # pred_out_file + # dump predictions with tempfile.TemporaryDirectory() as tmp_dir: - pred_out_file = osp.join(tmp_dir, 'tmp.pkl') res = self.inferencer( - img_path, print_result=True, pred_out_file=pred_out_file) - dumped_res = mmengine.load(pred_out_file) - self.assert_predictions_equal(res['predictions'], - dumped_res['predictions']) + img_path, print_result=True, out_dir=tmp_dir, save_pred=True) + dumped_res = mmengine.load( + osp.join(tmp_dir, 'preds', 'img_1.json')) + self.assert_prediction_equal(res['predictions'][0], dumped_res) def test_pred2dict(self): data_sample = TextDetDataSample() diff --git a/tests/test_apis/test_inferencers/test_textrec_inferencer.py b/tests/test_apis/test_inferencers/test_textrec_inferencer.py index 801110cf..1e89e4a9 100644 --- a/tests/test_apis/test_inferencers/test_textrec_inferencer.py +++ b/tests/test_apis/test_inferencers/test_textrec_inferencer.py @@ -40,8 +40,11 @@ class TestTextRecinferencer(TestCase): def assert_predictions_equal(self, preds1, preds2): for pred1, pred2 in zip(preds1, preds2): - self.assertEqual(pred1['text'], pred2['text']) - self.assertTrue(np.allclose(pred1['scores'], pred2['scores'], 0.1)) + self.assert_prediction_equal(pred1, pred2) + + def assert_prediction_equal(self, pred1, pred2): + self.assertEqual(pred1['text'], pred2['text']) + self.assertTrue(np.allclose(pred1['scores'], pred2['scores'], 0.1)) def test_call(self): # single img @@ -86,9 +89,9 @@ class TestTextRecinferencer(TestCase): # img_out_dir with tempfile.TemporaryDirectory() as tmp_dir: - self.inferencer(img_paths, img_out_dir=tmp_dir) + self.inferencer(img_paths, out_dir=tmp_dir, save_vis=True) for img_dir in ['1036169.jpg', '1058891.jpg']: - self.assertTrue(osp.exists(osp.join(tmp_dir, img_dir))) + self.assertTrue(osp.exists(osp.join(tmp_dir, 'vis', img_dir))) def test_postprocess(self): # return_datasample @@ -98,9 +101,8 @@ class TestTextRecinferencer(TestCase): # pred_out_file with tempfile.TemporaryDirectory() as tmp_dir: - pred_out_file = osp.join(tmp_dir, 'tmp.pkl') res = self.inferencer( - img_path, print_result=True, pred_out_file=pred_out_file) - dumped_res = mmengine.load(pred_out_file) - self.assert_predictions_equal(res['predictions'], - dumped_res['predictions']) + img_path, print_result=True, out_dir=tmp_dir, save_pred=True) + dumped_res = mmengine.load( + osp.join(tmp_dir, 'preds', '1036169.json')) + self.assert_prediction_equal(res['predictions'][0], dumped_res)