Merge pull request #2138 from cuicheng01/release/2.4
update multilabel_dataset.pypull/1975/head^2 v2.4.0
commit
3a28ee2900
|
@ -28,6 +28,7 @@ class MultiLabelDataset(CommonDataset):
|
|||
def _load_anno(self, label_ratio=False):
|
||||
assert os.path.exists(self._cls_path)
|
||||
assert os.path.exists(self._img_root)
|
||||
self.label_ratio = label_ratio
|
||||
self.images = []
|
||||
self.labels = []
|
||||
with open(self._cls_path) as fd:
|
||||
|
@ -41,7 +42,7 @@ class MultiLabelDataset(CommonDataset):
|
|||
|
||||
self.labels.append(labels)
|
||||
assert os.path.exists(self.images[-1])
|
||||
if label_ratio:
|
||||
if self.label_ratio is not False:
|
||||
return np.array(self.labels).mean(0).astype("float32")
|
||||
|
||||
def __getitem__(self, idx):
|
||||
|
@ -52,7 +53,7 @@ class MultiLabelDataset(CommonDataset):
|
|||
img = transform(img, self._transform_ops)
|
||||
img = img.transpose((2, 0, 1))
|
||||
label = np.array(self.labels[idx]).astype("float32")
|
||||
if self.label_ratio is not None:
|
||||
if self.label_ratio is not False:
|
||||
return (img, np.array([label, self.label_ratio]))
|
||||
else:
|
||||
return (img, label)
|
||||
|
|
Loading…
Reference in New Issue