update attribute project

pull/299/head
liaoxingyu 2020-09-23 19:45:13 +08:00
parent 5b88736e1d
commit 5dfe537515
18 changed files with 789 additions and 15 deletions

View File

@ -4,6 +4,7 @@ FastReID is a research platform that implements state-of-the-art re-identificati
## What's New
- [Sep 2020] Added the person attribute recognition code based fastreid. See `projects/attribute_recognition`.
- [Sep 2020] Automatic Mixed Precision training is supported with pytorch1.6 built-in `torch.cuda.amp`. Set `cfg.SOLVER.AMP_ENABLED=True` to switch it on.
- [Aug 2020] [Model Distillation](https://github.com/JDAI-CV/fast-reid/tree/master/projects/DistillReID) is supported, thanks for [guan'an wang](https://github.com/wangguanan)'s contribution.
- [Aug 2020] ONNX/TensorRT converter is supported.

View File

@ -8,4 +8,4 @@ from .build import REID_HEADS_REGISTRY, build_heads
# import all the meta_arch, so they will be registered
from .embedding_head import EmbeddingHead
from .cls_head import CLSHead
from .attr_head import AttrHead

View File

@ -4,7 +4,7 @@
@contact: sherlockliao01@gmail.com
"""
import torch.nn.functional as F
import torch
from torch import nn
from fastreid.layers import *
@ -13,7 +13,7 @@ from .build import REID_HEADS_REGISTRY
@REID_HEADS_REGISTRY.register()
class CLSHead(nn.Module):
class AttrHead(nn.Module):
def __init__(self, cfg):
super().__init__()
# fmt: off
@ -46,7 +46,7 @@ class CLSHead(nn.Module):
# bottleneck = []
# if with_bnneck:
# bottleneck.append(get_norm(norm_type, feat_dim, bias_freeze=True))
bottleneck = [nn.BatchNorm1d(feat_dim)]
bottleneck = [nn.BatchNorm1d(num_classes)]
self.bottleneck = nn.Sequential(*bottleneck)
@ -60,16 +60,18 @@ class CLSHead(nn.Module):
global_feat = self.pool_layer(features)
global_feat = global_feat[..., 0, 0]
if self.classifier.__class__.__name__ == 'Linear':
cls_outputs = self.classifier(global_feat)
pred_class_logits = F.linear(global_feat, self.classifier.weight)
else:
cls_outputs = self.classifier(global_feat, targets)
pred_class_logits = self.classifier.s * F.linear(F.normalize(global_feat),
F.normalize(self.classifier.weight))
classifier_name = self.classifier.__class__.__name__
# fmt: off
if classifier_name == 'Linear': cls_outputs = self.classifier(global_feat)
else: cls_outputs = self.classifier(global_feat, targets)
# fmt: on
cls_outputs = self.bottleneck(cls_outputs)
return {
"cls_outputs": cls_outputs,
"pred_class_logits": pred_class_logits,
}
if self.training:
return {
"cls_outputs": cls_outputs,
}
else:
cls_outputs = torch.sigmoid(cls_outputs)
return cls_outputs

View File

@ -0,0 +1,26 @@
# Person Attribute Recognition in FastReID
## Training and Evaluation
To train a model, run:
```bash
python3 projects/PartialReID/train_net.py --config-file <config.yaml> --num-gpus 1
```
For example, to train the attribute recognition network with ResNet-50 Backbone in PA100k dataset,
one should execute:
```bash
python3 projects/attribute_recognition/train_net.py --config-file projects/attribute_recognition/configs/pa100.yml --num-gpus 4
```
## Results
### PA100k
| Method | mA | Accu | Prec | Recall | F1 |
|:--:|:--:|:--:|:--:|:--:|:--:|
| Strongbaseline | 77.76 | 77.59 | 88.38 | 84.35 | 86.32 |
More datasets and test results are waiting to add, stay tune!

View File

@ -0,0 +1,12 @@
# encoding: utf-8
"""
@author: xingyu liao
@contact: sherlockliao01@gmail.com
"""
from .config import add_attr_config
from .datasets import *
from .attr_baseline import AttrBaseline
from .attr_evaluation import AttrEvaluator
from .data_build import build_attr_train_loader, build_attr_test_loader
from .attr_trainer import AttrTrainer

View File

@ -0,0 +1,41 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
from fastreid.modeling.meta_arch.baseline import Baseline
from fastreid.modeling.meta_arch.build import META_ARCH_REGISTRY
from .bce_loss import cross_entropy_sigmoid_loss
@META_ARCH_REGISTRY.register()
class AttrBaseline(Baseline):
def losses(self, outs, sample_weight=None):
r"""
Compute loss from modeling's outputs, the loss function input arguments
must be the same as the outputs of the model forwarding.
"""
# fmt: off
outputs = outs["outputs"]
gt_labels = outs["targets"]
# model predictions
# pred_class_logits = outputs['pred_class_logits'].detach()
cls_outputs = outputs['cls_outputs']
# fmt: on
# Log prediction accuracy
# log_accuracy(pred_class_logits, gt_labels)
loss_dict = {}
loss_names = self._cfg.MODEL.LOSSES.NAME
if "BinaryCrossEntropyLoss" in loss_names:
loss_dict['loss_bce'] = cross_entropy_sigmoid_loss(
cls_outputs,
gt_labels,
sample_weight,
) * self._cfg.MODEL.LOSSES.BCE.SCALE
return loss_dict

View File

@ -0,0 +1,96 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import copy
import logging
from collections import OrderedDict
import torch
from fastreid.evaluation.evaluator import DatasetEvaluator
from fastreid.utils import comm
logger = logging.getLogger(__name__)
class AttrEvaluator(DatasetEvaluator):
def __init__(self, cfg, attr_dict, thres=0.5, output_dir=None):
self.cfg = cfg
self.attr_dict = attr_dict
self.thres = thres
self._output_dir = output_dir
self.pred_logits = []
self.gt_labels = []
def reset(self):
self.pred_logits = []
self.gt_labels = []
def process(self, inputs, outputs):
self.gt_labels.extend(inputs["targets"])
self.pred_logits.extend(outputs.cpu())
@staticmethod
def get_attr_metrics(gt_labels, pred_logits, thres):
pred_labels = copy.deepcopy(pred_logits)
pred_labels[pred_logits < thres] = 0
pred_labels[pred_logits >= thres] = 1
# Compute label-based metric
overlaps = pred_labels * gt_labels
correct_pos = overlaps.sum(axis=0)
real_pos = gt_labels.sum(axis=0)
inv_overlaps = (1 - pred_labels) * (1 - gt_labels)
correct_neg = inv_overlaps.sum(axis=0)
real_neg = (1 - gt_labels).sum(axis=0)
# Compute instance-based accuracy
pred_labels = pred_labels.astype(bool)
gt_labels = gt_labels.astype(bool)
intersect = (pred_labels & gt_labels).astype(float)
union = (pred_labels | gt_labels).astype(float)
ins_acc = (intersect.sum(axis=1) / union.sum(axis=1)).mean()
ins_prec = (intersect.sum(axis=1) / pred_labels.astype(float).sum(axis=1)).mean()
ins_rec = (intersect.sum(axis=1) / gt_labels.astype(float).sum(axis=1)).mean()
ins_f1 = (2 * ins_prec * ins_rec) / (ins_prec + ins_rec)
term1 = correct_pos / real_pos
term2 = correct_neg / real_neg
label_mA_verbose = (term1 + term2) * 0.5
label_mA = label_mA_verbose.mean()
results = OrderedDict()
results["Accu"] = ins_acc
results["Prec"] = ins_prec
results["Recall"] = ins_rec
results["F1"] = ins_f1
results["mA"] = label_mA
return results
def evaluate(self):
if comm.get_world_size() > 1:
comm.synchronize()
pred_logits = comm.gather(self.pred_logits)
pred_logits = sum(pred_logits, [])
gt_labels = comm.gather(self.gt_labels)
gt_labels = sum(gt_labels, [])
if not comm.is_main_process():
return {}
else:
pred_logits = self.pred_logits
gt_labels = self.gt_labels
pred_logits = torch.stack(pred_logits, dim=0).numpy()
gt_labels = torch.stack(gt_labels, dim=0).numpy()
# Pedestrian attribute metrics
thres = self.cfg.TEST.THRES
self._results = self.get_attr_metrics(gt_labels, pred_logits, thres)
return copy.deepcopy(self._results)

View File

@ -0,0 +1,89 @@
# encoding: utf-8
"""
@author: xingyu liao
@contact: sherlockliao01@gmail.com
"""
import time
import torch
from torch.nn.parallel import DistributedDataParallel
from torch.cuda import amp
from fastreid.engine import DefaultTrainer
from .data_build import build_attr_train_loader, build_attr_test_loader
from .attr_evaluation import AttrEvaluator
class AttrTrainer(DefaultTrainer):
def __init__(self, cfg):
super().__init__(cfg)
# Sample weight for attributed imbalanced classification
bce_weight_enabled = self.cfg.MODEL.LOSSES.BCE.WEIGHT_ENABLED
# fmt: off
if bce_weight_enabled: self.sample_weights = self.data_loader.dataset.sample_weights.to("cuda")
else: self.sample_weights = None
# fmt: on
@classmethod
def build_train_loader(cls, cfg):
return build_attr_train_loader(cfg)
@classmethod
def build_test_loader(cls, cfg, dataset_name):
return build_attr_test_loader(cfg, dataset_name)
@classmethod
def build_evaluator(cls, cfg, dataset_name, output_folder=None):
data_loader = cls.build_test_loader(cfg, dataset_name)
return data_loader, AttrEvaluator(cfg, output_folder)
def run_step(self):
r"""
Implement the attribute model training logic.
"""
assert self.model.training, "[SimpleTrainer] model was changed to eval mode!"
start = time.perf_counter()
"""
If your want to do something with the data, you can wrap the dataloader.
"""
data = next(self._data_loader_iter)
data_time = time.perf_counter() - start
"""
If your want to do something with the heads, you can wrap the model.
"""
with amp.autocast(enabled=self.amp_enabled):
outs = self.model(data)
# Compute loss
if isinstance(self.model, DistributedDataParallel):
loss_dict = self.model.module.losses(outs, self.sample_weights)
else:
loss_dict = self.model.losses(outs, self.sample_weights)
losses = sum(loss_dict.values())
with torch.cuda.stream(torch.cuda.Stream()):
metrics_dict = loss_dict
metrics_dict["data_time"] = data_time
self._write_metrics(metrics_dict)
self._detect_anomaly(losses, loss_dict)
"""
If you need accumulate gradients or something similar, you can
wrap the optimizer with your custom `zero_grad()` method.
"""
self.optimizer.zero_grad()
if self.amp_enabled:
self.scaler.scale(losses).backward()
self.scaler.step(self.optimizer)
self.scaler.update()
else:
losses.backward()
"""
If you need gradient clipping/scaling or other processing, you can
wrap the optimizer with your custom `step()` method.
"""
self.optimizer.step()

View File

@ -0,0 +1,33 @@
# encoding: utf-8
"""
@author: xingyu liao
@contact: sherlockliao01@gmail.com
"""
import torch
import torch.nn.functional as F
def ratio2weight(targets, ratio):
pos_weights = targets * (1 - ratio)
neg_weights = (1 - targets) * ratio
weights = torch.exp(neg_weights + pos_weights)
weights[targets > 1] = 0.0
return weights
def cross_entropy_sigmoid_loss(pred_class_logits, gt_classes, sample_weight=None):
loss = F.binary_cross_entropy_with_logits(pred_class_logits, gt_classes, reduction='none')
if sample_weight is not None:
targets_mask = torch.where(gt_classes.detach() > 0.5,
torch.ones(1, device="cuda"), torch.zeros(1, device="cuda")) # dtype float32
weight = ratio2weight(targets_mask, sample_weight)
loss = loss * weight
with torch.no_grad():
non_zero_cnt = max(loss.nonzero(as_tuple=False).size(0), 1)
loss = loss.sum() / non_zero_cnt
return loss

View File

@ -0,0 +1,47 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import torch
from torch.utils.data import Dataset
from fastreid.data.data_utils import read_image
class AttrDataset(Dataset):
"""Image Person Attribute Dataset"""
def __init__(self, img_items, attr_dict, transform=None):
self.img_items = img_items
self.attr_dict = attr_dict
self.transform = transform
def __len__(self):
return len(self.img_items)
def __getitem__(self, index):
img_path, labels = self.img_items[index]
img = read_image(img_path)
if self.transform is not None: img = self.transform(img)
labels = torch.from_numpy(labels)
return {
"images": img,
"targets": labels,
"img_paths": img_path,
}
@property
def num_classes(self):
return len(self.attr_dict)
@property
def sample_weights(self):
sample_weights = torch.zeros(self.num_classes, dtype=torch.float)
for _, attr in self.img_items:
sample_weights += torch.from_numpy(attr)
sample_weights /= len(self)
return sample_weights

View File

@ -0,0 +1,17 @@
# encoding: utf-8
"""
@author: xingyu liao
@contact: sherlockliao01@gmail.com
"""
from fastreid.config import CfgNode as CN
def add_attr_config(cfg):
_C = cfg
_C.MODEL.LOSSES.BCE = CN()
_C.MODEL.LOSSES.BCE.WEIGHT_ENABLED = True
_C.MODEL.LOSSES.BCE.SCALE = 1.
_C.TEST.THRES = 0.5

View File

@ -0,0 +1,82 @@
# encoding: utf-8
"""
@author: l1aoxingyu
@contact: sherlockliao01@gmail.com
"""
import os
import torch
from torch.utils.data import DataLoader
from fastreid.utils import comm
from .common_attr import AttrDataset
from fastreid.data import samplers
from fastreid.data.build import fast_batch_collator
from fastreid.data.datasets import DATASET_REGISTRY
from fastreid.data.transforms import build_transforms
_root = os.getenv("FASTREID_DATASETS", "datasets")
def build_attr_train_loader(cfg):
cfg = cfg.clone()
cfg.defrost()
train_items = list()
attr_dict = None
for d in cfg.DATASETS.NAMES:
dataset = DATASET_REGISTRY.get(d)(root=_root, combineall=cfg.DATASETS.COMBINEALL)
if comm.is_main_process():
dataset.show_train()
if attr_dict is not None:
assert attr_dict == dataset.attr_dict, "attr_dict in {} does not match with previous ones".format(d)
else:
attr_dict = dataset.attr_dict
train_items.extend(dataset.train)
iters_per_epoch = len(train_items) // cfg.SOLVER.IMS_PER_BATCH
cfg.SOLVER.MAX_ITER *= iters_per_epoch
train_transforms = build_transforms(cfg, is_train=True)
train_set = AttrDataset(train_items, attr_dict, train_transforms)
num_workers = cfg.DATALOADER.NUM_WORKERS
mini_batch_size = cfg.SOLVER.IMS_PER_BATCH // comm.get_world_size()
data_sampler = samplers.TrainingSampler(len(train_set))
batch_sampler = torch.utils.data.sampler.BatchSampler(data_sampler, mini_batch_size, True)
train_loader = torch.utils.data.DataLoader(
train_set,
num_workers=num_workers,
batch_sampler=batch_sampler,
collate_fn=fast_batch_collator,
pin_memory=True,
)
return train_loader
def build_attr_test_loader(cfg, dataset_name):
cfg = cfg.clone()
cfg.defrost()
dataset = DATASET_REGISTRY.get(dataset_name)(root=_root, combineall=cfg.DATASETS.COMBINEALL)
if comm.is_main_process():
dataset.show_test()
test_items = dataset.test
test_transforms = build_transforms(cfg, is_train=False)
test_set = AttrDataset(test_items, dataset.attr_dict, test_transforms)
mini_batch_size = cfg.TEST.IMS_PER_BATCH // comm.get_world_size()
data_sampler = samplers.InferenceSampler(len(test_set))
batch_sampler = torch.utils.data.BatchSampler(data_sampler, mini_batch_size, False)
test_loader = DataLoader(
test_set,
batch_sampler=batch_sampler,
num_workers=0, # save some memory
collate_fn=fast_batch_collator,
pin_memory=True,
)
return test_loader

View File

@ -0,0 +1,8 @@
# encoding: utf-8
"""
@author: xingyu liao
@contact: sherlockliao01@gmail.com
"""
# Attributed datasets
from .pa100k import PA100K

View File

@ -0,0 +1,127 @@
# encoding: utf-8
"""
@author: xingyu liao
@contact: sherlockliao01@gmail.com
"""
import copy
import logging
import os
from tabulate import tabulate
from termcolor import colored
logger = logging.getLogger("fastreid." + __name__)
class Dataset(object):
def __init__(
self,
train,
val,
test,
attr_dict,
mode='train',
verbose=True,
**kwargs,
):
self.train = train
self.val = val
self.test = test
self._attr_dict = attr_dict
self._num_attrs = len(self.attr_dict)
if mode == 'train':
self.data = self.train
elif mode == 'val':
self.data = self.val
else:
self.data = self.test
@property
def num_attrs(self):
return self._num_attrs
@property
def attr_dict(self):
return self._attr_dict
def __len__(self):
return len(self.data)
def __getitem__(self, index):
raise NotImplementedError
def check_before_run(self, required_files):
"""Checks if required files exist before going deeper.
Args:
required_files (str or list): string file name(s).
"""
if isinstance(required_files, str):
required_files = [required_files]
for fpath in required_files:
if not os.path.exists(fpath):
raise RuntimeError('"{}" is not found'.format(fpath))
def combine_all(self):
"""Combines train, val and test in a dataset for training."""
combined = copy.deepcopy(self.train)
def _combine_data(data):
for img_path, pid, camid in data:
if pid in self._junk_pids:
continue
pid = self.dataset_name + "_" + str(pid)
camid = self.dataset_name + "_" + str(camid)
combined.append((img_path, pid, camid))
_combine_data(self.query)
_combine_data(self.gallery)
self.train = combined
self.num_train_pids = self.get_num_pids(self.train)
def show_train(self):
num_train = len(self.train)
num_val = len(self.val)
num_total = num_train + num_val
headers = ['subset', '# images']
csv_results = [
['train', num_train],
['val', num_val],
['total', num_total],
]
# tabulate it
table = tabulate(
csv_results,
tablefmt="pipe",
headers=headers,
numalign="left",
)
logger.info(f"=> Loaded {self.__class__.__name__} in csv format: \n" + colored(table, "cyan"))
logger.info("attributes:")
for label, attr in self.attr_dict.items():
logger.info('{:3d}: {}'.format(label, attr))
logger.info("------------------------------")
logger.info("# attributes: {}".format(len(self.attr_dict)))
def show_test(self):
num_test = len(self.test)
headers = ['subset', '# images']
csv_results = [
['test', num_test],
]
# tabulate it
table = tabulate(
csv_results,
tablefmt="pipe",
headers=headers,
numalign="left",
)
logger.info(f"=> Loaded {self.__class__.__name__} in csv format: \n" + colored(table, "cyan"))

View File

@ -0,0 +1,65 @@
# encoding: utf-8
"""
@author: xingyu liao
@contact: sherlockliao01@gmail.com
"""
import os.path as osp
import numpy as np
from scipy.io import loadmat
from fastreid.data.datasets import DATASET_REGISTRY
from .bases import Dataset
@DATASET_REGISTRY.register()
class PA100K(Dataset):
"""Pedestrian attribute dataset.
80k training images + 20k test images.
The folder structure should be:
pa100k/
data/ # images
annotation.mat
"""
dataset_dir = 'PA-100K'
def __init__(self, root='', **kwargs):
self.root = root
self.dataset_dir = osp.join(self.root, self.dataset_dir)
self.data_dir = osp.join(self.dataset_dir, 'data')
self.anno_mat_path = osp.join(
self.dataset_dir, 'annotation.mat'
)
required_files = [self.data_dir, self.anno_mat_path]
self.check_before_run(required_files)
train, val, test, attr_dict = self.extract_data()
super(PA100K, self).__init__(train, val, test, attr_dict=attr_dict, **kwargs)
def extract_data(self):
# anno_mat is a dictionary with keys: ['test_images_name', 'val_images_name',
# 'train_images_name', 'val_label', 'attributes', 'test_label', 'train_label']
anno_mat = loadmat(self.anno_mat_path)
def _extract(key_name, key_label):
names = anno_mat[key_name]
labels = anno_mat[key_label]
num_imgs = names.shape[0]
data = []
for i in range(num_imgs):
name = names[i, 0][0]
attrs = labels[i, :].astype(np.float32)
img_path = osp.join(self.data_dir, name)
data.append((img_path, attrs))
return data
train = _extract('train_images_name', 'train_label')
val = _extract('val_images_name', 'val_label')
test = _extract('test_images_name', 'test_label')
attrs = anno_mat['attributes']
attr_dict = {i: str(attr[0][0]) for i, attr in enumerate(attrs)}
return train, val, test, attr_dict

View File

@ -0,0 +1,63 @@
MODEL:
META_ARCHITECTURE: "AttrBaseline"
BACKBONE:
NAME: "build_resnet_backbone"
NORM: "BN"
DEPTH: "50x"
LAST_STRIDE: 2
FEAT_DIM: 2048
WITH_IBN: False
PRETRAIN: True
PRETRAIN_PATH: "/export/home/lxy/.cache/torch/checkpoints/resnet50-19c8e357.pth"
HEADS:
NAME: "AttrHead"
NORM: "BN"
WITH_BNNECK: True
POOL_LAYER: "fastavgpool"
CLS_LAYER: "linear"
NUM_CLASSES: 26
LOSSES:
NAME: ("BinaryCrossEntropyLoss",)
BCE:
WEIGHT_ENABLED: True
SCALE: 1.
INPUT:
SIZE_TRAIN: [256, 128]
SIZE_TEST: [256, 128]
REA:
ENABLED: False
DO_PAD: True
DATALOADER:
NUM_WORKERS: 8
SOLVER:
OPT: "SGD"
MAX_ITER: 30
BASE_LR: 0.01
BIAS_LR_FACTOR: 2.
HEADS_LR_FACTOR: 10.
WEIGHT_DECAY: 0.0005
WEIGHT_DECAY_BIAS: 0.0005
IMS_PER_BATCH: 64
SCHED: "WarmupCosineAnnealingLR"
DELAY_ITERS: 5
ETA_MIN_LR: 0.00001
WARMUP_FACTOR: 0.01
WARMUP_ITERS: 5
CHECKPOINT_PERIOD: 10
TEST:
EVAL_PERIOD: 10
IMS_PER_BATCH: 256
CUDNN_BENCHMARK: True

View File

@ -0,0 +1,7 @@
_BASE_: "Base-attribute.yml"
DATASETS:
NAMES: ("PA100K",)
TESTS: ("PA100K",)
OUTPUT_DIR: "projects/attribute_recognition/logs/pa100k/strong_baseline"

View File

@ -0,0 +1,58 @@
# encoding: utf-8
"""
@author: xingyu liao
@contact: sherlockliao01@gmail.com
"""
import sys
sys.path.append('.')
from fastreid.config import get_cfg
from fastreid.engine import default_argument_parser, default_setup, launch
from fastreid.utils.checkpoint import Checkpointer
from attribute_baseline import *
def setup(args):
"""
Create configs and perform basic setups.
"""
cfg = get_cfg()
add_attr_config(cfg)
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.freeze()
default_setup(cfg, args)
return cfg
def main(args):
cfg = setup(args)
if args.eval_only:
cfg.defrost()
cfg.MODEL.BACKBONE.PRETRAIN = False
model = AttrTrainer.build_model(cfg)
Checkpointer(model).load(cfg.MODEL.WEIGHTS) # load trained model
res = AttrTrainer.test(cfg, model)
return res
trainer = AttrTrainer(cfg)
trainer.resume_or_load(resume=args.resume)
return trainer.train()
if __name__ == "__main__":
args = default_argument_parser().parse_args()
print("Command Line Args:", args)
launch(
main,
args.num_gpus,
num_machines=args.num_machines,
machine_rank=args.machine_rank,
dist_url=args.dist_url,
args=(args,),
)