From dcd90c52bf877ad803e0544aa8cd9ee85e8c9cd4 Mon Sep 17 00:00:00 2001
From: cuicheng01 <cuicheng_smile@163.com>
Date: Thu, 7 Jul 2022 06:13:06 +0000
Subject: [PATCH] update multilabel_dataset.py

---
 ppcls/data/dataloader/multilabel_dataset.py | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/ppcls/data/dataloader/multilabel_dataset.py b/ppcls/data/dataloader/multilabel_dataset.py
index 25dfc12b5..c67a5ae78 100644
--- a/ppcls/data/dataloader/multilabel_dataset.py
+++ b/ppcls/data/dataloader/multilabel_dataset.py
@@ -28,6 +28,7 @@ class MultiLabelDataset(CommonDataset):
     def _load_anno(self, label_ratio=False):
         assert os.path.exists(self._cls_path)
         assert os.path.exists(self._img_root)
+        self.label_ratio = label_ratio
         self.images = []
         self.labels = []
         with open(self._cls_path) as fd:
@@ -41,7 +42,7 @@ class MultiLabelDataset(CommonDataset):
 
                 self.labels.append(labels)
                 assert os.path.exists(self.images[-1])
-        if label_ratio:
+        if self.label_ratio is not False:
             return np.array(self.labels).mean(0).astype("float32")
 
     def __getitem__(self, idx):
@@ -52,7 +53,7 @@ class MultiLabelDataset(CommonDataset):
                 img = transform(img, self._transform_ops)
             img = img.transpose((2, 0, 1))
             label = np.array(self.labels[idx]).astype("float32")
-            if self.label_ratio is not None:
+            if self.label_ratio is not False:
                 return (img, np.array([label, self.label_ratio]))
             else:
                 return (img, label)