diff --git a/projects/FastAttr/configs/dukemtmc.yml b/projects/FastAttr/configs/dukemtmc.yml new file mode 100644 index 0000000..fc68deb --- /dev/null +++ b/projects/FastAttr/configs/dukemtmc.yml @@ -0,0 +1,11 @@ +_BASE_: Base-attribute.yml + +DATASETS: + NAMES: ("DukeMTMCAttr",) + TESTS: ("DukeMTMCAttr",) + +MODEL: + HEADS: + NUM_CLASSES: 23 + +OUTPUT_DIR: projects/FastAttr/logs/dukemtmc/strong_baseline \ No newline at end of file diff --git a/projects/FastAttr/configs/market1501.yml b/projects/FastAttr/configs/market1501.yml new file mode 100644 index 0000000..8ec26d9 --- /dev/null +++ b/projects/FastAttr/configs/market1501.yml @@ -0,0 +1,11 @@ +_BASE_: Base-attribute.yml + +DATASETS: + NAMES: ("Market1501Attr",) + TESTS: ("Market1501Attr",) + +MODEL: + HEADS: + NUM_CLASSES: 27 + +OUTPUT_DIR: projects/FastAttr/logs/market1501/strong_baseline \ No newline at end of file diff --git a/projects/FastAttr/fastattr/attr_evaluation.py b/projects/FastAttr/fastattr/attr_evaluation.py index 3ab7247..eb06dde 100644 --- a/projects/FastAttr/fastattr/attr_evaluation.py +++ b/projects/FastAttr/fastattr/attr_evaluation.py @@ -36,6 +36,8 @@ class AttrEvaluator(DatasetEvaluator): @staticmethod def get_attr_metrics(gt_labels, pred_logits, thres): + eps = 1e-20 + pred_labels = copy.deepcopy(pred_logits) pred_labels[pred_logits < thres] = 0 pred_labels[pred_logits >= thres] = 1 @@ -53,13 +55,13 @@ class AttrEvaluator(DatasetEvaluator): gt_labels = gt_labels.astype(bool) intersect = (pred_labels & gt_labels).astype(float) union = (pred_labels | gt_labels).astype(float) - ins_acc = (intersect.sum(axis=1) / union.sum(axis=1)).mean() - ins_prec = (intersect.sum(axis=1) / pred_labels.astype(float).sum(axis=1)).mean() - ins_rec = (intersect.sum(axis=1) / gt_labels.astype(float).sum(axis=1)).mean() - ins_f1 = (2 * ins_prec * ins_rec) / (ins_prec + ins_rec) + ins_acc = (intersect.sum(axis=1) / (union.sum(axis=1) + eps)).mean() + ins_prec = (intersect.sum(axis=1) / (pred_labels.astype(float).sum(axis=1) + eps)).mean() + ins_rec = (intersect.sum(axis=1) / (gt_labels.astype(float).sum(axis=1) + eps)).mean() + ins_f1 = (2 * ins_prec * ins_rec) / (ins_prec + ins_rec + eps) - term1 = correct_pos / real_pos - term2 = correct_neg / real_neg + term1 = correct_pos / (real_pos + eps) + term2 = correct_neg / (real_neg + eps) label_mA_verbose = (term1 + term2) * 0.5 label_mA = label_mA_verbose.mean() diff --git a/projects/FastAttr/fastattr/datasets/__init__.py b/projects/FastAttr/fastattr/datasets/__init__.py index 18050a2..90970bb 100644 --- a/projects/FastAttr/fastattr/datasets/__init__.py +++ b/projects/FastAttr/fastattr/datasets/__init__.py @@ -6,3 +6,5 @@ # Attributed datasets from .pa100k import PA100K +from .market1501attr import Market1501Attr +from .dukemtmcattr import DukeMTMCAttr diff --git a/projects/FastAttr/fastattr/datasets/dukemtmcattr.py b/projects/FastAttr/fastattr/datasets/dukemtmcattr.py new file mode 100644 index 0000000..ffacce9 --- /dev/null +++ b/projects/FastAttr/fastattr/datasets/dukemtmcattr.py @@ -0,0 +1,74 @@ +# encoding: utf-8 +""" +@author: liaoxingyu +@contact: liaoxingyu2@jd.com +""" + +import glob +import os.path as osp +import re +import mat4py +import numpy as np + +from fastreid.data.datasets import DATASET_REGISTRY + +from .bases import Dataset + + +@DATASET_REGISTRY.register() +class DukeMTMCAttr(Dataset): + """DukeMTMCAttr. + + Reference: + Lin, Yutian, et al. "Improving person re-identification by attribute and identity learning." + Pattern Recognition 95 (2019): 151-161. + + URL: ``_ + + The folder structure should be: + DukeMTMC-reID/ + bounding_box_train/ # images + bounding_box_test/ # images + duke_attribute.mat + """ + dataset_dir = 'DukeMTMC-reID' + dataset_url = 'http://vision.cs.duke.edu/DukeMTMC/data/misc/DukeMTMC-reID.zip' + dataset_name = "dukemtmc" + + def __init__(self, root='datasets', **kwargs): + self.root = root + self.dataset_dir = osp.join(self.root, self.dataset_dir) + self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') + self.test_dir = osp.join(self.dataset_dir, 'bounding_box_test') + + required_files = [ + self.dataset_dir, + self.train_dir, + self.test_dir, + ] + self.check_before_run(required_files) + + duke_attr = mat4py.loadmat(osp.join(self.dataset_dir, 'duke_attribute.mat'))['duke_attribute'] + sorted_attrs = sorted(duke_attr['train'].keys()) + sorted_attrs.remove('image_index') + attr_dict = {i: str(attr) for i, attr in enumerate(sorted_attrs)} + + train = self.process_dir(self.train_dir, duke_attr['train'], sorted_attrs) + test = val = self.process_dir(self.test_dir, duke_attr['test'], sorted_attrs) + + super(DukeMTMCAttr, self).__init__(train, val, test, attr_dict=attr_dict, **kwargs) + + def process_dir(self, dir_path, annotation, sorted_attrs): + img_paths = glob.glob(osp.join(dir_path, '*.jpg')) + pattern = re.compile(r'([-\d]+)_c(\d)') + + data = [] + for img_path in img_paths: + pid, camid = map(int, pattern.search(img_path).groups()) + assert 1 <= camid <= 8 + + img_index = annotation['image_index'].index(str(pid).zfill(4)) + attrs = np.array([int(annotation[i][img_index]) - 1 for i in sorted_attrs], dtype=np.float32) + data.append((img_path, attrs)) + + return data diff --git a/projects/FastAttr/fastattr/datasets/market1501attr.py b/projects/FastAttr/fastattr/datasets/market1501attr.py new file mode 100644 index 0000000..d0c44a2 --- /dev/null +++ b/projects/FastAttr/fastattr/datasets/market1501attr.py @@ -0,0 +1,90 @@ +# encoding: utf-8 +""" +@author: sherlock +@contact: sherlockliao01@gmail.com +""" + +import glob +import os.path as osp +import re +import warnings +import mat4py +import numpy as np + +from fastreid.data.datasets import DATASET_REGISTRY + +from .bases import Dataset + + +@DATASET_REGISTRY.register() +class Market1501Attr(Dataset): + """Market1501Attr. + + Reference: + Lin, Yutian, et al. "Improving person re-identification by attribute and identity learning." + Pattern Recognition 95 (2019): 151-161. + + URL: ``_ + + The folder structure should be: + Market-1501-v15.09.15/ + bounding_box_train/ # images + bounding_box_test/ # images + market_attribute.mat + """ + _junk_pids = [0, -1] + dataset_dir = '' + dataset_url = 'http://188.138.127.15:81/Datasets/Market-1501-v15.09.15.zip' + dataset_name = "market1501" + + def __init__(self, root='datasets', market1501_500k=False, **kwargs): + self.root = root + self.dataset_dir = osp.join(self.root, self.dataset_dir) + + # allow alternative directory structure + self.data_dir = self.dataset_dir + data_dir = osp.join(self.data_dir, 'Market-1501-v15.09.15') + if osp.isdir(data_dir): + self.data_dir = data_dir + else: + warnings.warn('The current data structure is deprecated. Please ' + 'put data folders such as "bounding_box_train" under ' + '"Market-1501-v15.09.15".') + + self.train_dir = osp.join(self.data_dir, 'bounding_box_train') + self.test_dir = osp.join(self.data_dir, 'bounding_box_test') + + required_files = [ + self.data_dir, + self.train_dir, + self.test_dir, + ] + self.check_before_run(required_files) + + market_attr = mat4py.loadmat(osp.join(self.data_dir, 'market_attribute.mat'))['market_attribute'] + sorted_attrs = sorted(market_attr['train'].keys()) + sorted_attrs.remove('image_index') + attr_dict = {i: str(attr) for i, attr in enumerate(sorted_attrs)} + + train = self.process_dir(self.train_dir, market_attr['train'], sorted_attrs) + test = val = self.process_dir(self.test_dir, market_attr['test'], sorted_attrs) + + super(Market1501Attr, self).__init__(train, val, test, attr_dict=attr_dict, **kwargs) + + def process_dir(self, dir_path, annotation, sorted_attrs): + img_paths = glob.glob(osp.join(dir_path, '*.jpg')) + pattern = re.compile(r'([-\d]+)_c(\d)') + + data = [] + for img_path in img_paths: + pid, camid = map(int, pattern.search(img_path).groups()) + if pid == -1 or pid == 0: + continue # junk images are just ignored + assert 0 <= pid <= 1501 # pid == 0 means background + assert 1 <= camid <= 6 + + img_index = annotation['image_index'].index(str(pid).zfill(4)) + attrs = np.array([int(annotation[i][img_index])-1 for i in sorted_attrs], dtype=np.float32) + data.append((img_path, attrs)) + + return data