Attribute projects update

Summary: 1.fix zero divided in attribute metric computation;2.update market/duke attribute dataset loading.
Reviewed by: l1aoxingyu
pull/414/head
Xingyu Liao 2021-02-18 10:32:48 +08:00 committed by GitHub
commit cf46e5f071
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 196 additions and 6 deletions

View File

@ -0,0 +1,11 @@
_BASE_: Base-attribute.yml
DATASETS:
NAMES: ("DukeMTMCAttr",)
TESTS: ("DukeMTMCAttr",)
MODEL:
HEADS:
NUM_CLASSES: 23
OUTPUT_DIR: projects/FastAttr/logs/dukemtmc/strong_baseline

View File

@ -0,0 +1,11 @@
_BASE_: Base-attribute.yml
DATASETS:
NAMES: ("Market1501Attr",)
TESTS: ("Market1501Attr",)
MODEL:
HEADS:
NUM_CLASSES: 27
OUTPUT_DIR: projects/FastAttr/logs/market1501/strong_baseline

View File

@ -36,6 +36,8 @@ class AttrEvaluator(DatasetEvaluator):
@staticmethod
def get_attr_metrics(gt_labels, pred_logits, thres):
eps = 1e-20
pred_labels = copy.deepcopy(pred_logits)
pred_labels[pred_logits < thres] = 0
pred_labels[pred_logits >= thres] = 1
@ -53,13 +55,13 @@ class AttrEvaluator(DatasetEvaluator):
gt_labels = gt_labels.astype(bool)
intersect = (pred_labels & gt_labels).astype(float)
union = (pred_labels | gt_labels).astype(float)
ins_acc = (intersect.sum(axis=1) / union.sum(axis=1)).mean()
ins_prec = (intersect.sum(axis=1) / pred_labels.astype(float).sum(axis=1)).mean()
ins_rec = (intersect.sum(axis=1) / gt_labels.astype(float).sum(axis=1)).mean()
ins_f1 = (2 * ins_prec * ins_rec) / (ins_prec + ins_rec)
ins_acc = (intersect.sum(axis=1) / (union.sum(axis=1) + eps)).mean()
ins_prec = (intersect.sum(axis=1) / (pred_labels.astype(float).sum(axis=1) + eps)).mean()
ins_rec = (intersect.sum(axis=1) / (gt_labels.astype(float).sum(axis=1) + eps)).mean()
ins_f1 = (2 * ins_prec * ins_rec) / (ins_prec + ins_rec + eps)
term1 = correct_pos / real_pos
term2 = correct_neg / real_neg
term1 = correct_pos / (real_pos + eps)
term2 = correct_neg / (real_neg + eps)
label_mA_verbose = (term1 + term2) * 0.5
label_mA = label_mA_verbose.mean()

View File

@ -6,3 +6,5 @@
# Attributed datasets
from .pa100k import PA100K
from .market1501attr import Market1501Attr
from .dukemtmcattr import DukeMTMCAttr

View File

@ -0,0 +1,74 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: liaoxingyu2@jd.com
"""
import glob
import os.path as osp
import re
import mat4py
import numpy as np
from fastreid.data.datasets import DATASET_REGISTRY
from .bases import Dataset
@DATASET_REGISTRY.register()
class DukeMTMCAttr(Dataset):
"""DukeMTMCAttr.
Reference:
Lin, Yutian, et al. "Improving person re-identification by attribute and identity learning."
Pattern Recognition 95 (2019): 151-161.
URL: `<https://github.com/vana77/DukeMTMC-attribute>`_
The folder structure should be:
DukeMTMC-reID/
bounding_box_train/ # images
bounding_box_test/ # images
duke_attribute.mat
"""
dataset_dir = 'DukeMTMC-reID'
dataset_url = 'http://vision.cs.duke.edu/DukeMTMC/data/misc/DukeMTMC-reID.zip'
dataset_name = "dukemtmc"
def __init__(self, root='datasets', **kwargs):
self.root = root
self.dataset_dir = osp.join(self.root, self.dataset_dir)
self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train')
self.test_dir = osp.join(self.dataset_dir, 'bounding_box_test')
required_files = [
self.dataset_dir,
self.train_dir,
self.test_dir,
]
self.check_before_run(required_files)
duke_attr = mat4py.loadmat(osp.join(self.dataset_dir, 'duke_attribute.mat'))['duke_attribute']
sorted_attrs = sorted(duke_attr['train'].keys())
sorted_attrs.remove('image_index')
attr_dict = {i: str(attr) for i, attr in enumerate(sorted_attrs)}
train = self.process_dir(self.train_dir, duke_attr['train'], sorted_attrs)
test = val = self.process_dir(self.test_dir, duke_attr['test'], sorted_attrs)
super(DukeMTMCAttr, self).__init__(train, val, test, attr_dict=attr_dict, **kwargs)
def process_dir(self, dir_path, annotation, sorted_attrs):
img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
pattern = re.compile(r'([-\d]+)_c(\d)')
data = []
for img_path in img_paths:
pid, camid = map(int, pattern.search(img_path).groups())
assert 1 <= camid <= 8
img_index = annotation['image_index'].index(str(pid).zfill(4))
attrs = np.array([int(annotation[i][img_index]) - 1 for i in sorted_attrs], dtype=np.float32)
data.append((img_path, attrs))
return data

View File

@ -0,0 +1,90 @@
# encoding: utf-8
"""
@author: sherlock
@contact: sherlockliao01@gmail.com
"""
import glob
import os.path as osp
import re
import warnings
import mat4py
import numpy as np
from fastreid.data.datasets import DATASET_REGISTRY
from .bases import Dataset
@DATASET_REGISTRY.register()
class Market1501Attr(Dataset):
"""Market1501Attr.
Reference:
Lin, Yutian, et al. "Improving person re-identification by attribute and identity learning."
Pattern Recognition 95 (2019): 151-161.
URL: `<https://github.com/vana77/Market-1501_Attribute>`_
The folder structure should be:
Market-1501-v15.09.15/
bounding_box_train/ # images
bounding_box_test/ # images
market_attribute.mat
"""
_junk_pids = [0, -1]
dataset_dir = ''
dataset_url = 'http://188.138.127.15:81/Datasets/Market-1501-v15.09.15.zip'
dataset_name = "market1501"
def __init__(self, root='datasets', market1501_500k=False, **kwargs):
self.root = root
self.dataset_dir = osp.join(self.root, self.dataset_dir)
# allow alternative directory structure
self.data_dir = self.dataset_dir
data_dir = osp.join(self.data_dir, 'Market-1501-v15.09.15')
if osp.isdir(data_dir):
self.data_dir = data_dir
else:
warnings.warn('The current data structure is deprecated. Please '
'put data folders such as "bounding_box_train" under '
'"Market-1501-v15.09.15".')
self.train_dir = osp.join(self.data_dir, 'bounding_box_train')
self.test_dir = osp.join(self.data_dir, 'bounding_box_test')
required_files = [
self.data_dir,
self.train_dir,
self.test_dir,
]
self.check_before_run(required_files)
market_attr = mat4py.loadmat(osp.join(self.data_dir, 'market_attribute.mat'))['market_attribute']
sorted_attrs = sorted(market_attr['train'].keys())
sorted_attrs.remove('image_index')
attr_dict = {i: str(attr) for i, attr in enumerate(sorted_attrs)}
train = self.process_dir(self.train_dir, market_attr['train'], sorted_attrs)
test = val = self.process_dir(self.test_dir, market_attr['test'], sorted_attrs)
super(Market1501Attr, self).__init__(train, val, test, attr_dict=attr_dict, **kwargs)
def process_dir(self, dir_path, annotation, sorted_attrs):
img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
pattern = re.compile(r'([-\d]+)_c(\d)')
data = []
for img_path in img_paths:
pid, camid = map(int, pattern.search(img_path).groups())
if pid == -1 or pid == 0:
continue # junk images are just ignored
assert 0 <= pid <= 1501 # pid == 0 means background
assert 1 <= camid <= 6
img_index = annotation['image_index'].index(str(pid).zfill(4))
attrs = np.array([int(annotation[i][img_index])-1 for i in sorted_attrs], dtype=np.float32)
data.append((img_path, attrs))
return data