mirror of
https://github.com/ultralytics/yolov5.git
synced 2025-06-03 14:49:29 +08:00
Improved corruption handling during scan and cache (#999)
This commit is contained in:
parent
0fda95aaf4
commit
d61930e017
@ -328,6 +328,12 @@ class LoadStreams: # multiple IP or RTSP cameras
|
||||
class LoadImagesAndLabels(Dataset): # for training/testing
|
||||
def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False,
|
||||
cache_images=False, single_cls=False, stride=32, pad=0.0, rank=-1):
|
||||
|
||||
def img2label_paths(img_paths):
|
||||
# Define label paths as a function of image paths
|
||||
sa, sb = os.sep + 'images' + os.sep, os.sep + 'labels' + os.sep # /images/, /labels/ substrings
|
||||
return [x.replace(sa, sb, 1).replace(os.path.splitext(x)[-1], '.txt') for x in img_paths]
|
||||
|
||||
try:
|
||||
f = [] # image files
|
||||
for p in path if isinstance(path, list) else [path]:
|
||||
@ -362,11 +368,8 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
||||
self.mosaic_border = [-img_size // 2, -img_size // 2]
|
||||
self.stride = stride
|
||||
|
||||
# Define labels
|
||||
sa, sb = os.sep + 'images' + os.sep, os.sep + 'labels' + os.sep # /images/, /labels/ substrings
|
||||
self.label_files = [x.replace(sa, sb, 1).replace(os.path.splitext(x)[-1], '.txt') for x in self.img_files]
|
||||
|
||||
# Check cache
|
||||
self.label_files = img2label_paths(self.img_files) # labels
|
||||
cache_path = str(Path(self.label_files[0]).parent) + '.cache' # cached labels
|
||||
if os.path.isfile(cache_path):
|
||||
cache = torch.load(cache_path) # load
|
||||
@ -375,12 +378,15 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
||||
else:
|
||||
cache = self.cache_labels(cache_path) # cache
|
||||
|
||||
# Get labels
|
||||
labels, shapes = zip(*[cache[x] for x in self.img_files])
|
||||
self.shapes = np.array(shapes, dtype=np.float64)
|
||||
# Read cache
|
||||
cache.pop('hash') # remove hash
|
||||
labels, shapes = zip(*cache.values())
|
||||
self.labels = list(labels)
|
||||
self.shapes = np.array(shapes, dtype=np.float64)
|
||||
self.img_files = list(cache.keys()) # update
|
||||
self.label_files = img2label_paths(cache.keys()) # update
|
||||
|
||||
# Rectangular Training https://github.com/ultralytics/yolov3/issues/232
|
||||
# Rectangular Training
|
||||
if self.rect:
|
||||
# Sort by aspect ratio
|
||||
s = self.shapes # wh
|
||||
@ -404,7 +410,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
||||
|
||||
self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride
|
||||
|
||||
# Cache labels
|
||||
# Check labels
|
||||
create_datasubset, extract_bounding_boxes, labels_loaded = False, False, False
|
||||
nm, nf, ne, ns, nd = 0, 0, 0, 0, 0 # number missing, found, empty, datasubset, duplicate
|
||||
pbar = enumerate(self.label_files)
|
||||
@ -483,10 +489,9 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
||||
for (img, label) in pbar:
|
||||
try:
|
||||
l = []
|
||||
image = Image.open(img)
|
||||
image.verify() # PIL verify
|
||||
# _ = io.imread(img) # skimage verify (from skimage import io)
|
||||
shape = exif_size(image) # image size
|
||||
im = Image.open(img)
|
||||
im.verify() # PIL verify
|
||||
shape = exif_size(im) # image size
|
||||
assert (shape[0] > 9) & (shape[1] > 9), 'image size <10 pixels'
|
||||
if os.path.isfile(label):
|
||||
with open(label, 'r') as f:
|
||||
@ -495,8 +500,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
||||
l = np.zeros((0, 5), dtype=np.float32)
|
||||
x[img] = [l, shape]
|
||||
except Exception as e:
|
||||
x[img] = [None, None]
|
||||
print('WARNING: %s: %s' % (img, e))
|
||||
print('WARNING: Ignoring corrupted image and/or label:%s: %s' % (img, e))
|
||||
|
||||
x['hash'] = get_hash(self.label_files + self.img_files)
|
||||
torch.save(x, path) # save for next time
|
||||
|
Loading…
x
Reference in New Issue
Block a user