From a346926996d62518ace72ea0f93fbd02478f4583 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Wed, 13 Oct 2021 15:48:54 -0700 Subject: [PATCH] Add class filtering to `LoadImagesAndLabels()` dataloader (#5172) * Add train class filter feature to datasets.py Allows for training on a subset of total classes if `include_class` list is defined on datasets.py L448: ```python include_class = [] # filter labels to include only these classes (optional) ``` * segments fix --- utils/datasets.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/utils/datasets.py b/utils/datasets.py index a086e6bfb..4f9bd0f05 100755 --- a/utils/datasets.py +++ b/utils/datasets.py @@ -437,10 +437,6 @@ class LoadImagesAndLabels(Dataset): self.shapes = np.array(shapes, dtype=np.float64) self.img_files = list(cache.keys()) # update self.label_files = img2label_paths(cache.keys()) # update - if single_cls: - for x in self.labels: - x[:, 0] = 0 - n = len(shapes) # number of images bi = np.floor(np.arange(n) / batch_size).astype(np.int) # batch index nb = bi[-1] + 1 # number of batches @@ -448,6 +444,20 @@ class LoadImagesAndLabels(Dataset): self.n = n self.indices = range(n) + # Update labels + include_class = [] # filter labels to include only these classes (optional) + include_class_array = np.array(include_class).reshape(1, -1) + for i, (label, segment) in enumerate(zip(self.labels, self.segments)): + if include_class: + j = (label[:, 0:1] == include_class_array).any(1) + self.labels[i] = label[j] + if segment: + self.segments[i] = segment[j] + if single_cls: # single-class training, merge all classes into 0 + self.labels[i][:, 0] = 0 + if segment: + self.segments[i][:, 0] = 0 + # Rectangular Training if self.rect: # Sort by aspect ratio