[Fix] fix load error (#1212)

pull/1215/head
liukuikun 2022-07-26 16:53:11 +08:00 committed by GitHub
parent 507f0656c9
commit 2cca103b93
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 59 additions and 6 deletions

View File

@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os
import warnings
from typing import Optional
import mmcv
@ -67,11 +68,42 @@ class LoadImageFromFile(MMCV_LoadImageFromFile):
Args:
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
"""
results = super().transform(results)
if results and min(results['ori_shape']) < self.min_size:
"""Functions to load image.
Args:
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
Returns:
dict: The dict contains loaded image and meta information.
"""
filename = results['img_path']
try:
img_bytes = self.file_client.get(filename)
img = mmcv.imfrombytes(
img_bytes, flag=self.color_type, backend=self.imdecode_backend)
except Exception as e:
if self.ignore_empty:
warnings.warn(f'Failed to load {filename} due to {e}')
return None
else:
raise e
if self.ignore_empty and img is None:
warnings.warn(f'Ignore broken image: {filename}')
return None
else:
return results
elif img is None:
raise IOError(f'{filename} is broken')
if self.to_float32:
img = img.astype(np.float32)
if min(img.shape[:2]) < self.min_size:
return None
results['img'] = img
results['img_shape'] = img.shape[:2]
results['ori_shape'] = img.shape[:2]
return results
def __repr__(self):
repr_str = (f'{self.__class__.__name__}('

Binary file not shown.

After

Width:  |  Height:  |  Size: 512 B

View File

@ -30,12 +30,33 @@ class TestLoadImageFromFile(TestCase):
"to_float32=False, color_type='color', imdecode_backend='cv2', "
"file_client_args={'backend': 'disk'})"))
# to_float32
transform = LoadImageFromFile(to_float32=True)
results = transform(copy.deepcopy(results))
self.assertEquals(results['img'].dtype, np.float32)
# min_size
transform = LoadImageFromFile(min_size=26)
results = transform(copy.deepcopy(results))
self.assertIsNone(results)
results = dict(img_path='fake.jpg')
transform = LoadImageFromFile(min_size=26, ignore_empty=True)
# test load empty
fake_img_path = osp.join(data_prefix, 'fake.jpg')
results = dict(img_path=fake_img_path)
transform = LoadImageFromFile(ignore_empty=False)
with self.assertRaises(FileNotFoundError):
transform(copy.deepcopy(results))
transform = LoadImageFromFile(ignore_empty=True)
results = transform(copy.deepcopy(results))
self.assertIsNone(results)
data_prefix = osp.join(osp.dirname(__file__), '../../data')
broken_img_path = osp.join(data_prefix, 'broken.jpg')
results = dict(img_path=broken_img_path)
transform = LoadImageFromFile(ignore_empty=False)
with self.assertRaises(IOError):
transform(copy.deepcopy(results))
transform = LoadImageFromFile(ignore_empty=True)
results = transform(copy.deepcopy(results))
self.assertIsNone(results)