mirror of https://github.com/JDAI-CV/fast-reid.git
update attribute project
parent
5b88736e1d
commit
5dfe537515
|
@ -4,6 +4,7 @@ FastReID is a research platform that implements state-of-the-art re-identificati
|
|||
|
||||
## What's New
|
||||
|
||||
- [Sep 2020] Added the person attribute recognition code based fastreid. See `projects/attribute_recognition`.
|
||||
- [Sep 2020] Automatic Mixed Precision training is supported with pytorch1.6 built-in `torch.cuda.amp`. Set `cfg.SOLVER.AMP_ENABLED=True` to switch it on.
|
||||
- [Aug 2020] [Model Distillation](https://github.com/JDAI-CV/fast-reid/tree/master/projects/DistillReID) is supported, thanks for [guan'an wang](https://github.com/wangguanan)'s contribution.
|
||||
- [Aug 2020] ONNX/TensorRT converter is supported.
|
||||
|
|
|
@ -8,4 +8,4 @@ from .build import REID_HEADS_REGISTRY, build_heads
|
|||
|
||||
# import all the meta_arch, so they will be registered
|
||||
from .embedding_head import EmbeddingHead
|
||||
from .cls_head import CLSHead
|
||||
from .attr_head import AttrHead
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import torch.nn.functional as F
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from fastreid.layers import *
|
||||
|
@ -13,7 +13,7 @@ from .build import REID_HEADS_REGISTRY
|
|||
|
||||
|
||||
@REID_HEADS_REGISTRY.register()
|
||||
class CLSHead(nn.Module):
|
||||
class AttrHead(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
# fmt: off
|
||||
|
@ -46,7 +46,7 @@ class CLSHead(nn.Module):
|
|||
# bottleneck = []
|
||||
# if with_bnneck:
|
||||
# bottleneck.append(get_norm(norm_type, feat_dim, bias_freeze=True))
|
||||
bottleneck = [nn.BatchNorm1d(feat_dim)]
|
||||
bottleneck = [nn.BatchNorm1d(num_classes)]
|
||||
|
||||
self.bottleneck = nn.Sequential(*bottleneck)
|
||||
|
||||
|
@ -60,16 +60,18 @@ class CLSHead(nn.Module):
|
|||
global_feat = self.pool_layer(features)
|
||||
global_feat = global_feat[..., 0, 0]
|
||||
|
||||
if self.classifier.__class__.__name__ == 'Linear':
|
||||
cls_outputs = self.classifier(global_feat)
|
||||
pred_class_logits = F.linear(global_feat, self.classifier.weight)
|
||||
else:
|
||||
cls_outputs = self.classifier(global_feat, targets)
|
||||
pred_class_logits = self.classifier.s * F.linear(F.normalize(global_feat),
|
||||
F.normalize(self.classifier.weight))
|
||||
classifier_name = self.classifier.__class__.__name__
|
||||
# fmt: off
|
||||
if classifier_name == 'Linear': cls_outputs = self.classifier(global_feat)
|
||||
else: cls_outputs = self.classifier(global_feat, targets)
|
||||
# fmt: on
|
||||
|
||||
cls_outputs = self.bottleneck(cls_outputs)
|
||||
return {
|
||||
"cls_outputs": cls_outputs,
|
||||
"pred_class_logits": pred_class_logits,
|
||||
}
|
||||
|
||||
if self.training:
|
||||
return {
|
||||
"cls_outputs": cls_outputs,
|
||||
}
|
||||
else:
|
||||
cls_outputs = torch.sigmoid(cls_outputs)
|
||||
return cls_outputs
|
|
@ -0,0 +1,26 @@
|
|||
# Person Attribute Recognition in FastReID
|
||||
|
||||
## Training and Evaluation
|
||||
|
||||
To train a model, run:
|
||||
|
||||
```bash
|
||||
python3 projects/PartialReID/train_net.py --config-file <config.yaml> --num-gpus 1
|
||||
```
|
||||
|
||||
For example, to train the attribute recognition network with ResNet-50 Backbone in PA100k dataset,
|
||||
one should execute:
|
||||
|
||||
```bash
|
||||
python3 projects/attribute_recognition/train_net.py --config-file projects/attribute_recognition/configs/pa100.yml --num-gpus 4
|
||||
```
|
||||
|
||||
## Results
|
||||
|
||||
### PA100k
|
||||
|
||||
| Method | mA | Accu | Prec | Recall | F1 |
|
||||
|:--:|:--:|:--:|:--:|:--:|:--:|
|
||||
| Strongbaseline | 77.76 | 77.59 | 88.38 | 84.35 | 86.32 |
|
||||
|
||||
More datasets and test results are waiting to add, stay tune!
|
|
@ -0,0 +1,12 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: xingyu liao
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from .config import add_attr_config
|
||||
from .datasets import *
|
||||
from .attr_baseline import AttrBaseline
|
||||
from .attr_evaluation import AttrEvaluator
|
||||
from .data_build import build_attr_train_loader, build_attr_test_loader
|
||||
from .attr_trainer import AttrTrainer
|
|
@ -0,0 +1,41 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from fastreid.modeling.meta_arch.baseline import Baseline
|
||||
from fastreid.modeling.meta_arch.build import META_ARCH_REGISTRY
|
||||
from .bce_loss import cross_entropy_sigmoid_loss
|
||||
|
||||
|
||||
@META_ARCH_REGISTRY.register()
|
||||
class AttrBaseline(Baseline):
|
||||
|
||||
def losses(self, outs, sample_weight=None):
|
||||
r"""
|
||||
Compute loss from modeling's outputs, the loss function input arguments
|
||||
must be the same as the outputs of the model forwarding.
|
||||
"""
|
||||
# fmt: off
|
||||
outputs = outs["outputs"]
|
||||
gt_labels = outs["targets"]
|
||||
# model predictions
|
||||
# pred_class_logits = outputs['pred_class_logits'].detach()
|
||||
cls_outputs = outputs['cls_outputs']
|
||||
# fmt: on
|
||||
|
||||
# Log prediction accuracy
|
||||
# log_accuracy(pred_class_logits, gt_labels)
|
||||
|
||||
loss_dict = {}
|
||||
loss_names = self._cfg.MODEL.LOSSES.NAME
|
||||
|
||||
if "BinaryCrossEntropyLoss" in loss_names:
|
||||
loss_dict['loss_bce'] = cross_entropy_sigmoid_loss(
|
||||
cls_outputs,
|
||||
gt_labels,
|
||||
sample_weight,
|
||||
) * self._cfg.MODEL.LOSSES.BCE.SCALE
|
||||
|
||||
return loss_dict
|
|
@ -0,0 +1,96 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
import copy
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
|
||||
from fastreid.evaluation.evaluator import DatasetEvaluator
|
||||
from fastreid.utils import comm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AttrEvaluator(DatasetEvaluator):
|
||||
def __init__(self, cfg, attr_dict, thres=0.5, output_dir=None):
|
||||
self.cfg = cfg
|
||||
self.attr_dict = attr_dict
|
||||
self.thres = thres
|
||||
self._output_dir = output_dir
|
||||
|
||||
self.pred_logits = []
|
||||
self.gt_labels = []
|
||||
|
||||
def reset(self):
|
||||
self.pred_logits = []
|
||||
self.gt_labels = []
|
||||
|
||||
def process(self, inputs, outputs):
|
||||
self.gt_labels.extend(inputs["targets"])
|
||||
self.pred_logits.extend(outputs.cpu())
|
||||
|
||||
@staticmethod
|
||||
def get_attr_metrics(gt_labels, pred_logits, thres):
|
||||
|
||||
pred_labels = copy.deepcopy(pred_logits)
|
||||
pred_labels[pred_logits < thres] = 0
|
||||
pred_labels[pred_logits >= thres] = 1
|
||||
|
||||
# Compute label-based metric
|
||||
overlaps = pred_labels * gt_labels
|
||||
correct_pos = overlaps.sum(axis=0)
|
||||
real_pos = gt_labels.sum(axis=0)
|
||||
inv_overlaps = (1 - pred_labels) * (1 - gt_labels)
|
||||
correct_neg = inv_overlaps.sum(axis=0)
|
||||
real_neg = (1 - gt_labels).sum(axis=0)
|
||||
|
||||
# Compute instance-based accuracy
|
||||
pred_labels = pred_labels.astype(bool)
|
||||
gt_labels = gt_labels.astype(bool)
|
||||
intersect = (pred_labels & gt_labels).astype(float)
|
||||
union = (pred_labels | gt_labels).astype(float)
|
||||
ins_acc = (intersect.sum(axis=1) / union.sum(axis=1)).mean()
|
||||
ins_prec = (intersect.sum(axis=1) / pred_labels.astype(float).sum(axis=1)).mean()
|
||||
ins_rec = (intersect.sum(axis=1) / gt_labels.astype(float).sum(axis=1)).mean()
|
||||
ins_f1 = (2 * ins_prec * ins_rec) / (ins_prec + ins_rec)
|
||||
|
||||
term1 = correct_pos / real_pos
|
||||
term2 = correct_neg / real_neg
|
||||
label_mA_verbose = (term1 + term2) * 0.5
|
||||
label_mA = label_mA_verbose.mean()
|
||||
|
||||
results = OrderedDict()
|
||||
results["Accu"] = ins_acc
|
||||
results["Prec"] = ins_prec
|
||||
results["Recall"] = ins_rec
|
||||
results["F1"] = ins_f1
|
||||
results["mA"] = label_mA
|
||||
return results
|
||||
|
||||
def evaluate(self):
|
||||
if comm.get_world_size() > 1:
|
||||
comm.synchronize()
|
||||
pred_logits = comm.gather(self.pred_logits)
|
||||
pred_logits = sum(pred_logits, [])
|
||||
|
||||
gt_labels = comm.gather(self.gt_labels)
|
||||
gt_labels = sum(gt_labels, [])
|
||||
|
||||
if not comm.is_main_process():
|
||||
return {}
|
||||
else:
|
||||
pred_logits = self.pred_logits
|
||||
gt_labels = self.gt_labels
|
||||
|
||||
pred_logits = torch.stack(pred_logits, dim=0).numpy()
|
||||
gt_labels = torch.stack(gt_labels, dim=0).numpy()
|
||||
|
||||
# Pedestrian attribute metrics
|
||||
thres = self.cfg.TEST.THRES
|
||||
self._results = self.get_attr_metrics(gt_labels, pred_logits, thres)
|
||||
|
||||
return copy.deepcopy(self._results)
|
|
@ -0,0 +1,89 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: xingyu liao
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import time
|
||||
import torch
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from torch.cuda import amp
|
||||
from fastreid.engine import DefaultTrainer
|
||||
from .data_build import build_attr_train_loader, build_attr_test_loader
|
||||
from .attr_evaluation import AttrEvaluator
|
||||
|
||||
|
||||
class AttrTrainer(DefaultTrainer):
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
|
||||
# Sample weight for attributed imbalanced classification
|
||||
bce_weight_enabled = self.cfg.MODEL.LOSSES.BCE.WEIGHT_ENABLED
|
||||
# fmt: off
|
||||
if bce_weight_enabled: self.sample_weights = self.data_loader.dataset.sample_weights.to("cuda")
|
||||
else: self.sample_weights = None
|
||||
# fmt: on
|
||||
|
||||
@classmethod
|
||||
def build_train_loader(cls, cfg):
|
||||
return build_attr_train_loader(cfg)
|
||||
|
||||
@classmethod
|
||||
def build_test_loader(cls, cfg, dataset_name):
|
||||
return build_attr_test_loader(cfg, dataset_name)
|
||||
|
||||
@classmethod
|
||||
def build_evaluator(cls, cfg, dataset_name, output_folder=None):
|
||||
data_loader = cls.build_test_loader(cfg, dataset_name)
|
||||
return data_loader, AttrEvaluator(cfg, output_folder)
|
||||
|
||||
def run_step(self):
|
||||
r"""
|
||||
Implement the attribute model training logic.
|
||||
"""
|
||||
assert self.model.training, "[SimpleTrainer] model was changed to eval mode!"
|
||||
start = time.perf_counter()
|
||||
"""
|
||||
If your want to do something with the data, you can wrap the dataloader.
|
||||
"""
|
||||
data = next(self._data_loader_iter)
|
||||
data_time = time.perf_counter() - start
|
||||
|
||||
"""
|
||||
If your want to do something with the heads, you can wrap the model.
|
||||
"""
|
||||
|
||||
with amp.autocast(enabled=self.amp_enabled):
|
||||
outs = self.model(data)
|
||||
|
||||
# Compute loss
|
||||
if isinstance(self.model, DistributedDataParallel):
|
||||
loss_dict = self.model.module.losses(outs, self.sample_weights)
|
||||
else:
|
||||
loss_dict = self.model.losses(outs, self.sample_weights)
|
||||
|
||||
losses = sum(loss_dict.values())
|
||||
|
||||
with torch.cuda.stream(torch.cuda.Stream()):
|
||||
metrics_dict = loss_dict
|
||||
metrics_dict["data_time"] = data_time
|
||||
self._write_metrics(metrics_dict)
|
||||
self._detect_anomaly(losses, loss_dict)
|
||||
|
||||
"""
|
||||
If you need accumulate gradients or something similar, you can
|
||||
wrap the optimizer with your custom `zero_grad()` method.
|
||||
"""
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
if self.amp_enabled:
|
||||
self.scaler.scale(losses).backward()
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
else:
|
||||
losses.backward()
|
||||
"""
|
||||
If you need gradient clipping/scaling or other processing, you can
|
||||
wrap the optimizer with your custom `step()` method.
|
||||
"""
|
||||
self.optimizer.step()
|
|
@ -0,0 +1,33 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: xingyu liao
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def ratio2weight(targets, ratio):
|
||||
pos_weights = targets * (1 - ratio)
|
||||
neg_weights = (1 - targets) * ratio
|
||||
weights = torch.exp(neg_weights + pos_weights)
|
||||
|
||||
weights[targets > 1] = 0.0
|
||||
return weights
|
||||
|
||||
|
||||
def cross_entropy_sigmoid_loss(pred_class_logits, gt_classes, sample_weight=None):
|
||||
loss = F.binary_cross_entropy_with_logits(pred_class_logits, gt_classes, reduction='none')
|
||||
|
||||
if sample_weight is not None:
|
||||
targets_mask = torch.where(gt_classes.detach() > 0.5,
|
||||
torch.ones(1, device="cuda"), torch.zeros(1, device="cuda")) # dtype float32
|
||||
weight = ratio2weight(targets_mask, sample_weight)
|
||||
loss = loss * weight
|
||||
|
||||
with torch.no_grad():
|
||||
non_zero_cnt = max(loss.nonzero(as_tuple=False).size(0), 1)
|
||||
|
||||
loss = loss.sum() / non_zero_cnt
|
||||
return loss
|
|
@ -0,0 +1,47 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from fastreid.data.data_utils import read_image
|
||||
|
||||
|
||||
class AttrDataset(Dataset):
|
||||
"""Image Person Attribute Dataset"""
|
||||
|
||||
def __init__(self, img_items, attr_dict, transform=None):
|
||||
self.img_items = img_items
|
||||
self.attr_dict = attr_dict
|
||||
self.transform = transform
|
||||
|
||||
def __len__(self):
|
||||
return len(self.img_items)
|
||||
|
||||
def __getitem__(self, index):
|
||||
img_path, labels = self.img_items[index]
|
||||
img = read_image(img_path)
|
||||
if self.transform is not None: img = self.transform(img)
|
||||
|
||||
labels = torch.from_numpy(labels)
|
||||
|
||||
return {
|
||||
"images": img,
|
||||
"targets": labels,
|
||||
"img_paths": img_path,
|
||||
}
|
||||
|
||||
@property
|
||||
def num_classes(self):
|
||||
return len(self.attr_dict)
|
||||
|
||||
@property
|
||||
def sample_weights(self):
|
||||
sample_weights = torch.zeros(self.num_classes, dtype=torch.float)
|
||||
for _, attr in self.img_items:
|
||||
sample_weights += torch.from_numpy(attr)
|
||||
sample_weights /= len(self)
|
||||
return sample_weights
|
|
@ -0,0 +1,17 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: xingyu liao
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from fastreid.config import CfgNode as CN
|
||||
|
||||
|
||||
def add_attr_config(cfg):
|
||||
_C = cfg
|
||||
|
||||
_C.MODEL.LOSSES.BCE = CN()
|
||||
_C.MODEL.LOSSES.BCE.WEIGHT_ENABLED = True
|
||||
_C.MODEL.LOSSES.BCE.SCALE = 1.
|
||||
|
||||
_C.TEST.THRES = 0.5
|
|
@ -0,0 +1,82 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: l1aoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import os
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from fastreid.utils import comm
|
||||
|
||||
from .common_attr import AttrDataset
|
||||
from fastreid.data import samplers
|
||||
from fastreid.data.build import fast_batch_collator
|
||||
from fastreid.data.datasets import DATASET_REGISTRY
|
||||
from fastreid.data.transforms import build_transforms
|
||||
|
||||
_root = os.getenv("FASTREID_DATASETS", "datasets")
|
||||
|
||||
|
||||
def build_attr_train_loader(cfg):
|
||||
cfg = cfg.clone()
|
||||
cfg.defrost()
|
||||
|
||||
train_items = list()
|
||||
attr_dict = None
|
||||
for d in cfg.DATASETS.NAMES:
|
||||
dataset = DATASET_REGISTRY.get(d)(root=_root, combineall=cfg.DATASETS.COMBINEALL)
|
||||
if comm.is_main_process():
|
||||
dataset.show_train()
|
||||
if attr_dict is not None:
|
||||
assert attr_dict == dataset.attr_dict, "attr_dict in {} does not match with previous ones".format(d)
|
||||
else:
|
||||
attr_dict = dataset.attr_dict
|
||||
train_items.extend(dataset.train)
|
||||
|
||||
iters_per_epoch = len(train_items) // cfg.SOLVER.IMS_PER_BATCH
|
||||
cfg.SOLVER.MAX_ITER *= iters_per_epoch
|
||||
train_transforms = build_transforms(cfg, is_train=True)
|
||||
train_set = AttrDataset(train_items, attr_dict, train_transforms)
|
||||
|
||||
num_workers = cfg.DATALOADER.NUM_WORKERS
|
||||
mini_batch_size = cfg.SOLVER.IMS_PER_BATCH // comm.get_world_size()
|
||||
|
||||
data_sampler = samplers.TrainingSampler(len(train_set))
|
||||
batch_sampler = torch.utils.data.sampler.BatchSampler(data_sampler, mini_batch_size, True)
|
||||
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_set,
|
||||
num_workers=num_workers,
|
||||
batch_sampler=batch_sampler,
|
||||
collate_fn=fast_batch_collator,
|
||||
pin_memory=True,
|
||||
)
|
||||
return train_loader
|
||||
|
||||
|
||||
def build_attr_test_loader(cfg, dataset_name):
|
||||
cfg = cfg.clone()
|
||||
cfg.defrost()
|
||||
|
||||
dataset = DATASET_REGISTRY.get(dataset_name)(root=_root, combineall=cfg.DATASETS.COMBINEALL)
|
||||
if comm.is_main_process():
|
||||
dataset.show_test()
|
||||
test_items = dataset.test
|
||||
|
||||
test_transforms = build_transforms(cfg, is_train=False)
|
||||
test_set = AttrDataset(test_items, dataset.attr_dict, test_transforms)
|
||||
|
||||
mini_batch_size = cfg.TEST.IMS_PER_BATCH // comm.get_world_size()
|
||||
data_sampler = samplers.InferenceSampler(len(test_set))
|
||||
batch_sampler = torch.utils.data.BatchSampler(data_sampler, mini_batch_size, False)
|
||||
test_loader = DataLoader(
|
||||
test_set,
|
||||
batch_sampler=batch_sampler,
|
||||
num_workers=0, # save some memory
|
||||
collate_fn=fast_batch_collator,
|
||||
pin_memory=True,
|
||||
)
|
||||
return test_loader
|
||||
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: xingyu liao
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
# Attributed datasets
|
||||
from .pa100k import PA100K
|
|
@ -0,0 +1,127 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: xingyu liao
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
|
||||
from tabulate import tabulate
|
||||
from termcolor import colored
|
||||
|
||||
logger = logging.getLogger("fastreid." + __name__)
|
||||
|
||||
|
||||
class Dataset(object):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
train,
|
||||
val,
|
||||
test,
|
||||
attr_dict,
|
||||
mode='train',
|
||||
verbose=True,
|
||||
**kwargs,
|
||||
):
|
||||
self.train = train
|
||||
self.val = val
|
||||
self.test = test
|
||||
self._attr_dict = attr_dict
|
||||
self._num_attrs = len(self.attr_dict)
|
||||
|
||||
if mode == 'train':
|
||||
self.data = self.train
|
||||
elif mode == 'val':
|
||||
self.data = self.val
|
||||
else:
|
||||
self.data = self.test
|
||||
|
||||
@property
|
||||
def num_attrs(self):
|
||||
return self._num_attrs
|
||||
|
||||
@property
|
||||
def attr_dict(self):
|
||||
return self._attr_dict
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, index):
|
||||
raise NotImplementedError
|
||||
|
||||
def check_before_run(self, required_files):
|
||||
"""Checks if required files exist before going deeper.
|
||||
Args:
|
||||
required_files (str or list): string file name(s).
|
||||
"""
|
||||
if isinstance(required_files, str):
|
||||
required_files = [required_files]
|
||||
|
||||
for fpath in required_files:
|
||||
if not os.path.exists(fpath):
|
||||
raise RuntimeError('"{}" is not found'.format(fpath))
|
||||
|
||||
def combine_all(self):
|
||||
"""Combines train, val and test in a dataset for training."""
|
||||
combined = copy.deepcopy(self.train)
|
||||
|
||||
def _combine_data(data):
|
||||
for img_path, pid, camid in data:
|
||||
if pid in self._junk_pids:
|
||||
continue
|
||||
pid = self.dataset_name + "_" + str(pid)
|
||||
camid = self.dataset_name + "_" + str(camid)
|
||||
combined.append((img_path, pid, camid))
|
||||
|
||||
_combine_data(self.query)
|
||||
_combine_data(self.gallery)
|
||||
|
||||
self.train = combined
|
||||
self.num_train_pids = self.get_num_pids(self.train)
|
||||
|
||||
def show_train(self):
|
||||
num_train = len(self.train)
|
||||
num_val = len(self.val)
|
||||
num_total = num_train + num_val
|
||||
|
||||
headers = ['subset', '# images']
|
||||
csv_results = [
|
||||
['train', num_train],
|
||||
['val', num_val],
|
||||
['total', num_total],
|
||||
]
|
||||
|
||||
# tabulate it
|
||||
table = tabulate(
|
||||
csv_results,
|
||||
tablefmt="pipe",
|
||||
headers=headers,
|
||||
numalign="left",
|
||||
)
|
||||
logger.info(f"=> Loaded {self.__class__.__name__} in csv format: \n" + colored(table, "cyan"))
|
||||
logger.info("attributes:")
|
||||
for label, attr in self.attr_dict.items():
|
||||
logger.info('{:3d}: {}'.format(label, attr))
|
||||
logger.info("------------------------------")
|
||||
logger.info("# attributes: {}".format(len(self.attr_dict)))
|
||||
|
||||
def show_test(self):
|
||||
num_test = len(self.test)
|
||||
|
||||
headers = ['subset', '# images']
|
||||
csv_results = [
|
||||
['test', num_test],
|
||||
]
|
||||
|
||||
# tabulate it
|
||||
table = tabulate(
|
||||
csv_results,
|
||||
tablefmt="pipe",
|
||||
headers=headers,
|
||||
numalign="left",
|
||||
)
|
||||
logger.info(f"=> Loaded {self.__class__.__name__} in csv format: \n" + colored(table, "cyan"))
|
|
@ -0,0 +1,65 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: xingyu liao
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import os.path as osp
|
||||
|
||||
import numpy as np
|
||||
from scipy.io import loadmat
|
||||
|
||||
from fastreid.data.datasets import DATASET_REGISTRY
|
||||
|
||||
from .bases import Dataset
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class PA100K(Dataset):
|
||||
"""Pedestrian attribute dataset.
|
||||
80k training images + 20k test images.
|
||||
The folder structure should be:
|
||||
pa100k/
|
||||
data/ # images
|
||||
annotation.mat
|
||||
"""
|
||||
dataset_dir = 'PA-100K'
|
||||
|
||||
def __init__(self, root='', **kwargs):
|
||||
self.root = root
|
||||
self.dataset_dir = osp.join(self.root, self.dataset_dir)
|
||||
self.data_dir = osp.join(self.dataset_dir, 'data')
|
||||
self.anno_mat_path = osp.join(
|
||||
self.dataset_dir, 'annotation.mat'
|
||||
)
|
||||
|
||||
required_files = [self.data_dir, self.anno_mat_path]
|
||||
self.check_before_run(required_files)
|
||||
|
||||
train, val, test, attr_dict = self.extract_data()
|
||||
super(PA100K, self).__init__(train, val, test, attr_dict=attr_dict, **kwargs)
|
||||
|
||||
def extract_data(self):
|
||||
# anno_mat is a dictionary with keys: ['test_images_name', 'val_images_name',
|
||||
# 'train_images_name', 'val_label', 'attributes', 'test_label', 'train_label']
|
||||
anno_mat = loadmat(self.anno_mat_path)
|
||||
|
||||
def _extract(key_name, key_label):
|
||||
names = anno_mat[key_name]
|
||||
labels = anno_mat[key_label]
|
||||
num_imgs = names.shape[0]
|
||||
data = []
|
||||
for i in range(num_imgs):
|
||||
name = names[i, 0][0]
|
||||
attrs = labels[i, :].astype(np.float32)
|
||||
img_path = osp.join(self.data_dir, name)
|
||||
data.append((img_path, attrs))
|
||||
return data
|
||||
|
||||
train = _extract('train_images_name', 'train_label')
|
||||
val = _extract('val_images_name', 'val_label')
|
||||
test = _extract('test_images_name', 'test_label')
|
||||
attrs = anno_mat['attributes']
|
||||
attr_dict = {i: str(attr[0][0]) for i, attr in enumerate(attrs)}
|
||||
|
||||
return train, val, test, attr_dict
|
|
@ -0,0 +1,63 @@
|
|||
MODEL:
|
||||
META_ARCHITECTURE: "AttrBaseline"
|
||||
|
||||
BACKBONE:
|
||||
NAME: "build_resnet_backbone"
|
||||
NORM: "BN"
|
||||
DEPTH: "50x"
|
||||
LAST_STRIDE: 2
|
||||
FEAT_DIM: 2048
|
||||
WITH_IBN: False
|
||||
PRETRAIN: True
|
||||
PRETRAIN_PATH: "/export/home/lxy/.cache/torch/checkpoints/resnet50-19c8e357.pth"
|
||||
|
||||
HEADS:
|
||||
NAME: "AttrHead"
|
||||
NORM: "BN"
|
||||
WITH_BNNECK: True
|
||||
POOL_LAYER: "fastavgpool"
|
||||
CLS_LAYER: "linear"
|
||||
NUM_CLASSES: 26
|
||||
|
||||
LOSSES:
|
||||
NAME: ("BinaryCrossEntropyLoss",)
|
||||
|
||||
BCE:
|
||||
WEIGHT_ENABLED: True
|
||||
SCALE: 1.
|
||||
|
||||
INPUT:
|
||||
SIZE_TRAIN: [256, 128]
|
||||
SIZE_TEST: [256, 128]
|
||||
REA:
|
||||
ENABLED: False
|
||||
DO_PAD: True
|
||||
|
||||
DATALOADER:
|
||||
NUM_WORKERS: 8
|
||||
|
||||
SOLVER:
|
||||
OPT: "SGD"
|
||||
MAX_ITER: 30
|
||||
BASE_LR: 0.01
|
||||
BIAS_LR_FACTOR: 2.
|
||||
HEADS_LR_FACTOR: 10.
|
||||
WEIGHT_DECAY: 0.0005
|
||||
WEIGHT_DECAY_BIAS: 0.0005
|
||||
IMS_PER_BATCH: 64
|
||||
|
||||
SCHED: "WarmupCosineAnnealingLR"
|
||||
DELAY_ITERS: 5
|
||||
ETA_MIN_LR: 0.00001
|
||||
|
||||
WARMUP_FACTOR: 0.01
|
||||
WARMUP_ITERS: 5
|
||||
|
||||
CHECKPOINT_PERIOD: 10
|
||||
|
||||
TEST:
|
||||
EVAL_PERIOD: 10
|
||||
IMS_PER_BATCH: 256
|
||||
|
||||
CUDNN_BENCHMARK: True
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
_BASE_: "Base-attribute.yml"
|
||||
|
||||
DATASETS:
|
||||
NAMES: ("PA100K",)
|
||||
TESTS: ("PA100K",)
|
||||
|
||||
OUTPUT_DIR: "projects/attribute_recognition/logs/pa100k/strong_baseline"
|
|
@ -0,0 +1,58 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: xingyu liao
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
import sys
|
||||
|
||||
sys.path.append('.')
|
||||
|
||||
from fastreid.config import get_cfg
|
||||
from fastreid.engine import default_argument_parser, default_setup, launch
|
||||
from fastreid.utils.checkpoint import Checkpointer
|
||||
|
||||
from attribute_baseline import *
|
||||
|
||||
|
||||
def setup(args):
|
||||
"""
|
||||
Create configs and perform basic setups.
|
||||
"""
|
||||
cfg = get_cfg()
|
||||
add_attr_config(cfg)
|
||||
cfg.merge_from_file(args.config_file)
|
||||
cfg.merge_from_list(args.opts)
|
||||
cfg.freeze()
|
||||
default_setup(cfg, args)
|
||||
return cfg
|
||||
|
||||
|
||||
def main(args):
|
||||
cfg = setup(args)
|
||||
|
||||
if args.eval_only:
|
||||
cfg.defrost()
|
||||
cfg.MODEL.BACKBONE.PRETRAIN = False
|
||||
model = AttrTrainer.build_model(cfg)
|
||||
|
||||
Checkpointer(model).load(cfg.MODEL.WEIGHTS) # load trained model
|
||||
|
||||
res = AttrTrainer.test(cfg, model)
|
||||
return res
|
||||
|
||||
trainer = AttrTrainer(cfg)
|
||||
trainer.resume_or_load(resume=args.resume)
|
||||
return trainer.train()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = default_argument_parser().parse_args()
|
||||
print("Command Line Args:", args)
|
||||
launch(
|
||||
main,
|
||||
args.num_gpus,
|
||||
num_machines=args.num_machines,
|
||||
machine_rank=args.machine_rank,
|
||||
dist_url=args.dist_url,
|
||||
args=(args,),
|
||||
)
|
Loading…
Reference in New Issue