Add `min_items` filter option (#9997)
* Add `min_items` filter option @AyushExel @Laughing-q dataset filter Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> * Update dataloaders.py Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>pull/10000/head
parent
cf99788823
commit
c55e2cd73b
|
@ -444,6 +444,7 @@ class LoadImagesAndLabels(Dataset):
|
|||
single_cls=False,
|
||||
stride=32,
|
||||
pad=0.0,
|
||||
min_items=0,
|
||||
prefix=''):
|
||||
self.img_size = img_size
|
||||
self.augment = augment
|
||||
|
@ -475,7 +476,7 @@ class LoadImagesAndLabels(Dataset):
|
|||
# self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
|
||||
assert self.im_files, f'{prefix}No images found'
|
||||
except Exception as e:
|
||||
raise Exception(f'{prefix}Error loading data from {path}: {e}\n{HELP_URL}')
|
||||
raise Exception(f'{prefix}Error loading data from {path}: {e}\n{HELP_URL}') from e
|
||||
|
||||
# Check cache
|
||||
self.label_files = img2label_paths(self.im_files) # labels
|
||||
|
@ -505,7 +506,19 @@ class LoadImagesAndLabels(Dataset):
|
|||
self.shapes = np.array(shapes)
|
||||
self.im_files = list(cache.keys()) # update
|
||||
self.label_files = img2label_paths(cache.keys()) # update
|
||||
n = len(shapes) # number of images
|
||||
|
||||
# Filter images
|
||||
if min_items:
|
||||
include = np.array([len(x) > min_items for x in self.labels]).nonzero()[0].astype(int)
|
||||
LOGGER.info(f'{prefix}{nf - len(include)}/{nf} images filtered from dataset')
|
||||
self.im_files = [self.im_files[i] for i in include]
|
||||
self.label_files = [self.label_files[i] for i in include]
|
||||
self.labels = [self.labels[i] for i in include]
|
||||
self.segments = [self.segments[i] for i in include]
|
||||
self.shapes = self.shapes[include] # wh
|
||||
|
||||
# Create indices
|
||||
n = len(self.shapes) # number of images
|
||||
bi = np.floor(np.arange(n) / batch_size).astype(int) # batch index
|
||||
nb = bi[-1] + 1 # number of batches
|
||||
self.batch = bi # batch index of image
|
||||
|
|
|
@ -93,12 +93,13 @@ class LoadImagesAndLabelsAndMasks(LoadImagesAndLabels): # for training/testing
|
|||
single_cls=False,
|
||||
stride=32,
|
||||
pad=0,
|
||||
min_items=0,
|
||||
prefix="",
|
||||
downsample_ratio=1,
|
||||
overlap=False,
|
||||
):
|
||||
super().__init__(path, img_size, batch_size, augment, hyp, rect, image_weights, cache_images, single_cls,
|
||||
stride, pad, prefix)
|
||||
stride, pad, min_items, prefix)
|
||||
self.downsample_ratio = downsample_ratio
|
||||
self.overlap = overlap
|
||||
|
||||
|
|
Loading…
Reference in New Issue