dataloaders: fix class filtering for segmentation ()

* dataloaders: fix class filtering for segmentation

self.segments[i] and segment[j] are lists so they cannot be indexed with booleans
self.segments is a tuple so it has to be converted into a list first

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
pull/11161/head^2
Eljas Hyyrynen 2023-03-23 21:29:11 +02:00 committed by GitHub
parent d223460f3a
commit 6dd17516c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 1 deletions

View File

@ -531,13 +531,14 @@ class LoadImagesAndLabels(Dataset):
# Update labels
include_class = [] # filter labels to include only these classes (optional)
self.segments = list(self.segments)
include_class_array = np.array(include_class).reshape(1, -1)
for i, (label, segment) in enumerate(zip(self.labels, self.segments)):
if include_class:
j = (label[:, 0:1] == include_class_array).any(1)
self.labels[i] = label[j]
if segment:
self.segments[i] = segment[j]
self.segments[i] = [segment[idx] for idx, elem in enumerate(j) if elem]
if single_cls: # single-class training, merge all classes into 0
self.labels[i][:, 0] = 0