diff --git a/README.md b/README.md index ff76709..1d0b0b5 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,7 @@ FastReID is a research platform that implements state-of-the-art re-identificati ## What's New +- [Sep 2020] Added the person attribute recognition code based fastreid. See `projects/attribute_recognition`. - [Sep 2020] Automatic Mixed Precision training is supported with pytorch1.6 built-in `torch.cuda.amp`. Set `cfg.SOLVER.AMP_ENABLED=True` to switch it on. - [Aug 2020] [Model Distillation](https://github.com/JDAI-CV/fast-reid/tree/master/projects/DistillReID) is supported, thanks for [guan'an wang](https://github.com/wangguanan)'s contribution. - [Aug 2020] ONNX/TensorRT converter is supported. diff --git a/fastreid/modeling/heads/__init__.py b/fastreid/modeling/heads/__init__.py index e4f5082..7d62330 100644 --- a/fastreid/modeling/heads/__init__.py +++ b/fastreid/modeling/heads/__init__.py @@ -8,4 +8,4 @@ from .build import REID_HEADS_REGISTRY, build_heads # import all the meta_arch, so they will be registered from .embedding_head import EmbeddingHead -from .cls_head import CLSHead +from .attr_head import AttrHead diff --git a/fastreid/modeling/heads/cls_head.py b/fastreid/modeling/heads/attr_head.py similarity index 79% rename from fastreid/modeling/heads/cls_head.py rename to fastreid/modeling/heads/attr_head.py index 42d7c81..a0618db 100644 --- a/fastreid/modeling/heads/cls_head.py +++ b/fastreid/modeling/heads/attr_head.py @@ -4,7 +4,7 @@ @contact: sherlockliao01@gmail.com """ -import torch.nn.functional as F +import torch from torch import nn from fastreid.layers import * @@ -13,7 +13,7 @@ from .build import REID_HEADS_REGISTRY @REID_HEADS_REGISTRY.register() -class CLSHead(nn.Module): +class AttrHead(nn.Module): def __init__(self, cfg): super().__init__() # fmt: off @@ -46,7 +46,7 @@ class CLSHead(nn.Module): # bottleneck = [] # if with_bnneck: # bottleneck.append(get_norm(norm_type, feat_dim, bias_freeze=True)) - bottleneck = [nn.BatchNorm1d(feat_dim)] + bottleneck = [nn.BatchNorm1d(num_classes)] self.bottleneck = nn.Sequential(*bottleneck) @@ -60,16 +60,18 @@ class CLSHead(nn.Module): global_feat = self.pool_layer(features) global_feat = global_feat[..., 0, 0] - if self.classifier.__class__.__name__ == 'Linear': - cls_outputs = self.classifier(global_feat) - pred_class_logits = F.linear(global_feat, self.classifier.weight) - else: - cls_outputs = self.classifier(global_feat, targets) - pred_class_logits = self.classifier.s * F.linear(F.normalize(global_feat), - F.normalize(self.classifier.weight)) + classifier_name = self.classifier.__class__.__name__ + # fmt: off + if classifier_name == 'Linear': cls_outputs = self.classifier(global_feat) + else: cls_outputs = self.classifier(global_feat, targets) + # fmt: on cls_outputs = self.bottleneck(cls_outputs) - return { - "cls_outputs": cls_outputs, - "pred_class_logits": pred_class_logits, - } + + if self.training: + return { + "cls_outputs": cls_outputs, + } + else: + cls_outputs = torch.sigmoid(cls_outputs) + return cls_outputs diff --git a/projects/attribute_recognition/README.md b/projects/attribute_recognition/README.md new file mode 100644 index 0000000..6efac09 --- /dev/null +++ b/projects/attribute_recognition/README.md @@ -0,0 +1,26 @@ +# Person Attribute Recognition in FastReID + +## Training and Evaluation + +To train a model, run: + +```bash +python3 projects/PartialReID/train_net.py --config-file --num-gpus 1 +``` + +For example, to train the attribute recognition network with ResNet-50 Backbone in PA100k dataset, +one should execute: + +```bash +python3 projects/attribute_recognition/train_net.py --config-file projects/attribute_recognition/configs/pa100.yml --num-gpus 4 +``` + +## Results + +### PA100k + +| Method | mA | Accu | Prec | Recall | F1 | +|:--:|:--:|:--:|:--:|:--:|:--:| +| Strongbaseline | 77.76 | 77.59 | 88.38 | 84.35 | 86.32 | + +More datasets and test results are waiting to add, stay tune! diff --git a/projects/attribute_recognition/attribute_baseline/__init__.py b/projects/attribute_recognition/attribute_baseline/__init__.py new file mode 100644 index 0000000..a769ede --- /dev/null +++ b/projects/attribute_recognition/attribute_baseline/__init__.py @@ -0,0 +1,12 @@ +# encoding: utf-8 +""" +@author: xingyu liao +@contact: sherlockliao01@gmail.com +""" + +from .config import add_attr_config +from .datasets import * +from .attr_baseline import AttrBaseline +from .attr_evaluation import AttrEvaluator +from .data_build import build_attr_train_loader, build_attr_test_loader +from .attr_trainer import AttrTrainer diff --git a/projects/attribute_recognition/attribute_baseline/attr_baseline.py b/projects/attribute_recognition/attribute_baseline/attr_baseline.py new file mode 100644 index 0000000..1961f68 --- /dev/null +++ b/projects/attribute_recognition/attribute_baseline/attr_baseline.py @@ -0,0 +1,41 @@ +# encoding: utf-8 +""" +@author: liaoxingyu +@contact: sherlockliao01@gmail.com +""" + +from fastreid.modeling.meta_arch.baseline import Baseline +from fastreid.modeling.meta_arch.build import META_ARCH_REGISTRY +from .bce_loss import cross_entropy_sigmoid_loss + + +@META_ARCH_REGISTRY.register() +class AttrBaseline(Baseline): + + def losses(self, outs, sample_weight=None): + r""" + Compute loss from modeling's outputs, the loss function input arguments + must be the same as the outputs of the model forwarding. + """ + # fmt: off + outputs = outs["outputs"] + gt_labels = outs["targets"] + # model predictions + # pred_class_logits = outputs['pred_class_logits'].detach() + cls_outputs = outputs['cls_outputs'] + # fmt: on + + # Log prediction accuracy + # log_accuracy(pred_class_logits, gt_labels) + + loss_dict = {} + loss_names = self._cfg.MODEL.LOSSES.NAME + + if "BinaryCrossEntropyLoss" in loss_names: + loss_dict['loss_bce'] = cross_entropy_sigmoid_loss( + cls_outputs, + gt_labels, + sample_weight, + ) * self._cfg.MODEL.LOSSES.BCE.SCALE + + return loss_dict diff --git a/projects/attribute_recognition/attribute_baseline/attr_evaluation.py b/projects/attribute_recognition/attribute_baseline/attr_evaluation.py new file mode 100644 index 0000000..d102284 --- /dev/null +++ b/projects/attribute_recognition/attribute_baseline/attr_evaluation.py @@ -0,0 +1,96 @@ +# encoding: utf-8 +""" +@author: liaoxingyu +@contact: sherlockliao01@gmail.com +""" +import copy +import logging +from collections import OrderedDict + +import torch + +from fastreid.evaluation.evaluator import DatasetEvaluator +from fastreid.utils import comm + +logger = logging.getLogger(__name__) + + +class AttrEvaluator(DatasetEvaluator): + def __init__(self, cfg, attr_dict, thres=0.5, output_dir=None): + self.cfg = cfg + self.attr_dict = attr_dict + self.thres = thres + self._output_dir = output_dir + + self.pred_logits = [] + self.gt_labels = [] + + def reset(self): + self.pred_logits = [] + self.gt_labels = [] + + def process(self, inputs, outputs): + self.gt_labels.extend(inputs["targets"]) + self.pred_logits.extend(outputs.cpu()) + + @staticmethod + def get_attr_metrics(gt_labels, pred_logits, thres): + + pred_labels = copy.deepcopy(pred_logits) + pred_labels[pred_logits < thres] = 0 + pred_labels[pred_logits >= thres] = 1 + + # Compute label-based metric + overlaps = pred_labels * gt_labels + correct_pos = overlaps.sum(axis=0) + real_pos = gt_labels.sum(axis=0) + inv_overlaps = (1 - pred_labels) * (1 - gt_labels) + correct_neg = inv_overlaps.sum(axis=0) + real_neg = (1 - gt_labels).sum(axis=0) + + # Compute instance-based accuracy + pred_labels = pred_labels.astype(bool) + gt_labels = gt_labels.astype(bool) + intersect = (pred_labels & gt_labels).astype(float) + union = (pred_labels | gt_labels).astype(float) + ins_acc = (intersect.sum(axis=1) / union.sum(axis=1)).mean() + ins_prec = (intersect.sum(axis=1) / pred_labels.astype(float).sum(axis=1)).mean() + ins_rec = (intersect.sum(axis=1) / gt_labels.astype(float).sum(axis=1)).mean() + ins_f1 = (2 * ins_prec * ins_rec) / (ins_prec + ins_rec) + + term1 = correct_pos / real_pos + term2 = correct_neg / real_neg + label_mA_verbose = (term1 + term2) * 0.5 + label_mA = label_mA_verbose.mean() + + results = OrderedDict() + results["Accu"] = ins_acc + results["Prec"] = ins_prec + results["Recall"] = ins_rec + results["F1"] = ins_f1 + results["mA"] = label_mA + return results + + def evaluate(self): + if comm.get_world_size() > 1: + comm.synchronize() + pred_logits = comm.gather(self.pred_logits) + pred_logits = sum(pred_logits, []) + + gt_labels = comm.gather(self.gt_labels) + gt_labels = sum(gt_labels, []) + + if not comm.is_main_process(): + return {} + else: + pred_logits = self.pred_logits + gt_labels = self.gt_labels + + pred_logits = torch.stack(pred_logits, dim=0).numpy() + gt_labels = torch.stack(gt_labels, dim=0).numpy() + + # Pedestrian attribute metrics + thres = self.cfg.TEST.THRES + self._results = self.get_attr_metrics(gt_labels, pred_logits, thres) + + return copy.deepcopy(self._results) diff --git a/projects/attribute_recognition/attribute_baseline/attr_trainer.py b/projects/attribute_recognition/attribute_baseline/attr_trainer.py new file mode 100644 index 0000000..d63f124 --- /dev/null +++ b/projects/attribute_recognition/attribute_baseline/attr_trainer.py @@ -0,0 +1,89 @@ +# encoding: utf-8 +""" +@author: xingyu liao +@contact: sherlockliao01@gmail.com +""" + +import time +import torch +from torch.nn.parallel import DistributedDataParallel +from torch.cuda import amp +from fastreid.engine import DefaultTrainer +from .data_build import build_attr_train_loader, build_attr_test_loader +from .attr_evaluation import AttrEvaluator + + +class AttrTrainer(DefaultTrainer): + def __init__(self, cfg): + super().__init__(cfg) + + # Sample weight for attributed imbalanced classification + bce_weight_enabled = self.cfg.MODEL.LOSSES.BCE.WEIGHT_ENABLED + # fmt: off + if bce_weight_enabled: self.sample_weights = self.data_loader.dataset.sample_weights.to("cuda") + else: self.sample_weights = None + # fmt: on + + @classmethod + def build_train_loader(cls, cfg): + return build_attr_train_loader(cfg) + + @classmethod + def build_test_loader(cls, cfg, dataset_name): + return build_attr_test_loader(cfg, dataset_name) + + @classmethod + def build_evaluator(cls, cfg, dataset_name, output_folder=None): + data_loader = cls.build_test_loader(cfg, dataset_name) + return data_loader, AttrEvaluator(cfg, output_folder) + + def run_step(self): + r""" + Implement the attribute model training logic. + """ + assert self.model.training, "[SimpleTrainer] model was changed to eval mode!" + start = time.perf_counter() + """ + If your want to do something with the data, you can wrap the dataloader. + """ + data = next(self._data_loader_iter) + data_time = time.perf_counter() - start + + """ + If your want to do something with the heads, you can wrap the model. + """ + + with amp.autocast(enabled=self.amp_enabled): + outs = self.model(data) + + # Compute loss + if isinstance(self.model, DistributedDataParallel): + loss_dict = self.model.module.losses(outs, self.sample_weights) + else: + loss_dict = self.model.losses(outs, self.sample_weights) + + losses = sum(loss_dict.values()) + + with torch.cuda.stream(torch.cuda.Stream()): + metrics_dict = loss_dict + metrics_dict["data_time"] = data_time + self._write_metrics(metrics_dict) + self._detect_anomaly(losses, loss_dict) + + """ + If you need accumulate gradients or something similar, you can + wrap the optimizer with your custom `zero_grad()` method. + """ + self.optimizer.zero_grad() + + if self.amp_enabled: + self.scaler.scale(losses).backward() + self.scaler.step(self.optimizer) + self.scaler.update() + else: + losses.backward() + """ + If you need gradient clipping/scaling or other processing, you can + wrap the optimizer with your custom `step()` method. + """ + self.optimizer.step() diff --git a/projects/attribute_recognition/attribute_baseline/bce_loss.py b/projects/attribute_recognition/attribute_baseline/bce_loss.py new file mode 100644 index 0000000..639c4a3 --- /dev/null +++ b/projects/attribute_recognition/attribute_baseline/bce_loss.py @@ -0,0 +1,33 @@ +# encoding: utf-8 +""" +@author: xingyu liao +@contact: sherlockliao01@gmail.com +""" + +import torch +import torch.nn.functional as F + + +def ratio2weight(targets, ratio): + pos_weights = targets * (1 - ratio) + neg_weights = (1 - targets) * ratio + weights = torch.exp(neg_weights + pos_weights) + + weights[targets > 1] = 0.0 + return weights + + +def cross_entropy_sigmoid_loss(pred_class_logits, gt_classes, sample_weight=None): + loss = F.binary_cross_entropy_with_logits(pred_class_logits, gt_classes, reduction='none') + + if sample_weight is not None: + targets_mask = torch.where(gt_classes.detach() > 0.5, + torch.ones(1, device="cuda"), torch.zeros(1, device="cuda")) # dtype float32 + weight = ratio2weight(targets_mask, sample_weight) + loss = loss * weight + + with torch.no_grad(): + non_zero_cnt = max(loss.nonzero(as_tuple=False).size(0), 1) + + loss = loss.sum() / non_zero_cnt + return loss diff --git a/projects/attribute_recognition/attribute_baseline/common_attr.py b/projects/attribute_recognition/attribute_baseline/common_attr.py new file mode 100644 index 0000000..896ec55 --- /dev/null +++ b/projects/attribute_recognition/attribute_baseline/common_attr.py @@ -0,0 +1,47 @@ +# encoding: utf-8 +""" +@author: liaoxingyu +@contact: sherlockliao01@gmail.com +""" + +import torch +from torch.utils.data import Dataset + +from fastreid.data.data_utils import read_image + + +class AttrDataset(Dataset): + """Image Person Attribute Dataset""" + + def __init__(self, img_items, attr_dict, transform=None): + self.img_items = img_items + self.attr_dict = attr_dict + self.transform = transform + + def __len__(self): + return len(self.img_items) + + def __getitem__(self, index): + img_path, labels = self.img_items[index] + img = read_image(img_path) + if self.transform is not None: img = self.transform(img) + + labels = torch.from_numpy(labels) + + return { + "images": img, + "targets": labels, + "img_paths": img_path, + } + + @property + def num_classes(self): + return len(self.attr_dict) + + @property + def sample_weights(self): + sample_weights = torch.zeros(self.num_classes, dtype=torch.float) + for _, attr in self.img_items: + sample_weights += torch.from_numpy(attr) + sample_weights /= len(self) + return sample_weights diff --git a/projects/attribute_recognition/attribute_baseline/config.py b/projects/attribute_recognition/attribute_baseline/config.py new file mode 100644 index 0000000..5b69581 --- /dev/null +++ b/projects/attribute_recognition/attribute_baseline/config.py @@ -0,0 +1,17 @@ +# encoding: utf-8 +""" +@author: xingyu liao +@contact: sherlockliao01@gmail.com +""" + +from fastreid.config import CfgNode as CN + + +def add_attr_config(cfg): + _C = cfg + + _C.MODEL.LOSSES.BCE = CN() + _C.MODEL.LOSSES.BCE.WEIGHT_ENABLED = True + _C.MODEL.LOSSES.BCE.SCALE = 1. + + _C.TEST.THRES = 0.5 diff --git a/projects/attribute_recognition/attribute_baseline/data_build.py b/projects/attribute_recognition/attribute_baseline/data_build.py new file mode 100644 index 0000000..eb04922 --- /dev/null +++ b/projects/attribute_recognition/attribute_baseline/data_build.py @@ -0,0 +1,82 @@ +# encoding: utf-8 +""" +@author: l1aoxingyu +@contact: sherlockliao01@gmail.com +""" + +import os +import torch +from torch.utils.data import DataLoader +from fastreid.utils import comm + +from .common_attr import AttrDataset +from fastreid.data import samplers +from fastreid.data.build import fast_batch_collator +from fastreid.data.datasets import DATASET_REGISTRY +from fastreid.data.transforms import build_transforms + +_root = os.getenv("FASTREID_DATASETS", "datasets") + + +def build_attr_train_loader(cfg): + cfg = cfg.clone() + cfg.defrost() + + train_items = list() + attr_dict = None + for d in cfg.DATASETS.NAMES: + dataset = DATASET_REGISTRY.get(d)(root=_root, combineall=cfg.DATASETS.COMBINEALL) + if comm.is_main_process(): + dataset.show_train() + if attr_dict is not None: + assert attr_dict == dataset.attr_dict, "attr_dict in {} does not match with previous ones".format(d) + else: + attr_dict = dataset.attr_dict + train_items.extend(dataset.train) + + iters_per_epoch = len(train_items) // cfg.SOLVER.IMS_PER_BATCH + cfg.SOLVER.MAX_ITER *= iters_per_epoch + train_transforms = build_transforms(cfg, is_train=True) + train_set = AttrDataset(train_items, attr_dict, train_transforms) + + num_workers = cfg.DATALOADER.NUM_WORKERS + mini_batch_size = cfg.SOLVER.IMS_PER_BATCH // comm.get_world_size() + + data_sampler = samplers.TrainingSampler(len(train_set)) + batch_sampler = torch.utils.data.sampler.BatchSampler(data_sampler, mini_batch_size, True) + + train_loader = torch.utils.data.DataLoader( + train_set, + num_workers=num_workers, + batch_sampler=batch_sampler, + collate_fn=fast_batch_collator, + pin_memory=True, + ) + return train_loader + + +def build_attr_test_loader(cfg, dataset_name): + cfg = cfg.clone() + cfg.defrost() + + dataset = DATASET_REGISTRY.get(dataset_name)(root=_root, combineall=cfg.DATASETS.COMBINEALL) + if comm.is_main_process(): + dataset.show_test() + test_items = dataset.test + + test_transforms = build_transforms(cfg, is_train=False) + test_set = AttrDataset(test_items, dataset.attr_dict, test_transforms) + + mini_batch_size = cfg.TEST.IMS_PER_BATCH // comm.get_world_size() + data_sampler = samplers.InferenceSampler(len(test_set)) + batch_sampler = torch.utils.data.BatchSampler(data_sampler, mini_batch_size, False) + test_loader = DataLoader( + test_set, + batch_sampler=batch_sampler, + num_workers=0, # save some memory + collate_fn=fast_batch_collator, + pin_memory=True, + ) + return test_loader + + diff --git a/projects/attribute_recognition/attribute_baseline/datasets/__init__.py b/projects/attribute_recognition/attribute_baseline/datasets/__init__.py new file mode 100644 index 0000000..18050a2 --- /dev/null +++ b/projects/attribute_recognition/attribute_baseline/datasets/__init__.py @@ -0,0 +1,8 @@ +# encoding: utf-8 +""" +@author: xingyu liao +@contact: sherlockliao01@gmail.com +""" + +# Attributed datasets +from .pa100k import PA100K diff --git a/projects/attribute_recognition/attribute_baseline/datasets/bases.py b/projects/attribute_recognition/attribute_baseline/datasets/bases.py new file mode 100644 index 0000000..ee0bbe8 --- /dev/null +++ b/projects/attribute_recognition/attribute_baseline/datasets/bases.py @@ -0,0 +1,127 @@ +# encoding: utf-8 +""" +@author: xingyu liao +@contact: sherlockliao01@gmail.com +""" + +import copy +import logging +import os + +from tabulate import tabulate +from termcolor import colored + +logger = logging.getLogger("fastreid." + __name__) + + +class Dataset(object): + + def __init__( + self, + train, + val, + test, + attr_dict, + mode='train', + verbose=True, + **kwargs, + ): + self.train = train + self.val = val + self.test = test + self._attr_dict = attr_dict + self._num_attrs = len(self.attr_dict) + + if mode == 'train': + self.data = self.train + elif mode == 'val': + self.data = self.val + else: + self.data = self.test + + @property + def num_attrs(self): + return self._num_attrs + + @property + def attr_dict(self): + return self._attr_dict + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + raise NotImplementedError + + def check_before_run(self, required_files): + """Checks if required files exist before going deeper. + Args: + required_files (str or list): string file name(s). + """ + if isinstance(required_files, str): + required_files = [required_files] + + for fpath in required_files: + if not os.path.exists(fpath): + raise RuntimeError('"{}" is not found'.format(fpath)) + + def combine_all(self): + """Combines train, val and test in a dataset for training.""" + combined = copy.deepcopy(self.train) + + def _combine_data(data): + for img_path, pid, camid in data: + if pid in self._junk_pids: + continue + pid = self.dataset_name + "_" + str(pid) + camid = self.dataset_name + "_" + str(camid) + combined.append((img_path, pid, camid)) + + _combine_data(self.query) + _combine_data(self.gallery) + + self.train = combined + self.num_train_pids = self.get_num_pids(self.train) + + def show_train(self): + num_train = len(self.train) + num_val = len(self.val) + num_total = num_train + num_val + + headers = ['subset', '# images'] + csv_results = [ + ['train', num_train], + ['val', num_val], + ['total', num_total], + ] + + # tabulate it + table = tabulate( + csv_results, + tablefmt="pipe", + headers=headers, + numalign="left", + ) + logger.info(f"=> Loaded {self.__class__.__name__} in csv format: \n" + colored(table, "cyan")) + logger.info("attributes:") + for label, attr in self.attr_dict.items(): + logger.info('{:3d}: {}'.format(label, attr)) + logger.info("------------------------------") + logger.info("# attributes: {}".format(len(self.attr_dict))) + + def show_test(self): + num_test = len(self.test) + + headers = ['subset', '# images'] + csv_results = [ + ['test', num_test], + ] + + # tabulate it + table = tabulate( + csv_results, + tablefmt="pipe", + headers=headers, + numalign="left", + ) + logger.info(f"=> Loaded {self.__class__.__name__} in csv format: \n" + colored(table, "cyan")) diff --git a/projects/attribute_recognition/attribute_baseline/datasets/pa100k.py b/projects/attribute_recognition/attribute_baseline/datasets/pa100k.py new file mode 100644 index 0000000..5d6d154 --- /dev/null +++ b/projects/attribute_recognition/attribute_baseline/datasets/pa100k.py @@ -0,0 +1,65 @@ +# encoding: utf-8 +""" +@author: xingyu liao +@contact: sherlockliao01@gmail.com +""" + +import os.path as osp + +import numpy as np +from scipy.io import loadmat + +from fastreid.data.datasets import DATASET_REGISTRY + +from .bases import Dataset + + +@DATASET_REGISTRY.register() +class PA100K(Dataset): + """Pedestrian attribute dataset. + 80k training images + 20k test images. + The folder structure should be: + pa100k/ + data/ # images + annotation.mat + """ + dataset_dir = 'PA-100K' + + def __init__(self, root='', **kwargs): + self.root = root + self.dataset_dir = osp.join(self.root, self.dataset_dir) + self.data_dir = osp.join(self.dataset_dir, 'data') + self.anno_mat_path = osp.join( + self.dataset_dir, 'annotation.mat' + ) + + required_files = [self.data_dir, self.anno_mat_path] + self.check_before_run(required_files) + + train, val, test, attr_dict = self.extract_data() + super(PA100K, self).__init__(train, val, test, attr_dict=attr_dict, **kwargs) + + def extract_data(self): + # anno_mat is a dictionary with keys: ['test_images_name', 'val_images_name', + # 'train_images_name', 'val_label', 'attributes', 'test_label', 'train_label'] + anno_mat = loadmat(self.anno_mat_path) + + def _extract(key_name, key_label): + names = anno_mat[key_name] + labels = anno_mat[key_label] + num_imgs = names.shape[0] + data = [] + for i in range(num_imgs): + name = names[i, 0][0] + attrs = labels[i, :].astype(np.float32) + img_path = osp.join(self.data_dir, name) + data.append((img_path, attrs)) + return data + + train = _extract('train_images_name', 'train_label') + val = _extract('val_images_name', 'val_label') + test = _extract('test_images_name', 'test_label') + attrs = anno_mat['attributes'] + attr_dict = {i: str(attr[0][0]) for i, attr in enumerate(attrs)} + + return train, val, test, attr_dict diff --git a/projects/attribute_recognition/configs/Base-attribute.yml b/projects/attribute_recognition/configs/Base-attribute.yml new file mode 100644 index 0000000..ba6da27 --- /dev/null +++ b/projects/attribute_recognition/configs/Base-attribute.yml @@ -0,0 +1,63 @@ +MODEL: + META_ARCHITECTURE: "AttrBaseline" + + BACKBONE: + NAME: "build_resnet_backbone" + NORM: "BN" + DEPTH: "50x" + LAST_STRIDE: 2 + FEAT_DIM: 2048 + WITH_IBN: False + PRETRAIN: True + PRETRAIN_PATH: "/export/home/lxy/.cache/torch/checkpoints/resnet50-19c8e357.pth" + + HEADS: + NAME: "AttrHead" + NORM: "BN" + WITH_BNNECK: True + POOL_LAYER: "fastavgpool" + CLS_LAYER: "linear" + NUM_CLASSES: 26 + + LOSSES: + NAME: ("BinaryCrossEntropyLoss",) + + BCE: + WEIGHT_ENABLED: True + SCALE: 1. + +INPUT: + SIZE_TRAIN: [256, 128] + SIZE_TEST: [256, 128] + REA: + ENABLED: False + DO_PAD: True + +DATALOADER: + NUM_WORKERS: 8 + +SOLVER: + OPT: "SGD" + MAX_ITER: 30 + BASE_LR: 0.01 + BIAS_LR_FACTOR: 2. + HEADS_LR_FACTOR: 10. + WEIGHT_DECAY: 0.0005 + WEIGHT_DECAY_BIAS: 0.0005 + IMS_PER_BATCH: 64 + + SCHED: "WarmupCosineAnnealingLR" + DELAY_ITERS: 5 + ETA_MIN_LR: 0.00001 + + WARMUP_FACTOR: 0.01 + WARMUP_ITERS: 5 + + CHECKPOINT_PERIOD: 10 + +TEST: + EVAL_PERIOD: 10 + IMS_PER_BATCH: 256 + +CUDNN_BENCHMARK: True + diff --git a/projects/attribute_recognition/configs/pa100.yml b/projects/attribute_recognition/configs/pa100.yml new file mode 100644 index 0000000..b7de5ee --- /dev/null +++ b/projects/attribute_recognition/configs/pa100.yml @@ -0,0 +1,7 @@ +_BASE_: "Base-attribute.yml" + +DATASETS: + NAMES: ("PA100K",) + TESTS: ("PA100K",) + +OUTPUT_DIR: "projects/attribute_recognition/logs/pa100k/strong_baseline" \ No newline at end of file diff --git a/projects/attribute_recognition/train_net.py b/projects/attribute_recognition/train_net.py new file mode 100644 index 0000000..1aa97d6 --- /dev/null +++ b/projects/attribute_recognition/train_net.py @@ -0,0 +1,58 @@ +# encoding: utf-8 +""" +@author: xingyu liao +@contact: sherlockliao01@gmail.com +""" +import sys + +sys.path.append('.') + +from fastreid.config import get_cfg +from fastreid.engine import default_argument_parser, default_setup, launch +from fastreid.utils.checkpoint import Checkpointer + +from attribute_baseline import * + + +def setup(args): + """ + Create configs and perform basic setups. + """ + cfg = get_cfg() + add_attr_config(cfg) + cfg.merge_from_file(args.config_file) + cfg.merge_from_list(args.opts) + cfg.freeze() + default_setup(cfg, args) + return cfg + + +def main(args): + cfg = setup(args) + + if args.eval_only: + cfg.defrost() + cfg.MODEL.BACKBONE.PRETRAIN = False + model = AttrTrainer.build_model(cfg) + + Checkpointer(model).load(cfg.MODEL.WEIGHTS) # load trained model + + res = AttrTrainer.test(cfg, model) + return res + + trainer = AttrTrainer(cfg) + trainer.resume_or_load(resume=args.resume) + return trainer.train() + + +if __name__ == "__main__": + args = default_argument_parser().parse_args() + print("Command Line Args:", args) + launch( + main, + args.num_gpus, + num_machines=args.num_machines, + machine_rank=args.machine_rank, + dist_url=args.dist_url, + args=(args,), + )