fix multilabel_dataset bug

pull/2115/head
cuicheng01 2022-06-29 06:51:40 +00:00
parent f3f0605c7e
commit a614dbe313
2 changed files with 5 additions and 4 deletions

View File

@ -122,8 +122,8 @@ Infer:
Metric:
Train:
- HammingDistance:
- AccuracyScore:
- HammingDistance:
Eval:
- HammingDistance:
- AccuracyScore:
- HammingDistance:

View File

@ -28,6 +28,7 @@ class MultiLabelDataset(CommonDataset):
def _load_anno(self, label_ratio=False):
assert os.path.exists(self._cls_path)
assert os.path.exists(self._img_root)
self.label_ratio = label_ratio
self.images = []
self.labels = []
with open(self._cls_path) as fd:
@ -41,7 +42,7 @@ class MultiLabelDataset(CommonDataset):
self.labels.append(labels)
assert os.path.exists(self.images[-1])
if label_ratio:
if self.label_ratio:
return np.array(self.labels).mean(0).astype("float32")
def __getitem__(self, idx):
@ -52,7 +53,7 @@ class MultiLabelDataset(CommonDataset):
img = transform(img, self._transform_ops)
img = img.transpose((2, 0, 1))
label = np.array(self.labels[idx]).astype("float32")
if self.label_ratio is not None:
if self.label_ratio:
return (img, np.array([label, self.label_ratio]))
else:
return (img, label)