diff --git a/configs/textrecog/abinet/base.py b/configs/textrecog/abinet/base.py index f84ee8bd..a00d2588 100644 --- a/configs/textrecog/abinet/base.py +++ b/configs/textrecog/abinet/base.py @@ -11,7 +11,11 @@ file_client_args = dict(backend='disk') default_hooks = dict(logger=dict(type='LoggerHook', interval=100)) train_pipeline = [ - dict(type='LoadImageFromFile', file_client_args=file_client_args), + dict( + type='LoadImageFromFile', + file_client_args=file_client_args, + ignore_empty=True, + min_size=5), dict(type='LoadOCRAnnotations', with_text=True), dict(type='Resize', scale=(128, 32)), dict( diff --git a/configs/textrecog/crnn/crnn_academic_dataset.py b/configs/textrecog/crnn/crnn_academic_dataset.py index b671c05a..35f7c93c 100644 --- a/configs/textrecog/crnn/crnn_academic_dataset.py +++ b/configs/textrecog/crnn/crnn_academic_dataset.py @@ -17,7 +17,9 @@ train_pipeline = [ dict( type='LoadImageFromFile', color_type='grayscale', - file_client_args=file_client_args), + file_client_args=file_client_args, + ignore_empty=True, + min_size=5), dict(type='LoadOCRAnnotations', with_text=True), dict(type='Resize', scale=(100, 32), keep_ratio=False), dict( diff --git a/configs/textrecog/master/master_r31_12e_ST_MJ_SA.py b/configs/textrecog/master/master_r31_12e_ST_MJ_SA.py index b244733d..4446d7b6 100644 --- a/configs/textrecog/master/master_r31_12e_ST_MJ_SA.py +++ b/configs/textrecog/master/master_r31_12e_ST_MJ_SA.py @@ -12,7 +12,11 @@ file_client_args = dict(backend='disk') default_hooks = dict(logger=dict(type='LoggerHook', interval=50), ) train_pipeline = [ - dict(type='LoadImageFromFile', file_client_args=file_client_args), + dict( + type='LoadImageFromFile', + file_client_args=file_client_args, + ignore_empty=True, + min_size=5), dict(type='LoadOCRAnnotations', with_text=True), dict( type='RescaleToHeight', diff --git a/configs/textrecog/robust_scanner/robustscanner_r31_academic.py b/configs/textrecog/robust_scanner/robustscanner_r31_academic.py index 3b2480a5..da4e1b40 100644 --- a/configs/textrecog/robust_scanner/robustscanner_r31_academic.py +++ b/configs/textrecog/robust_scanner/robustscanner_r31_academic.py @@ -9,7 +9,11 @@ file_client_args = dict(backend='disk') default_hooks = dict(logger=dict(type='LoggerHook', interval=100)) train_pipeline = [ - dict(type='LoadImageFromFile', file_client_args=file_client_args), + dict( + type='LoadImageFromFile', + file_client_args=file_client_args, + ignore_empty=True, + min_size=5), dict(type='LoadOCRAnnotations', with_text=True), dict(type='Resize', scale=(160, 48), keep_ratio=False), dict( diff --git a/configs/textrecog/sar/sar_r31_parallel_decoder_academic.py b/configs/textrecog/sar/sar_r31_parallel_decoder_academic.py index 1566cd18..a46d62b5 100644 --- a/configs/textrecog/sar/sar_r31_parallel_decoder_academic.py +++ b/configs/textrecog/sar/sar_r31_parallel_decoder_academic.py @@ -11,7 +11,11 @@ file_client_args = dict(backend='disk') default_hooks = dict(logger=dict(type='LoggerHook', interval=100)) train_pipeline = [ - dict(type='LoadImageFromFile', file_client_args=file_client_args), + dict( + type='LoadImageFromFile', + file_client_args=file_client_args, + ignore_empty=True, + min_size=5), dict(type='LoadOCRAnnotations', with_text=True), dict( type='RescaleToHeight', diff --git a/configs/textrecog/satrn/satrn_academic.py b/configs/textrecog/satrn/satrn_academic.py index 8776a290..243ea0a9 100644 --- a/configs/textrecog/satrn/satrn_academic.py +++ b/configs/textrecog/satrn/satrn_academic.py @@ -42,7 +42,11 @@ model = dict( postprocessor=dict(type='AttentionPostprocessor'))) train_pipeline = [ - dict(type='LoadImageFromFile', file_client_args=file_client_args), + dict( + type='LoadImageFromFile', + file_client_args=file_client_args, + ignore_empty=True, + min_size=5), dict(type='LoadOCRAnnotations', with_text=True), dict(type='Resize', scale=(100, 32), keep_ratio=False), dict( diff --git a/mmocr/datasets/transforms/__init__.py b/mmocr/datasets/transforms/__init__.py index 8a9ba62f..b14c47d9 100644 --- a/mmocr/datasets/transforms/__init__.py +++ b/mmocr/datasets/transforms/__init__.py @@ -1,7 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. from .adapters import MMDet2MMOCR, MMOCR2MMDet from .formatting import PackKIEInputs, PackTextDetInputs, PackTextRecogInputs -from .loading import LoadImageFromLMDB, LoadKIEAnnotations, LoadOCRAnnotations +from .loading import (LoadImageFromFile, LoadImageFromLMDB, LoadKIEAnnotations, + LoadOCRAnnotations) from .ocr_transforms import RandomCrop, RandomRotate, Resize from .textdet_transforms import (BoundedScaleAspectJitter, FixInvalidPolygon, RandomFlip, ShortScaleAspectJitter, @@ -17,5 +18,5 @@ __all__ = [ 'PackTextRecogInputs', 'RescaleToHeight', 'PadToWidth', 'ShortScaleAspectJitter', 'RandomFlip', 'BoundedScaleAspectJitter', 'PackKIEInputs', 'LoadKIEAnnotations', 'FixInvalidPolygon', 'MMDet2MMOCR', - 'MMOCR2MMDet', 'LoadImageFromLMDB' + 'MMOCR2MMDet', 'LoadImageFromLMDB', 'LoadImageFromFile' ] diff --git a/mmocr/datasets/transforms/loading.py b/mmocr/datasets/transforms/loading.py index b2aaa2df..70ff11a3 100644 --- a/mmocr/datasets/transforms/loading.py +++ b/mmocr/datasets/transforms/loading.py @@ -7,10 +7,83 @@ import mmcv import numpy as np from mmcv.transforms import BaseTransform from mmcv.transforms import LoadAnnotations as MMCV_LoadAnnotations +from mmcv.transforms import LoadImageFromFile as MMCV_LoadImageFromFile from mmocr.registry import TRANSFORMS +@TRANSFORMS.register_module() +class LoadImageFromFile(MMCV_LoadImageFromFile): + """Load an image from file. + + Required Keys: + + - img_path + + Modified Keys: + + - img + - img_shape + - ori_shape + + Args: + to_float32 (bool): Whether to convert the loaded image to a float32 + numpy array. If set to False, the loaded image is an uint8 array. + Defaults to False. + color_type (str): The flag argument for :func:``mmcv.imfrombytes``. + Defaults to 'color'. + imdecode_backend (str): The image decoding backend type. The backend + argument for :func:``mmcv.imfrombytes``. + See :func:``mmcv.imfrombytes`` for details. + Defaults to 'cv2'. + file_client_args (dict): Arguments to instantiate a FileClient. + See :class:`mmcv.fileio.FileClient` for details. + Defaults to ``dict(backend='disk')``. + ignore_empty (bool): Whether to allow loading empty image or file path + not existent. Defaults to False. + min_size (int): The minimum size of the image to be loaded. If the + image is smaller than the minimum size, it will be ignored. + Defaults to 0. + """ + + def __init__(self, + to_float32: bool = False, + color_type: str = 'color', + imdecode_backend: str = 'cv2', + file_client_args: dict = dict(backend='disk'), + min_size: int = 0, + ignore_empty: bool = False) -> None: + self.ignore_empty = ignore_empty + self.to_float32 = to_float32 + self.color_type = color_type + self.imdecode_backend = imdecode_backend + self.file_client_args = file_client_args.copy() + self.file_client = mmcv.FileClient(**self.file_client_args) + self.min_size = min_size + + def transform(self, results: dict) -> Optional[dict]: + """Functions to load image. + + Args: + results (dict): Result dict from :obj:``mmcv.BaseDataset``. + """ + results = super().transform(results) + if min(results['ori_shape']) < self.min_size: + return None + else: + return results + + def __repr__(self): + repr_str = (f'{self.__class__.__name__}(' + f'ignore_empty={self.ignore_empty}, ' + f'min_size={self.min_size}, ' + f'to_float32={self.to_float32}, ' + f"color_type='{self.color_type}', " + f"imdecode_backend='{self.imdecode_backend}', " + f'file_client_args={self.file_client_args})') + return repr_str + + @TRANSFORMS.register_module() class LoadOCRAnnotations(MMCV_LoadAnnotations): """Load and process the ``instances`` annotation provided by dataset. diff --git a/tests/datasets/transforms/test_loading.py b/tests/datasets/transforms/test_loading.py index dc0de156..b4699a14 100644 --- a/tests/datasets/transforms/test_loading.py +++ b/tests/datasets/transforms/test_loading.py @@ -1,11 +1,38 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy +import os.path as osp from unittest import TestCase import numpy as np -from mmocr.datasets.transforms import (LoadImageFromLMDB, LoadKIEAnnotations, - LoadOCRAnnotations) +from mmocr.datasets.transforms import (LoadImageFromFile, LoadImageFromLMDB, + LoadKIEAnnotations, LoadOCRAnnotations) + + +class TestLoadImageFromFile(TestCase): + + def test_load_img(self): + data_prefix = osp.join( + osp.dirname(__file__), '../../data/rec_toy_dataset/imgs/') + + results = dict(img_path=osp.join(data_prefix, '1036169.jpg')) + transform = LoadImageFromFile(min_size=0) + results = transform(copy.deepcopy(results)) + self.assertEquals(results['img_path'], + osp.join(data_prefix, '1036169.jpg')) + self.assertEquals(results['img'].shape, (25, 119, 3)) + self.assertEquals(results['img'].dtype, np.uint8) + self.assertEquals(results['img_shape'], (25, 119)) + self.assertEquals(results['ori_shape'], (25, 119)) + self.assertEquals( + repr(transform), + ('LoadImageFromFile(ignore_empty=False, min_size=0, ' + "to_float32=False, color_type='color', imdecode_backend='cv2', " + "file_client_args={'backend': 'disk'})")) + + transform = LoadImageFromFile(min_size=26) + results = transform(copy.deepcopy(results)) + self.assertIsNone(results) class TestLoadOCRAnnotations(TestCase): @@ -179,9 +206,3 @@ class TestLoadImageFromLMDB(TestCase): "to_float32=False, color_type='color', " "imdecode_backend='cv2', " "file_client_args={'backend': 'lmdb', 'db_path': ''})") - - -if __name__ == '__main__': - test = TestLoadImageFromLMDB() - test.setUp() - test.test_transform()