diff --git a/ppcls/arch/backbone/model_zoo/strongbaseline_attr.py b/ppcls/arch/backbone/model_zoo/strongbaseline_attr.py
index 6bb445d1f..7e2725545 100644
--- a/ppcls/arch/backbone/model_zoo/strongbaseline_attr.py
+++ b/ppcls/arch/backbone/model_zoo/strongbaseline_attr.py
@@ -1,4 +1,4 @@
-# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
diff --git a/ppcls/configs/Attr/StrongBaselineAttr.yaml b/ppcls/configs/Attr/StrongBaselineAttr.yaml
index df6916bde..6718d8cbb 100644
--- a/ppcls/configs/Attr/StrongBaselineAttr.yaml
+++ b/ppcls/configs/Attr/StrongBaselineAttr.yaml
@@ -14,7 +14,6 @@ Global:
   image_shape: [3, 256, 192]
   save_inference_dir: "./inference"
   use_multilabel: True
-  metric_attr: True
 
 # model architecture
 Arch:
@@ -26,11 +25,15 @@ Arch:
 # loss function config for traing/eval process
 Loss:
   Train:
-    - BCELoss:
+    - MultiLabelLoss:
         weight: 1.0
+        weight_ratio: True
+        size_sum: True
   Eval:
-    - BCELoss:
+    - MultiLabelLoss:
         weight: 1.0
+        weight_ratio: True
+        size_sum: True
 
 Optimizer:
   name: Adam
@@ -47,10 +50,10 @@ Optimizer:
 DataLoader:
   Train:
     dataset:
-      name: AttrDataset
+      name: MultiLabelDataset
       image_root: "dataset/xingrenfenxi/data/"
-      cls_label_path: "dataset/xingrenfenxi/all_qiye.pkl"
-      split: 'trainval'
+      cls_label_path: "dataset/xingrenfenxi/trainval.txt"
+      label_ratio: True
       transform_ops:
         - DecodeImage:
             to_rgb: True
@@ -80,10 +83,10 @@ DataLoader:
       use_shared_memory: True
   Eval:
     dataset:
-      name: AttrDataset
+      name: MultiLabelDataset
       image_root: "dataset/xingrenfenxi/data/"
-      cls_label_path: "dataset/xingrenfenxi/all_qiye.pkl"
-      split: 'test'
+      cls_label_path: "dataset/xingrenfenxi/test.txt"
+      label_ratio: True
       transform_ops:
         - DecodeImage:
             to_rgb: True
diff --git a/ppcls/data/__init__.py b/ppcls/data/__init__.py
index 00bc50f0d..9fc4d760b 100644
--- a/ppcls/data/__init__.py
+++ b/ppcls/data/__init__.py
@@ -30,7 +30,6 @@ from ppcls.data.dataloader.icartoon_dataset import ICartoonDataset
 from ppcls.data.dataloader.mix_dataset import MixDataset
 from ppcls.data.dataloader.multi_scale_dataset import MultiScaleDataset
 from ppcls.data.dataloader.person_dataset import Market1501, MSMT17
-from ppcls.data.dataloader.attr_dataset import AttrDataset
 
 
 # sampler
diff --git a/ppcls/data/dataloader/__init__.py b/ppcls/data/dataloader/__init__.py
index 7581daa0a..2b1d92b76 100644
--- a/ppcls/data/dataloader/__init__.py
+++ b/ppcls/data/dataloader/__init__.py
@@ -10,4 +10,3 @@ from ppcls.data.dataloader.mix_sampler import MixSampler
 from ppcls.data.dataloader.multi_scale_sampler import MultiScaleSampler
 from ppcls.data.dataloader.pk_sampler import PKSampler
 from ppcls.data.dataloader.person_dataset import Market1501, MSMT17
-from ppcls.data.dataloader.attr_dataset import AttrDataset
diff --git a/ppcls/data/dataloader/attr_dataset.py b/ppcls/data/dataloader/attr_dataset.py
deleted file mode 100644
index f4aaef2db..000000000
--- a/ppcls/data/dataloader/attr_dataset.py
+++ /dev/null
@@ -1,82 +0,0 @@
-#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from __future__ import print_function
-
-import numpy as np
-import os
-import pickle
-
-from .common_dataset import CommonDataset
-from ppcls.data.preprocess import transform
-
-
-class AttrDataset(CommonDataset):
-    def _load_anno(self, seed=None, split='trainval'):
-        assert os.path.exists(self._cls_path)
-        assert os.path.exists(self._img_root)
-        anno_path = self._cls_path
-        image_dir = self._img_root
-        self.images = []
-        self.labels = []
-
-        dataset_info = pickle.load(open(anno_path, 'rb+'))
-        img_id = dataset_info.image_name
-
-        attr_label = dataset_info.label
-        attr_label[attr_label == 2] = 0
-        attr_id = dataset_info.attr_name
-        if 'label_idx' in dataset_info.keys():
-            eval_attr_idx = dataset_info.label_idx.eval
-            attr_label = attr_label[:, eval_attr_idx]
-            attr_id = [attr_id[i] for i in eval_attr_idx]
-
-        attr_num = len(attr_id)
-
-        # mapping category name to class id
-        # first_class:0, second_class:1, ...
-        cname2cid = {attr_id[i]: i for i in range(attr_num)}
-
-        assert split in dataset_info.partition.keys(
-        ), f'split {split} is not exist'
-
-        img_idx = dataset_info.partition[split]
-
-        if isinstance(img_idx, list):
-            img_idx = img_idx[0]  # default partition 0
-
-        img_num = img_idx.shape[0]
-        img_id = [img_id[i] for i in img_idx]
-        label = attr_label[img_idx]  # [:, [0, 12]]
-        self.label_ratio = label.mean(0)
-        print("label_ratio:", self.label_ratio)
-        for i, (img_i, label_i) in enumerate(zip(img_id, label)):
-            imgname = os.path.join(image_dir, img_i)
-            self.images.append(imgname)
-            self.labels.append(np.int64(label_i))
-
-    def __getitem__(self, idx):
-        try:
-            with open(self.images[idx], 'rb') as f:
-                img = f.read()
-            if self._transform_ops:
-                img = transform(img, self._transform_ops)
-            img = img.transpose((2, 0, 1))
-            return (img, [self.labels[idx], self.label_ratio])
-
-        except Exception as ex:
-            logger.error("Exception occured when parse line: {} with msg: {}".
-                         format(self.images[idx], ex))
-            rnd_idx = np.random.randint(self.__len__())
-            return self.__getitem__(rnd_idx)
diff --git a/ppcls/data/dataloader/common_dataset.py b/ppcls/data/dataloader/common_dataset.py
index fb251a7fc..88bab0f1d 100644
--- a/ppcls/data/dataloader/common_dataset.py
+++ b/ppcls/data/dataloader/common_dataset.py
@@ -48,7 +48,7 @@ class CommonDataset(Dataset):
                  image_root,
                  cls_label_path,
                  transform_ops=None,
-                 split='trainval'):
+                 label_ratio=False):
         self._img_root = image_root
         self._cls_path = cls_label_path
         if transform_ops:
@@ -56,7 +56,10 @@ class CommonDataset(Dataset):
 
         self.images = []
         self.labels = []
-        self._load_anno(split=split)
+        if label_ratio:
+            self.label_ratio = self._load_anno(label_ratio=label_ratio)
+        else:
+            self._load_anno()
 
     def _load_anno(self):
         pass
diff --git a/ppcls/data/dataloader/multilabel_dataset.py b/ppcls/data/dataloader/multilabel_dataset.py
index 2c1ed7703..16bd5481f 100644
--- a/ppcls/data/dataloader/multilabel_dataset.py
+++ b/ppcls/data/dataloader/multilabel_dataset.py
@@ -25,7 +25,7 @@ from .common_dataset import CommonDataset
 
 
 class MultiLabelDataset(CommonDataset):
-    def _load_anno(self):
+    def _load_anno(self, label_ratio=False):
         assert os.path.exists(self._cls_path)
         assert os.path.exists(self._img_root)
         self.images = []
@@ -41,6 +41,8 @@ class MultiLabelDataset(CommonDataset):
 
                 self.labels.append(labels)
                 assert os.path.exists(self.images[-1])
+        if label_ratio:
+            return np.array(self.labels).mean(0)
 
     def __getitem__(self, idx):
         try:
@@ -50,7 +52,10 @@ class MultiLabelDataset(CommonDataset):
                 img = transform(img, self._transform_ops)
             img = img.transpose((2, 0, 1))
             label = np.array(self.labels[idx]).astype("float32")
-            return (img, label)
+            if self.label_ratio is not None:
+                return (img, [label, self.label_ratio])
+            else:
+                return (img, label)
 
         except Exception as ex:
             logger.error("Exception occured when parse line: {} with msg: {}".
diff --git a/ppcls/engine/evaluation/classification.py b/ppcls/engine/evaluation/classification.py
index d543b0b3f..ea78fe5b5 100644
--- a/ppcls/engine/evaluation/classification.py
+++ b/ppcls/engine/evaluation/classification.py
@@ -32,8 +32,8 @@ def classification_eval(engine, epoch_id=0):
     }
     print_batch_step = engine.config["Global"]["print_batch_step"]
 
-    if engine.eval_metric_func is not None and engine.config["Global"][
-            "metric_attr"]:
+    if engine.eval_metric_func is not None and engine.config["Arch"][
+            "name"] == "StrongBaselineAttr":
         output_info["attr"] = AttrMeter(threshold=0.5)
 
     metric_key = None
@@ -128,7 +128,7 @@ def classification_eval(engine, epoch_id=0):
 
         #  calc metric
         if engine.eval_metric_func is not None:
-            if engine.config["Global"]["metric_attr"]:
+            if engine.config["Arch"]["name"] == "StrongBaselineAttr":
                 metric_dict = engine.eval_metric_func(preds, labels)
                 metric_key = "attr"
                 output_info["attr"].update(metric_dict)
@@ -153,7 +153,7 @@ def classification_eval(engine, epoch_id=0):
             ips_msg = "ips: {:.5f} images/sec".format(
                 batch_size / time_info["batch_cost"].avg)
 
-            if engine.config["Global"]["metric_attr"]:
+            if engine.config["Arch"]["name"] == "StrongBaselineAttr":
                 metric_msg = ""
             else:
                 metric_msg = ", ".join([
@@ -168,7 +168,7 @@ def classification_eval(engine, epoch_id=0):
     if engine.use_dali:
         engine.eval_dataloader.reset()
 
-    if engine.config["Global"]["metric_attr"]:
+    if engine.config["Arch"]["name"] == "StrongBaselineAttr":
         metric_msg = ", ".join([
             "evalres: ma: {:.5f} label_f1: {:.5f} label_pos_recall: {:.5f} label_neg_recall: {:.5f} instance_f1: {:.5f} instance_acc: {:.5f} instance_prec: {:.5f} instance_recall: {:.5f}".
             format(*output_info["attr"].res())
diff --git a/ppcls/loss/__init__.py b/ppcls/loss/__init__.py
index 65e1b0b76..c1f2f95df 100644
--- a/ppcls/loss/__init__.py
+++ b/ppcls/loss/__init__.py
@@ -26,7 +26,6 @@ from .distillationloss import DistillationKLDivLoss
 from .distillationloss import DistillationDKDLoss
 from .multilabelloss import MultiLabelLoss
 from .afdloss import AFDLoss
-from .bceloss import BCELoss
 
 from .deephashloss import DSHSDLoss
 from .deephashloss import LCDSHLoss
diff --git a/ppcls/loss/bceloss.py b/ppcls/loss/bceloss.py
deleted file mode 100644
index 58418058d..000000000
--- a/ppcls/loss/bceloss.py
+++ /dev/null
@@ -1,59 +0,0 @@
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-import paddle
-import paddle.nn as nn
-import paddle.nn.functional as F
-
-
-def ratio2weight(targets, ratio):
-    #     print(targets)
-    pos_weights = targets * (1. - ratio)
-    neg_weights = (1. - targets) * ratio
-    weights = paddle.exp(neg_weights + pos_weights)
-
-    # for RAP dataloader, targets element may be 2, with or without smooth, some element must great than 1
-    weights = weights - weights * (targets > 1)
-
-    return weights
-
-
-class BCELoss(nn.Layer):
-    """BCE Loss.
-    Args:
-        
-    """
-
-    def __init__(self,
-                 sample_weight=True,
-                 size_sum=True,
-                 smoothing=None,
-                 weight=1.0):
-        super(BCELoss, self).__init__()
-
-        self.sample_weight = sample_weight
-        self.size_sum = size_sum
-        self.hyper = 0.8
-        self.smoothing = smoothing
-
-    def forward(self, logits, labels):
-        targets, ratio = labels
-
-        if self.smoothing is not None:
-            targets = (1 - self.smoothing) * targets + self.smoothing * (
-                1 - targets)
-
-        targets = paddle.cast(targets, 'float32')
-
-        loss_m = F.binary_cross_entropy_with_logits(
-            logits, targets, reduction='none')
-
-        targets_mask = paddle.cast(targets > 0.5, 'float32')
-        if self.sample_weight:
-            weight = ratio2weight(targets_mask, ratio[0])
-            weight = weight * (targets > -1)
-            loss_m = loss_m * weight
-
-        loss = loss_m.sum(1).mean() if self.size_sum else loss_m.sum()
-
-        return {"BCELoss": loss}
diff --git a/ppcls/loss/multilabelloss.py b/ppcls/loss/multilabelloss.py
index 550db40f2..52c31c7da 100644
--- a/ppcls/loss/multilabelloss.py
+++ b/ppcls/loss/multilabelloss.py
@@ -19,12 +19,13 @@ class MultiLabelLoss(nn.Layer):
     Multi-label loss
     """
 
-    def __init__(self, epsilon=None, weight_ratio=None):
+    def __init__(self, epsilon=None, size_sum=False, weight_ratio=False):
         super().__init__()
         if epsilon is not None and (epsilon <= 0 or epsilon >= 1):
             epsilon = None
         self.epsilon = epsilon
         self.weight_ratio = weight_ratio
+        self.size_sum = size_sum
 
     def _labelsmoothing(self, target, class_num):
         if target.ndim == 1 or target.shape[-1] != class_num:
@@ -36,18 +37,21 @@ class MultiLabelLoss(nn.Layer):
         return soft_target
 
     def _binary_crossentropy(self, input, target, class_num):
+        if self.weight_ratio:
+            target, label_ratio = target
         if self.epsilon is not None:
             target = self._labelsmoothing(target, class_num)
-        cost = F.binary_cross_entropy_with_logits(logit=input, label=target)
+        cost = F.binary_cross_entropy_with_logits(
+            logit=input, label=target, reduction='none')
 
-        if self.weight_ratio is not None:
+        if self.weight_ratio:
             targets_mask = paddle.cast(target > 0.5, 'float32')
-            weight = ratio2weight(targets_mask,
-                                  paddle.to_tensor(self.weight_ratio))
+            weight = ratio2weight(targets_mask, paddle.to_tensor(label_ratio))
             weight = weight * (target > -1)
             cost = cost * weight
-            import pdb
-            pdb.set_trace()
+
+        if self.size_sum:
+            cost = cost.sum(1).mean() if self.size_sum else cost.mean()
 
         return cost
 
diff --git a/requirements.txt b/requirements.txt
index 79f548c22..5e927756a 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -9,3 +9,4 @@ scipy
 scikit-learn==0.23.2
 gast==0.3.3
 faiss-cpu==1.7.1.post2
+easydict