mirror of https://github.com/open-mmlab/mmocr.git
[Fix] fix load error (#1212)
parent
507f0656c9
commit
2cca103b93
|
@ -1,6 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import os
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
import mmcv
|
||||
|
@ -67,11 +68,42 @@ class LoadImageFromFile(MMCV_LoadImageFromFile):
|
|||
Args:
|
||||
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
|
||||
"""
|
||||
results = super().transform(results)
|
||||
if results and min(results['ori_shape']) < self.min_size:
|
||||
"""Functions to load image.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded image and meta information.
|
||||
"""
|
||||
|
||||
filename = results['img_path']
|
||||
try:
|
||||
img_bytes = self.file_client.get(filename)
|
||||
img = mmcv.imfrombytes(
|
||||
img_bytes, flag=self.color_type, backend=self.imdecode_backend)
|
||||
except Exception as e:
|
||||
if self.ignore_empty:
|
||||
warnings.warn(f'Failed to load {filename} due to {e}')
|
||||
return None
|
||||
else:
|
||||
raise e
|
||||
if self.ignore_empty and img is None:
|
||||
warnings.warn(f'Ignore broken image: {filename}')
|
||||
return None
|
||||
else:
|
||||
return results
|
||||
elif img is None:
|
||||
raise IOError(f'{filename} is broken')
|
||||
|
||||
if self.to_float32:
|
||||
img = img.astype(np.float32)
|
||||
|
||||
if min(img.shape[:2]) < self.min_size:
|
||||
return None
|
||||
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape[:2]
|
||||
results['ori_shape'] = img.shape[:2]
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = (f'{self.__class__.__name__}('
|
||||
|
|
Binary file not shown.
After Width: | Height: | Size: 512 B |
|
@ -30,12 +30,33 @@ class TestLoadImageFromFile(TestCase):
|
|||
"to_float32=False, color_type='color', imdecode_backend='cv2', "
|
||||
"file_client_args={'backend': 'disk'})"))
|
||||
|
||||
# to_float32
|
||||
transform = LoadImageFromFile(to_float32=True)
|
||||
results = transform(copy.deepcopy(results))
|
||||
self.assertEquals(results['img'].dtype, np.float32)
|
||||
|
||||
# min_size
|
||||
transform = LoadImageFromFile(min_size=26)
|
||||
results = transform(copy.deepcopy(results))
|
||||
self.assertIsNone(results)
|
||||
|
||||
results = dict(img_path='fake.jpg')
|
||||
transform = LoadImageFromFile(min_size=26, ignore_empty=True)
|
||||
# test load empty
|
||||
fake_img_path = osp.join(data_prefix, 'fake.jpg')
|
||||
results = dict(img_path=fake_img_path)
|
||||
transform = LoadImageFromFile(ignore_empty=False)
|
||||
with self.assertRaises(FileNotFoundError):
|
||||
transform(copy.deepcopy(results))
|
||||
transform = LoadImageFromFile(ignore_empty=True)
|
||||
results = transform(copy.deepcopy(results))
|
||||
self.assertIsNone(results)
|
||||
|
||||
data_prefix = osp.join(osp.dirname(__file__), '../../data')
|
||||
broken_img_path = osp.join(data_prefix, 'broken.jpg')
|
||||
results = dict(img_path=broken_img_path)
|
||||
transform = LoadImageFromFile(ignore_empty=False)
|
||||
with self.assertRaises(IOError):
|
||||
transform(copy.deepcopy(results))
|
||||
transform = LoadImageFromFile(ignore_empty=True)
|
||||
results = transform(copy.deepcopy(results))
|
||||
self.assertIsNone(results)
|
||||
|
||||
|
|
Loading…
Reference in New Issue