Add `min_items` filter option (#9997)
* Add `min_items` filter option @AyushExel @Laughing-q dataset filter Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> * Update dataloaders.py Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>pull/10000/head
parent
cf99788823
commit
c55e2cd73b
|
@ -444,6 +444,7 @@ class LoadImagesAndLabels(Dataset):
|
||||||
single_cls=False,
|
single_cls=False,
|
||||||
stride=32,
|
stride=32,
|
||||||
pad=0.0,
|
pad=0.0,
|
||||||
|
min_items=0,
|
||||||
prefix=''):
|
prefix=''):
|
||||||
self.img_size = img_size
|
self.img_size = img_size
|
||||||
self.augment = augment
|
self.augment = augment
|
||||||
|
@ -475,7 +476,7 @@ class LoadImagesAndLabels(Dataset):
|
||||||
# self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
|
# self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
|
||||||
assert self.im_files, f'{prefix}No images found'
|
assert self.im_files, f'{prefix}No images found'
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception(f'{prefix}Error loading data from {path}: {e}\n{HELP_URL}')
|
raise Exception(f'{prefix}Error loading data from {path}: {e}\n{HELP_URL}') from e
|
||||||
|
|
||||||
# Check cache
|
# Check cache
|
||||||
self.label_files = img2label_paths(self.im_files) # labels
|
self.label_files = img2label_paths(self.im_files) # labels
|
||||||
|
@ -505,7 +506,19 @@ class LoadImagesAndLabels(Dataset):
|
||||||
self.shapes = np.array(shapes)
|
self.shapes = np.array(shapes)
|
||||||
self.im_files = list(cache.keys()) # update
|
self.im_files = list(cache.keys()) # update
|
||||||
self.label_files = img2label_paths(cache.keys()) # update
|
self.label_files = img2label_paths(cache.keys()) # update
|
||||||
n = len(shapes) # number of images
|
|
||||||
|
# Filter images
|
||||||
|
if min_items:
|
||||||
|
include = np.array([len(x) > min_items for x in self.labels]).nonzero()[0].astype(int)
|
||||||
|
LOGGER.info(f'{prefix}{nf - len(include)}/{nf} images filtered from dataset')
|
||||||
|
self.im_files = [self.im_files[i] for i in include]
|
||||||
|
self.label_files = [self.label_files[i] for i in include]
|
||||||
|
self.labels = [self.labels[i] for i in include]
|
||||||
|
self.segments = [self.segments[i] for i in include]
|
||||||
|
self.shapes = self.shapes[include] # wh
|
||||||
|
|
||||||
|
# Create indices
|
||||||
|
n = len(self.shapes) # number of images
|
||||||
bi = np.floor(np.arange(n) / batch_size).astype(int) # batch index
|
bi = np.floor(np.arange(n) / batch_size).astype(int) # batch index
|
||||||
nb = bi[-1] + 1 # number of batches
|
nb = bi[-1] + 1 # number of batches
|
||||||
self.batch = bi # batch index of image
|
self.batch = bi # batch index of image
|
||||||
|
|
|
@ -93,12 +93,13 @@ class LoadImagesAndLabelsAndMasks(LoadImagesAndLabels): # for training/testing
|
||||||
single_cls=False,
|
single_cls=False,
|
||||||
stride=32,
|
stride=32,
|
||||||
pad=0,
|
pad=0,
|
||||||
|
min_items=0,
|
||||||
prefix="",
|
prefix="",
|
||||||
downsample_ratio=1,
|
downsample_ratio=1,
|
||||||
overlap=False,
|
overlap=False,
|
||||||
):
|
):
|
||||||
super().__init__(path, img_size, batch_size, augment, hyp, rect, image_weights, cache_images, single_cls,
|
super().__init__(path, img_size, batch_size, augment, hyp, rect, image_weights, cache_images, single_cls,
|
||||||
stride, pad, prefix)
|
stride, pad, min_items, prefix)
|
||||||
self.downsample_ratio = downsample_ratio
|
self.downsample_ratio = downsample_ratio
|
||||||
self.overlap = overlap
|
self.overlap = overlap
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue