Add `min_items` filter option (#9997)

* Add `min_items` filter option

@AyushExel @Laughing-q dataset filter

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>

* Update dataloaders.py

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
pull/10000/head
Glenn Jocher 2022-11-01 14:53:14 +01:00 committed by GitHub
parent cf99788823
commit c55e2cd73b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 17 additions and 3 deletions

View File

@ -444,6 +444,7 @@ class LoadImagesAndLabels(Dataset):
single_cls=False,
stride=32,
pad=0.0,
min_items=0,
prefix=''):
self.img_size = img_size
self.augment = augment
@ -475,7 +476,7 @@ class LoadImagesAndLabels(Dataset):
# self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
assert self.im_files, f'{prefix}No images found'
except Exception as e:
raise Exception(f'{prefix}Error loading data from {path}: {e}\n{HELP_URL}')
raise Exception(f'{prefix}Error loading data from {path}: {e}\n{HELP_URL}') from e
# Check cache
self.label_files = img2label_paths(self.im_files) # labels
@ -505,7 +506,19 @@ class LoadImagesAndLabels(Dataset):
self.shapes = np.array(shapes)
self.im_files = list(cache.keys()) # update
self.label_files = img2label_paths(cache.keys()) # update
n = len(shapes) # number of images
# Filter images
if min_items:
include = np.array([len(x) > min_items for x in self.labels]).nonzero()[0].astype(int)
LOGGER.info(f'{prefix}{nf - len(include)}/{nf} images filtered from dataset')
self.im_files = [self.im_files[i] for i in include]
self.label_files = [self.label_files[i] for i in include]
self.labels = [self.labels[i] for i in include]
self.segments = [self.segments[i] for i in include]
self.shapes = self.shapes[include] # wh
# Create indices
n = len(self.shapes) # number of images
bi = np.floor(np.arange(n) / batch_size).astype(int) # batch index
nb = bi[-1] + 1 # number of batches
self.batch = bi # batch index of image

View File

@ -93,12 +93,13 @@ class LoadImagesAndLabelsAndMasks(LoadImagesAndLabels): # for training/testing
single_cls=False,
stride=32,
pad=0,
min_items=0,
prefix="",
downsample_ratio=1,
overlap=False,
):
super().__init__(path, img_size, batch_size, augment, hyp, rect, image_weights, cache_images, single_cls,
stride, pad, prefix)
stride, pad, min_items, prefix)
self.downsample_ratio = downsample_ratio
self.overlap = overlap