Update first stable version v1.0

This commit is contained in:
sherlock 2019-01-10 18:39:31 +08:00
parent 519bac01fc
commit 69e12d989d
55 changed files with 1641 additions and 1042 deletions

View File

@ -1,21 +1,23 @@
# ReID_baseline
Baseline model (with bottleneck) for person ReID (using softmax and triplet loss). This is PyTorch version, [mxnet version](https://github.com/L1aoXingyu/reid_baseline_gluon) has a better result and more SOTA methods.
Baseline model (with bottleneck) for person ReID (using softmax and triplet loss).
We support
- multi-GPU training
- easy dataset preparation
- end-to-end training and evaluation
- [x] easy dataset preparation
- [x] end-to-end training and evaluation
- [x] high modular management
## Get Started
The designed architecture follows this guide [PyTorch-Project-Template](https://github.com/L1aoXingyu/PyTorch-Project-Template), you can check each folder's purpose by yourself.
1. `cd` to folder where you want to download this repo
2. Run `git clone https://github.com/L1aoXingyu/reid_baseline.git`
3. Install dependencies:
- [pytorch 0.4](https://pytorch.org/)
- [pytorch 1.0](https://pytorch.org/)
- torchvision
- tensorflow (for tensorboard)
- [tensorboardX](https://github.com/lanpa/tensorboardX)
- [ignite](https://github.com/pytorch/ignite)
- [yacs](https://github.com/rbgirshick/yacs)
4. Prepare dataset
Create a directory to store reid datasets under this repo via
```bash
cd reid_baseline
@ -23,39 +25,43 @@ We support
```
1. Download dataset to `data/` from http://www.liangzheng.org/Project/project_reid.html
2. Extract dataset and rename to `market1501`. The data structure would like:
```
market1501/
bounding_box_test/
bounding_box_train/
```bash
data
market1501
bounding_box_test/
bounding_box_train/
```
5. Prepare pretrained model if you don't have
```python
from torchvision import models
models.resnet50(pretrained=True)
```
Then it will automatically download model in `~.torch/models/`, you should set this path in `config.py`
Then it will automatically download model in `~/.torch/models/`, you should set this path in `config/defaults.py` for all training or set in every single training config file in `configs/`.
## Train
You can run
Most of the configuration files that we provide, you can run this command for training
```bash
bash scripts/train_triplet_softmax.sh
python3 tools/train.py --config_file='configs/market1501_softmax_bs64.yml'
```
You can also modify your cfg parameters as follow
```bash
python3 tools/train.py --config_file='configs/market1501_softmax_bs64.yml' INPUT.SIZE_TRAIN '(256, 128)' INPUT.SIZE_TEST '(256, 128)'
```
in `reid_baseline` folder if you want to train with softmax and triplet loss. You can find others train scripts in `scripts`.
## Results
**network architecture**
<div align=center>
<img src='https://ws3.sinaimg.cn/large/006tNbRwly1fvh3ekjh12j315k0j4q58.jpg' width='500'>
</div>
| cfg | market1501 | cuhk03 | dukemtmc |
| --- | -- | -- | -- |
| softmax, size=(384, 128), batch_size=64 | 92.5 (79.4) | 60.4 (56.1) | 84.6 (68.1) |
| softmax, size=(256, 128), batch_size=64 | 92.0 (80.4) | 60.5 (55.5) | 84.1(68.4) |
| softmax_triplet, size=(384, 128), batch_size=128(32 id x 4 imgs) | 93.2 (82.5) | - | 86.4 (73.1)
| softmax_triplet, size=(256, 128), batch_size=128(32 id x 4 imgs) | 93.8 (83.2) | 65.9 (61.4) | -
| config | Market1501 |
| --- | -- |
| bs(32) size(384,128) softmax | 92.2 (78.5) |
| bs(64) size(384,128) softmax | 92.5 (79.6) |
| bs(32) size(256,128) softmax | 92.0 (78.4) |
| bs(64) size(256,128) softmax | 91.7 (78.3) |
| bs(128) size(256,128) softmax | 91.2 (77.4) |
| triplet(p=32,k=4) size(256,128) | 88.3 (73.8) |
| triplet(p=16,k=4)+softmax size(384,128) | 93.1 (82.0) |
| triplet(p=24,k=4)+softmax size(384,128) | 91.7 (79.0) |

7
config/__init__.py Normal file
View File

@ -0,0 +1,7 @@
# encoding: utf-8
"""
@author: sherlock
@contact: sherlockliao01@gmail.com
"""
from .defaults import _C as cfg

101
config/defaults.py Normal file
View File

@ -0,0 +1,101 @@
from yacs.config import CfgNode as CN
# -----------------------------------------------------------------------------
# Convention about Training / Test specific parameters
# -----------------------------------------------------------------------------
# Whenever an argument can be either used for training or for testing, the
# corresponding name will be post-fixed by a _TRAIN for a training parameter,
# or _TEST for a test-specific parameter.
# For example, the number of images during training will be
# IMAGES_PER_BATCH_TRAIN, while the number of images for testing will be
# IMAGES_PER_BATCH_TEST
# -----------------------------------------------------------------------------
# Config definition
# -----------------------------------------------------------------------------
_C = CN()
_C.MODEL = CN()
_C.MODEL.DEVICE = "cuda"
_C.MODEL.NAME = 'resnet50'
_C.MODEL.LAST_STRIDE = 1
_C.MODEL.PRETRAIN_PATH = ''
# -----------------------------------------------------------------------------
# INPUT
# -----------------------------------------------------------------------------
_C.INPUT = CN()
# Size of the image during training
_C.INPUT.SIZE_TRAIN = [384, 128]
# Size of the image during test
_C.INPUT.SIZE_TEST = [384, 128]
# Random probability for image horizontal flip
_C.INPUT.PROB = 0.5
# Values to be used for image normalization
_C.INPUT.PIXEL_MEAN = [0.485, 0.456, 0.406]
# Values to be used for image normalization
_C.INPUT.PIXEL_STD = [0.229, 0.224, 0.225]
# Value of padding size
_C.INPUT.PADDING = 10
# -----------------------------------------------------------------------------
# Dataset
# -----------------------------------------------------------------------------
_C.DATASETS = CN()
# List of the dataset names for training, as present in paths_catalog.py
_C.DATASETS.NAMES = ('market1501')
# -----------------------------------------------------------------------------
# DataLoader
# -----------------------------------------------------------------------------
_C.DATALOADER = CN()
# Number of data loading threads
_C.DATALOADER.NUM_WORKERS = 8
# Sampler for data loading
_C.DATALOADER.SAMPLER = 'softmax'
# Number of instance for one batch
_C.DATALOADER.NUM_INSTANCE = 16
# ---------------------------------------------------------------------------- #
# Solver
# ---------------------------------------------------------------------------- #
_C.SOLVER = CN()
_C.SOLVER.OPTIMIZER_NAME = "Adam"
_C.SOLVER.MAX_EPOCHS = 50
_C.SOLVER.BASE_LR = 3e-4
_C.SOLVER.BIAS_LR_FACTOR = 2
_C.SOLVER.MOMENTUM = 0.9
_C.SOLVER.MARGIN = 0.3
_C.SOLVER.WEIGHT_DECAY = 0.0005
_C.SOLVER.WEIGHT_DECAY_BIAS = 0.
_C.SOLVER.GAMMA = 0.1
_C.SOLVER.STEPS = (30, 55)
_C.SOLVER.WARMUP_FACTOR = 1.0 / 3
_C.SOLVER.WARMUP_ITERS = 500
_C.SOLVER.WARMUP_METHOD = "linear"
_C.SOLVER.CHECKPOINT_PERIOD = 50
_C.SOLVER.LOG_PERIOD = 100
_C.SOLVER.EVAL_PERIOD = 50
# Number of images per batch
# This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will
# see 2 images per batch
_C.SOLVER.IMS_PER_BATCH = 64
# This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will
# see 2 images per batch
_C.TEST = CN()
_C.TEST.IMS_PER_BATCH = 128
_C.TEST.WEIGHT = ""
# ---------------------------------------------------------------------------- #
# Misc options
# ---------------------------------------------------------------------------- #
_C.OUTPUT_DIR = ""

View File

@ -1,39 +0,0 @@
# configuration for training market1501
dataset:
name: market1501
aug:
resize_size: [384, 128]
random_mirror: True
pad: 10
random_crop: True
random_erasing: True
train:
optimizer: 'Adam'
lr: 0.00035
num_epochs: 80
batch_size: 32
sampler: 'softmax'
wd: 0.0005
step: [30, 55]
factor: 0.1
warmup_epoch: 5
warmup_begin_lr: 0.0000035
loss_fn: 'softmax'
test:
batch_size: 128
network:
name: 'Baseline'
last_stride: 1
gpus: '0'
misc:
eval_step: 20
save_step: 20
log_interval: 100

View File

@ -1,41 +0,0 @@
# configuration for training market1501
dataset:
name: market1501
aug:
resize_size: [384, 128]
random_mirror: True
pad: 10
random_crop: True
random_erasing: True
train:
optimizer: 'Adam'
lr: 0.00035
num_epochs: 400
p_size: 16
k_size: 4
sampler: 'triplet'
wd: 0.0005
step: [80, 180, 300]
factor: 0.1
warmup_epoch: 20
warmup_begin_lr: 0.0000035
loss_fn: 'softmax_triplet'
test:
batch_size: 128
network:
name: 'Baseline'
last_stride: 1
gpus: '1'
misc:
eval_step: 50
save_step: 50
log_interval: 20

View File

@ -1,40 +0,0 @@
# configuration for training market1501
dataset:
name: market1501
aug:
resize_size: [384, 128]
random_mirror: True
pad: 10
random_crop: True
train:
optimizer: 'Adam'
lr: 0.00035
num_epochs: 400
p_size: 32
k_size: 4
sampler: 'triplet'
wd: 0.0005
step: [80, 180, 300]
factor: 0.1
warmup_epoch: 20
warmup_begin_lr: 0.0000035
loss_fn: 'triplet'
test:
batch_size: 128
network:
name: 'Baseline'
last_stride: 1
gpus: '1'
misc:
eval_step: 50
save_step: 50
log_interval: 20

43
configs/softmax.yml Normal file
View File

@ -0,0 +1,43 @@
MODEL:
PRETRAIN_PATH: '/export/home/lxy/.torch/models/resnet50-19c8e357.pth'
INPUT:
SIZE_TRAIN: [384, 128]
SIZE_TEST: [384, 128]
PROB: 0.5 # random horizontal flip
PADDING: 10
DATASETS:
NAMES: ('market1501')
DATALOADER:
SAMPLER: 'softmax'
NUM_WORKERS: 8
SOLVER:
OPTIMIZER_NAME: 'Adam'
MAX_EPOCHS: 120
BASE_LR: 0.00035
BIAS_LR_FACTOR: 1
WEIGHT_DECAY: 0.0005
WEIGHT_DECAY_BIAS: 0.0005
IMS_PER_BATCH: 64
STEPS: [30, 55]
GAMMA: 0.1
WARMUP_FACTOR: 0.01
WARMUP_ITERS: 5
WARMUP_METHOD: 'linear'
CHECKPOINT_PERIOD: 20
LOG_PERIOD: 100
EVAL_PERIOD: 20
TEST:
IMS_PER_BATCH: 256
OUTPUT_DIR: "/export/home/lxy/CHECKPOINTS/reid/market1501/softmax_bs64_384x128"

View File

@ -0,0 +1,45 @@
MODEL:
PRETRAIN_PATH: '/export/home/lxy/.torch/models/resnet50-19c8e357.pth'
INPUT:
SIZE_TRAIN: [384, 128]
SIZE_TEST: [384, 128]
PROB: 0.5 # random horizontal flip
PADDING: 10
DATASETS:
NAMES: ('market1501')
DATALOADER:
SAMPLER: 'softmax_triplet'
NUM_INSTANCE: 4
NUM_WORKERS: 8
SOLVER:
OPTIMIZER_NAME: 'Adam'
MAX_EPOCHS: 120
BASE_LR: 0.00035
BIAS_LR_FACTOR: 1
WEIGHT_DECAY: 0.0005
WEIGHT_DECAY_BIAS: 0.0005
IMS_PER_BATCH: 64
STEPS: [40, 70]
GAMMA: 0.1
WARMUP_FACTOR: 0.01
WARMUP_ITERS: 10
WARMUP_METHOD: 'linear'
CHECKPOINT_PERIOD: 40
LOG_PERIOD: 100
EVAL_PERIOD: 40
TEST:
IMS_PER_BATCH: 256
WEIGHT: "path"
OUTPUT_DIR: "/export/home/lxy/CHECKPOINTS/reid/market1501/softmax_triplet_bs128_384x128"

View File

@ -1,11 +0,0 @@
# encoding: utf-8
"""
@author: sherlock
@contact: sherlockliao01@gmail.com
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

View File

@ -1,79 +0,0 @@
# encoding: utf-8
"""
@author: sherlock
@contact: sherlockliao01@gmail.com
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import yaml
from easydict import EasyDict as edict
__C = edict()
opt = __C
__C.seed = 0
__C.dataset = edict()
__C.dataset.name = 'market1501'
__C.dataset.num_classes = 751
__C.aug = edict()
__C.aug.resize_size = [256, 128]
__C.aug.color_jitter = False
__C.aug.random_erasing = False
__C.aug.random_mirror = True
__C.aug.pad = 10
__C.aug.random_crop = True
__C.train = edict()
__C.train.optimizer = 'Adam'
__C.train.lr = 3e-4
__C.train.wd = 5e-4
__C.train.momentum = 0.9
__C.train.step = [80, 180, 300]
__C.train.warmup_epoch = 20
__C.train.warmup_begin_lr = 3e-6
__C.train.factor = 0.1
__C.train.margin = 0.3
__C.train.num_epochs = 400
__C.train.sampler = 'softmax'
__C.train.p_size = 32 # number of person in a single gpu
__C.train.k_size = 4 # number of images per person
__C.train.batch_size = 128
__C.train.loss_fn = 'softmax' # softmax, triplet, softmax_triplet
__C.train.triplet_normalize = False
__C.test = edict()
__C.test.batch_size = 128
__C.test.load_path = '/mnt/truenas/scratch/xingyu.liao/DATA/mx-ckpt'
__C.network = edict()
__C.network.depth = 50
__C.network.name = 'Baseline'
__C.network.last_stride = 1
__C.network.gpus = "1"
__C.network.workers = 8
__C.misc = edict()
__C.misc.log_interval = 10
__C.misc.eval_step = 50
__C.misc.save_step = 50
__C.misc.save_dir = ''
def update_config(config_file):
exp_config = None
with open(config_file) as f:
exp_config = edict(yaml.load(f))
for k, v in exp_config.items():
if k in __C:
if isinstance(v, dict):
for vk, vv in v.items():
__C[k][vk] = vv
else:
__C[k] = v
else:
raise ValueError("key must exist in configs.py")

View File

@ -1,126 +0,0 @@
from __future__ import print_function, absolute_import
from collections import defaultdict
import numpy as np
import torchvision.transforms as T
from PIL import Image
from torch.utils.data import Dataset, Sampler, DataLoader
from utils import augmenter
from .data_manager import init_dataset
def read_image(img_path):
"""Keep reading image until succeed.
This can avoid IOError incurred by heavy IO process."""
got_img = False
while not got_img:
try:
img = Image.open(img_path).convert("RGB")
got_img = True
except IOError:
print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path))
pass
return img
class ImageData(Dataset):
def __init__(self, dataset, transform):
self.dataset = dataset
self.transform = transform
def __getitem__(self, item):
img, pid, camid = self.dataset[item]
img = read_image(img)
if self.transform is not None:
img = self.transform(img)
return img, pid, camid
def __len__(self):
return len(self.dataset)
class RandomIdentitySampler(Sampler):
def __init__(self, data_source, num_instances=4):
self.data_source = data_source
self.num_instances = num_instances
self.index_dic = defaultdict(list)
for index, (_, pid, _) in enumerate(data_source):
self.index_dic[pid].append(index)
self.pids = list(self.index_dic.keys())
self.num_identities = len(self.pids)
def __iter__(self):
indices = np.random.permutation(self.num_identities)
ret = []
for i in indices:
pid = self.pids[i]
t = self.index_dic[pid]
replace = False if len(t) >= self.num_instances else True
t = np.random.choice(t, size=self.num_instances, replace=replace)
ret.extend(t)
return iter(ret)
def __len__(self):
return self.num_identities * self.num_instances
def get_data_provider(opt):
num_gpus = (len(opt.network.gpus) + 1) // 2
test_batch_size = opt.test.batch_size * num_gpus
# data augmenter
random_mirror = opt.aug.get('random_mirror', False)
pad = opt.aug.get('pad', False)
random_crop = opt.aug.get('random_crop', False)
random_erasing = opt.aug.get('random_erasing', False)
h, w = opt.aug.resize_size
train_aug = list()
train_aug.append(T.Resize((h, w)))
if random_mirror:
train_aug.append(T.RandomHorizontalFlip())
if pad:
train_aug.append(T.Pad(padding=pad))
if random_crop:
train_aug.append(T.RandomCrop((h, w)))
train_aug.append(T.ToTensor())
train_aug.append(T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
if random_erasing:
train_aug.append(augmenter.RandomErasing())
train_aug = T.Compose(train_aug)
test_aug = list()
test_aug.append(T.Resize((h, w)))
test_aug.append(T.ToTensor())
test_aug.append(T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
test_aug = T.Compose(test_aug)
dataset = init_dataset(opt.dataset.name)
train_set = ImageData(dataset.train, train_aug)
test_set = ImageData(dataset.query + dataset.gallery, test_aug)
if opt.train.sampler == 'softmax':
train_batch_size = opt.train.batch_size * num_gpus
train_loader = DataLoader(train_set, batch_size=train_batch_size, shuffle=True,
num_workers=opt.network.workers, pin_memory=True, drop_last=True)
elif opt.train.sampler == 'triplet':
train_batch_size = opt.train.p_size * num_gpus * opt.train.k_size
train_loader = DataLoader(train_set, batch_size=train_batch_size,
sampler=RandomIdentitySampler(dataset.train, opt.train.k_size),
num_workers=opt.network.workers, pin_memory=True)
else:
raise ValueError('sampler must be softmax or triplet, but get {}'.format(opt.train.sampler))
test_loader = DataLoader(test_set, batch_size=test_batch_size, num_workers=opt.network.workers, pin_memory=True)
return train_loader, test_loader, len(dataset.query) # return number of query
if __name__ == "__main__":
from config import opt
train_loader, test_loader, num_query = get_data_provider(opt)
from IPython import embed
embed()

View File

@ -1,187 +0,0 @@
# encoding: utf-8
"""
@author: sherlock
@contact: sherlockliao01@gmail.com
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import logging
import time
import numpy as np
import torch
from utils.meters import AverageMeter
from utils.serialization import save_checkpoint
class Solver(object):
def __init__(self, opt, net):
self.opt = opt
self.net = net
self.loss = AverageMeter('loss')
self.acc = AverageMeter('acc')
def fit(self, train_data, test_data, num_query, optimizer, criterion, lr_scheduler):
best_rank1 = -np.inf
for epoch in range(self.opt.train.num_epochs):
self.loss.reset()
self.acc.reset()
self.net.train()
# update learning rate
lr = lr_scheduler.update(epoch)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
logging.info('Epoch [{}] learning rate update to {:.3e}'.format(epoch, lr))
tic = time.time()
btic = time.time()
for i, inputs in enumerate(train_data):
data, pids, _ = inputs
label = pids.cuda()
score, feat = self.net(data)
loss = criterion(score, feat, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
self.loss.update(loss.item())
acc = (score.max(1)[1] == label.long()).float().mean().item()
self.acc.update(acc)
log_interval = self.opt.misc.log_interval
if log_interval and not (i + 1) % log_interval:
loss_name, loss_value = self.loss.get()
metric_name, metric_value = self.acc.get()
logging.info('Epoch[%d] Batch [%d]\tSpeed: %f samples/sec\t%s=%f\t'
'%s=%f' % (
epoch, i + 1, train_data.batch_size * log_interval / (time.time() - btic),
loss_name, loss_value,
metric_name, metric_value
))
btic = time.time()
loss_name, loss_value = self.loss.get()
metric_name, metric_value = self.acc.get()
throughput = int(train_data.batch_size * len(train_data) / (time.time() - tic))
logging.info('[Epoch %d] training: %s=%f\t%s=%f' % (
epoch, loss_name, loss_value, metric_name, metric_value))
logging.info('[Epoch %d] speed: %d samples/sec\ttime cost: %f' % (epoch, throughput, time.time() - tic))
is_best = False
if test_data is not None and self.opt.misc.eval_step and not (epoch + 1) % self.opt.misc.eval_step:
rank1 = self.test_func(test_data, num_query)
is_best = rank1 > best_rank1
if is_best:
best_rank1 = rank1
state_dict = self.net.module.state_dict()
if not (epoch + 1) % self.opt.misc.save_step:
save_checkpoint({
'state_dict': state_dict,
'epoch': epoch + 1,
}, is_best=is_best, save_dir=self.opt.misc.save_dir,
filename=self.opt.network.name + '.pth.tar')
def test_func(self, test_data, num_query):
self.net.eval()
feat, person, camera = list(), list(), list()
for inputs in test_data:
data, pids, camids = inputs
with torch.no_grad():
outputs = self.net(data).cpu()
feat.append(outputs)
person.extend(pids.numpy())
camera.extend(camids.numpy())
feat = torch.cat(feat, 0)
qf = feat[:num_query]
q_pids = np.asarray(person[:num_query])
q_camids = np.asarray(camera[:num_query])
gf = feat[num_query:]
g_pids = np.asarray(person[num_query:])
g_camids = np.asarray(camera[num_query:])
logging.info("Extracted features for query set, obtained {}-by-{} matrix".format(
qf.shape[0], qf.shape[1]))
logging.info("Extracted features for gallery set, obtained {}-by-{} matrix".format(
gf.shape[0], gf.shape[1]))
logging.info("Computing distance matrix")
m, n = qf.shape[0], gf.shape[0]
distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \
torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
distmat.addmm_(1, -2, qf, gf.t())
distmat = distmat.numpy()
logging.info("Computing CMC and mAP")
cmc, mAP = self.eval_func(distmat, q_pids, g_pids, q_camids, g_camids)
print("Results ----------")
print("mAP: {:.1%}".format(mAP))
print("CMC curve")
for r in [1, 5, 10]:
print("Rank-{:<3}: {:.1%}".format(r, cmc[r - 1]))
print("------------------")
return cmc[0]
@staticmethod
def eval_func(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50):
"""Evaluation with market1501 metric
Key: for each query identity, its gallery images from the same camera view are discarded.
"""
num_q, num_g = distmat.shape
if num_g < max_rank:
max_rank = num_g
print("Note: number of gallery samples is quite small, got {}".format(num_g))
indices = np.argsort(distmat, axis=1)
matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)
# compute cmc curve for each query
all_cmc = []
all_AP = []
num_valid_q = 0. # number of valid query
for q_idx in range(num_q):
# get query pid and camid
q_pid = q_pids[q_idx]
q_camid = q_camids[q_idx]
# remove gallery samples that have the same pid and camid with query
order = indices[q_idx]
remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid)
keep = np.invert(remove)
# compute cmc curve
# binary vector, positions with value 1 are correct matches
orig_cmc = matches[q_idx][keep]
if not np.any(orig_cmc):
# this condition is true when query identity does not appear in gallery
continue
cmc = orig_cmc.cumsum()
cmc[cmc > 1] = 1
all_cmc.append(cmc[:max_rank])
num_valid_q += 1.
# compute average precision
# reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision
num_rel = orig_cmc.sum()
tmp_cmc = orig_cmc.cumsum()
tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)]
tmp_cmc = np.asarray(tmp_cmc) * orig_cmc
AP = tmp_cmc.sum() / num_rel
all_AP.append(AP)
assert num_valid_q > 0, "Error: all query identities do not appear in gallery"
all_cmc = np.asarray(all_cmc).astype(np.float32)
all_cmc = all_cmc.sum(0) / num_valid_q
mAP = np.mean(all_AP)
return all_cmc, mAP

7
data/__init__.py Normal file
View File

@ -0,0 +1,7 @@
# encoding: utf-8
"""
@author: sherlock
@contact: sherlockliao01@gmail.com
"""
from .build import make_data_loader

44
data/build.py Normal file
View File

@ -0,0 +1,44 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
from torch.utils.data import DataLoader
from .collate_batch import train_collate_fn, val_collate_fn
from .datasets import init_dataset, ImageDataset
from .samplers import RandomIdentitySampler
from .transforms import build_transforms
def make_data_loader(cfg):
train_transforms = build_transforms(cfg, is_train=True)
val_transforms = build_transforms(cfg, is_train=False)
num_workers = cfg.DATALOADER.NUM_WORKERS
if len(cfg.DATASETS.NAMES) == 1:
dataset = init_dataset(cfg.DATASETS.NAMES)
else:
# TODO: add multi dataset to train
dataset = init_dataset(cfg.DATASETS.NAMES)
num_classes = dataset.num_train_pids
train_set = ImageDataset(dataset.train, train_transforms)
if cfg.DATALOADER.SAMPLER == 'softmax':
train_loader = DataLoader(
train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH, shuffle=True, num_workers=num_workers,
collate_fn=train_collate_fn
)
else:
train_loader = DataLoader(
train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH,
sampler=RandomIdentitySampler(dataset.train, cfg.SOLVER.IMS_PER_BATCH, cfg.DATALOADER.NUM_INSTANCE),
num_workers=num_workers, collate_fn=train_collate_fn
)
val_set = ImageDataset(dataset.query + dataset.gallery, val_transforms)
val_loader = DataLoader(
val_set, batch_size=cfg.TEST.IMS_PER_BATCH, shuffle=False, num_workers=num_workers,
collate_fn=val_collate_fn
)
return train_loader, val_loader, len(dataset.query), num_classes

18
data/collate_batch.py Normal file
View File

@ -0,0 +1,18 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import torch
def train_collate_fn(batch):
imgs, pids, _, _, = zip(*batch)
pids = torch.tensor(pids, dtype=torch.int64)
return torch.stack(imgs, dim=0), pids
def val_collate_fn(batch):
imgs, pids, camids, _ = zip(*batch)
return torch.stack(imgs, dim=0), pids, camids

25
data/datasets/__init__.py Normal file
View File

@ -0,0 +1,25 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
from .cuhk03 import CUHK03
from .dukemtmcreid import DukeMTMCreID
from .market1501 import Market1501
from .dataset_loader import ImageDataset
__factory = {
'market1501': Market1501,
'cuhk03': CUHK03,
'dukemtmc': DukeMTMCreID
}
def get_names():
return __factory.keys()
def init_dataset(name, *args, **kwargs):
if name not in __factory.keys():
raise KeyError("Unknown datasets: {}".format(name))
return __factory[name](*args, **kwargs)

95
data/datasets/bases.py Normal file
View File

@ -0,0 +1,95 @@
# encoding: utf-8
"""
@author: sherlock
@contact: sherlockliao01@gmail.com
"""
import numpy as np
class BaseDataset(object):
"""
Base class of reid dataset
"""
def get_imagedata_info(self, data):
pids, cams = [], []
for _, pid, camid in data:
pids += [pid]
cams += [camid]
pids = set(pids)
cams = set(cams)
num_pids = len(pids)
num_cams = len(cams)
num_imgs = len(data)
return num_pids, num_imgs, num_cams
def get_videodata_info(self, data, return_tracklet_stats=False):
pids, cams, tracklet_stats = [], [], []
for img_paths, pid, camid in data:
pids += [pid]
cams += [camid]
tracklet_stats += [len(img_paths)]
pids = set(pids)
cams = set(cams)
num_pids = len(pids)
num_cams = len(cams)
num_tracklets = len(data)
if return_tracklet_stats:
return num_pids, num_tracklets, num_cams, tracklet_stats
return num_pids, num_tracklets, num_cams
def print_dataset_statistics(self):
raise NotImplementedError
class BaseImageDataset(BaseDataset):
"""
Base class of image reid dataset
"""
def print_dataset_statistics(self, train, query, gallery):
num_train_pids, num_train_imgs, num_train_cams = self.get_imagedata_info(train)
num_query_pids, num_query_imgs, num_query_cams = self.get_imagedata_info(query)
num_gallery_pids, num_gallery_imgs, num_gallery_cams = self.get_imagedata_info(gallery)
print("Dataset statistics:")
print(" ----------------------------------------")
print(" subset | # ids | # images | # cameras")
print(" ----------------------------------------")
print(" train | {:5d} | {:8d} | {:9d}".format(num_train_pids, num_train_imgs, num_train_cams))
print(" query | {:5d} | {:8d} | {:9d}".format(num_query_pids, num_query_imgs, num_query_cams))
print(" gallery | {:5d} | {:8d} | {:9d}".format(num_gallery_pids, num_gallery_imgs, num_gallery_cams))
print(" ----------------------------------------")
class BaseVideoDataset(BaseDataset):
"""
Base class of video reid dataset
"""
def print_dataset_statistics(self, train, query, gallery):
num_train_pids, num_train_tracklets, num_train_cams, train_tracklet_stats = \
self.get_videodata_info(train, return_tracklet_stats=True)
num_query_pids, num_query_tracklets, num_query_cams, query_tracklet_stats = \
self.get_videodata_info(query, return_tracklet_stats=True)
num_gallery_pids, num_gallery_tracklets, num_gallery_cams, gallery_tracklet_stats = \
self.get_videodata_info(gallery, return_tracklet_stats=True)
tracklet_stats = train_tracklet_stats + query_tracklet_stats + gallery_tracklet_stats
min_num = np.min(tracklet_stats)
max_num = np.max(tracklet_stats)
avg_num = np.mean(tracklet_stats)
print("Dataset statistics:")
print(" -------------------------------------------")
print(" subset | # ids | # tracklets | # cameras")
print(" -------------------------------------------")
print(" train | {:5d} | {:11d} | {:9d}".format(num_train_pids, num_train_tracklets, num_train_cams))
print(" query | {:5d} | {:11d} | {:9d}".format(num_query_pids, num_query_tracklets, num_query_cams))
print(" gallery | {:5d} | {:11d} | {:9d}".format(num_gallery_pids, num_gallery_tracklets, num_gallery_cams))
print(" -------------------------------------------")
print(" number of images per tracklet: {} ~ {}, average {:.2f}".format(min_num, max_num, avg_num))
print(" -------------------------------------------")

259
data/datasets/cuhk03.py Normal file
View File

@ -0,0 +1,259 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: liaoxingyu2@jd.com
"""
import h5py
import os.path as osp
from scipy.io import loadmat
from scipy.misc import imsave
from utils.iotools import mkdir_if_missing, write_json, read_json
from .bases import BaseImageDataset
class CUHK03(BaseImageDataset):
"""
CUHK03
Reference:
Li et al. DeepReID: Deep Filter Pairing Neural Network for Person Re-identification. CVPR 2014.
URL: http://www.ee.cuhk.edu.hk/~xgwang/CUHK_identification.html#!
Dataset statistics:
# identities: 1360
# images: 13164
# cameras: 6
# splits: 20 (classic)
Args:
split_id (int): split index (default: 0)
cuhk03_labeled (bool): whether to load labeled images; if false, detected images are loaded (default: False)
"""
dataset_dir = 'cuhk03'
def __init__(self, root='/export/home/lxy/DATA/reid', split_id=0, cuhk03_labeled=False,
cuhk03_classic_split=False, verbose=True,
**kwargs):
super(CUHK03, self).__init__()
self.dataset_dir = osp.join(root, self.dataset_dir)
self.data_dir = osp.join(self.dataset_dir, 'cuhk03_release')
self.raw_mat_path = osp.join(self.data_dir, 'cuhk-03.mat')
self.imgs_detected_dir = osp.join(self.dataset_dir, 'images_detected')
self.imgs_labeled_dir = osp.join(self.dataset_dir, 'images_labeled')
self.split_classic_det_json_path = osp.join(self.dataset_dir, 'splits_classic_detected.json')
self.split_classic_lab_json_path = osp.join(self.dataset_dir, 'splits_classic_labeled.json')
self.split_new_det_json_path = osp.join(self.dataset_dir, 'splits_new_detected.json')
self.split_new_lab_json_path = osp.join(self.dataset_dir, 'splits_new_labeled.json')
self.split_new_det_mat_path = osp.join(self.dataset_dir, 'cuhk03_new_protocol_config_detected.mat')
self.split_new_lab_mat_path = osp.join(self.dataset_dir, 'cuhk03_new_protocol_config_labeled.mat')
self._check_before_run()
self._preprocess()
if cuhk03_labeled:
image_type = 'labeled'
split_path = self.split_classic_lab_json_path if cuhk03_classic_split else self.split_new_lab_json_path
else:
image_type = 'detected'
split_path = self.split_classic_det_json_path if cuhk03_classic_split else self.split_new_det_json_path
splits = read_json(split_path)
assert split_id < len(splits), "Condition split_id ({}) < len(splits) ({}) is false".format(split_id,
len(splits))
split = splits[split_id]
print("Split index = {}".format(split_id))
train = split['train']
query = split['query']
gallery = split['gallery']
if verbose:
print("=> CUHK03 ({}) loaded".format(image_type))
self.print_dataset_statistics(train, query, gallery)
self.train = train
self.query = query
self.gallery = gallery
self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)
self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)
self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)
def _check_before_run(self):
"""Check if all files are available before going deeper"""
if not osp.exists(self.dataset_dir):
raise RuntimeError("'{}' is not available".format(self.dataset_dir))
if not osp.exists(self.data_dir):
raise RuntimeError("'{}' is not available".format(self.data_dir))
if not osp.exists(self.raw_mat_path):
raise RuntimeError("'{}' is not available".format(self.raw_mat_path))
if not osp.exists(self.split_new_det_mat_path):
raise RuntimeError("'{}' is not available".format(self.split_new_det_mat_path))
if not osp.exists(self.split_new_lab_mat_path):
raise RuntimeError("'{}' is not available".format(self.split_new_lab_mat_path))
def _preprocess(self):
"""
This function is a bit complex and ugly, what it does is
1. Extract data from cuhk-03.mat and save as png images.
2. Create 20 classic splits. (Li et al. CVPR'14)
3. Create new split. (Zhong et al. CVPR'17)
"""
print(
"Note: if root path is changed, the previously generated json files need to be re-generated (delete them first)")
if osp.exists(self.imgs_labeled_dir) and \
osp.exists(self.imgs_detected_dir) and \
osp.exists(self.split_classic_det_json_path) and \
osp.exists(self.split_classic_lab_json_path) and \
osp.exists(self.split_new_det_json_path) and \
osp.exists(self.split_new_lab_json_path):
return
mkdir_if_missing(self.imgs_detected_dir)
mkdir_if_missing(self.imgs_labeled_dir)
print("Extract image data from {} and save as png".format(self.raw_mat_path))
mat = h5py.File(self.raw_mat_path, 'r')
def _deref(ref):
return mat[ref][:].T
def _process_images(img_refs, campid, pid, save_dir):
img_paths = [] # Note: some persons only have images for one view
for imgid, img_ref in enumerate(img_refs):
img = _deref(img_ref)
# skip empty cell
if img.size == 0 or img.ndim < 3: continue
# images are saved with the following format, index-1 (ensure uniqueness)
# campid: index of camera pair (1-5)
# pid: index of person in 'campid'-th camera pair
# viewid: index of view, {1, 2}
# imgid: index of image, (1-10)
viewid = 1 if imgid < 5 else 2
img_name = '{:01d}_{:03d}_{:01d}_{:02d}.png'.format(campid + 1, pid + 1, viewid, imgid + 1)
img_path = osp.join(save_dir, img_name)
if not osp.isfile(img_path):
imsave(img_path, img)
img_paths.append(img_path)
return img_paths
def _extract_img(name):
print("Processing {} images (extract and save) ...".format(name))
meta_data = []
imgs_dir = self.imgs_detected_dir if name == 'detected' else self.imgs_labeled_dir
for campid, camp_ref in enumerate(mat[name][0]):
camp = _deref(camp_ref)
num_pids = camp.shape[0]
for pid in range(num_pids):
img_paths = _process_images(camp[pid, :], campid, pid, imgs_dir)
assert len(img_paths) > 0, "campid{}-pid{} has no images".format(campid, pid)
meta_data.append((campid + 1, pid + 1, img_paths))
print("- done camera pair {} with {} identities".format(campid + 1, num_pids))
return meta_data
meta_detected = _extract_img('detected')
meta_labeled = _extract_img('labeled')
def _extract_classic_split(meta_data, test_split):
train, test = [], []
num_train_pids, num_test_pids = 0, 0
num_train_imgs, num_test_imgs = 0, 0
for i, (campid, pid, img_paths) in enumerate(meta_data):
if [campid, pid] in test_split:
for img_path in img_paths:
camid = int(osp.basename(img_path).split('_')[2]) - 1 # make it 0-based
test.append((img_path, num_test_pids, camid))
num_test_pids += 1
num_test_imgs += len(img_paths)
else:
for img_path in img_paths:
camid = int(osp.basename(img_path).split('_')[2]) - 1 # make it 0-based
train.append((img_path, num_train_pids, camid))
num_train_pids += 1
num_train_imgs += len(img_paths)
return train, num_train_pids, num_train_imgs, test, num_test_pids, num_test_imgs
print("Creating classic splits (# = 20) ...")
splits_classic_det, splits_classic_lab = [], []
for split_ref in mat['testsets'][0]:
test_split = _deref(split_ref).tolist()
# create split for detected images
train, num_train_pids, num_train_imgs, test, num_test_pids, num_test_imgs = \
_extract_classic_split(meta_detected, test_split)
splits_classic_det.append({
'train': train, 'query': test, 'gallery': test,
'num_train_pids': num_train_pids, 'num_train_imgs': num_train_imgs,
'num_query_pids': num_test_pids, 'num_query_imgs': num_test_imgs,
'num_gallery_pids': num_test_pids, 'num_gallery_imgs': num_test_imgs,
})
# create split for labeled images
train, num_train_pids, num_train_imgs, test, num_test_pids, num_test_imgs = \
_extract_classic_split(meta_labeled, test_split)
splits_classic_lab.append({
'train': train, 'query': test, 'gallery': test,
'num_train_pids': num_train_pids, 'num_train_imgs': num_train_imgs,
'num_query_pids': num_test_pids, 'num_query_imgs': num_test_imgs,
'num_gallery_pids': num_test_pids, 'num_gallery_imgs': num_test_imgs,
})
write_json(splits_classic_det, self.split_classic_det_json_path)
write_json(splits_classic_lab, self.split_classic_lab_json_path)
def _extract_set(filelist, pids, pid2label, idxs, img_dir, relabel):
tmp_set = []
unique_pids = set()
for idx in idxs:
img_name = filelist[idx][0]
camid = int(img_name.split('_')[2]) - 1 # make it 0-based
pid = pids[idx]
if relabel: pid = pid2label[pid]
img_path = osp.join(img_dir, img_name)
tmp_set.append((img_path, int(pid), camid))
unique_pids.add(pid)
return tmp_set, len(unique_pids), len(idxs)
def _extract_new_split(split_dict, img_dir):
train_idxs = split_dict['train_idx'].flatten() - 1 # index-0
pids = split_dict['labels'].flatten()
train_pids = set(pids[train_idxs])
pid2label = {pid: label for label, pid in enumerate(train_pids)}
query_idxs = split_dict['query_idx'].flatten() - 1
gallery_idxs = split_dict['gallery_idx'].flatten() - 1
filelist = split_dict['filelist'].flatten()
train_info = _extract_set(filelist, pids, pid2label, train_idxs, img_dir, relabel=True)
query_info = _extract_set(filelist, pids, pid2label, query_idxs, img_dir, relabel=False)
gallery_info = _extract_set(filelist, pids, pid2label, gallery_idxs, img_dir, relabel=False)
return train_info, query_info, gallery_info
print("Creating new splits for detected images (767/700) ...")
train_info, query_info, gallery_info = _extract_new_split(
loadmat(self.split_new_det_mat_path),
self.imgs_detected_dir,
)
splits = [{
'train': train_info[0], 'query': query_info[0], 'gallery': gallery_info[0],
'num_train_pids': train_info[1], 'num_train_imgs': train_info[2],
'num_query_pids': query_info[1], 'num_query_imgs': query_info[2],
'num_gallery_pids': gallery_info[1], 'num_gallery_imgs': gallery_info[2],
}]
write_json(splits, self.split_new_det_json_path)
print("Creating new splits for labeled images (767/700) ...")
train_info, query_info, gallery_info = _extract_new_split(
loadmat(self.split_new_lab_mat_path),
self.imgs_labeled_dir,
)
splits = [{
'train': train_info[0], 'query': query_info[0], 'gallery': gallery_info[0],
'num_train_pids': train_info[1], 'num_train_imgs': train_info[2],
'num_query_pids': query_info[1], 'num_query_imgs': query_info[2],
'num_gallery_pids': gallery_info[1], 'num_gallery_imgs': gallery_info[2],
}]
write_json(splits, self.split_new_lab_json_path)

View File

@ -0,0 +1,45 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import os.path as osp
from PIL import Image
from torch.utils.data import Dataset
def read_image(img_path):
"""Keep reading image until succeed.
This can avoid IOError incurred by heavy IO process."""
got_img = False
if not osp.exists(img_path):
raise IOError("{} does not exist".format(img_path))
while not got_img:
try:
img = Image.open(img_path).convert('RGB')
got_img = True
except IOError:
print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path))
pass
return img
class ImageDataset(Dataset):
"""Image Person ReID Dataset"""
def __init__(self, dataset, transform=None):
self.dataset = dataset
self.transform = transform
def __len__(self):
return len(self.dataset)
def __getitem__(self, index):
img_path, pid, camid = self.dataset[index]
img = read_image(img_path)
if self.transform is not None:
img = self.transform(img)
return img, pid, camid, img_path

View File

@ -0,0 +1,106 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: liaoxingyu2@jd.com
"""
import glob
import re
import urllib
import zipfile
import os.path as osp
from utils.iotools import mkdir_if_missing
from .bases import BaseImageDataset
class DukeMTMCreID(BaseImageDataset):
"""
DukeMTMC-reID
Reference:
1. Ristani et al. Performance Measures and a Data Set for Multi-Target, Multi-Camera Tracking. ECCVW 2016.
2. Zheng et al. Unlabeled Samples Generated by GAN Improve the Person Re-identification Baseline in vitro. ICCV 2017.
URL: https://github.com/layumi/DukeMTMC-reID_evaluation
Dataset statistics:
# identities: 1404 (train + query)
# images:16522 (train) + 2228 (query) + 17661 (gallery)
# cameras: 8
"""
dataset_dir = 'dukemtmc-reid'
def __init__(self, root='/export/home/lxy/DATA/reid', verbose=True, **kwargs):
super(DukeMTMCreID, self).__init__()
self.dataset_dir = osp.join(root, self.dataset_dir)
self.dataset_url = 'http://vision.cs.duke.edu/DukeMTMC/data/misc/DukeMTMC-reID.zip'
self.train_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/bounding_box_train')
self.query_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/query')
self.gallery_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/bounding_box_test')
self._download_data()
self._check_before_run()
train = self._process_dir(self.train_dir, relabel=True)
query = self._process_dir(self.query_dir, relabel=False)
gallery = self._process_dir(self.gallery_dir, relabel=False)
if verbose:
print("=> DukeMTMC-reID loaded")
self.print_dataset_statistics(train, query, gallery)
self.train = train
self.query = query
self.gallery = gallery
self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)
self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)
self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)
def _download_data(self):
if osp.exists(self.dataset_dir):
print("This dataset has been downloaded.")
return
print("Creating directory {}".format(self.dataset_dir))
mkdir_if_missing(self.dataset_dir)
fpath = osp.join(self.dataset_dir, osp.basename(self.dataset_url))
print("Downloading DukeMTMC-reID dataset")
urllib.urlretrieve(self.dataset_url, fpath)
print("Extracting files")
zip_ref = zipfile.ZipFile(fpath, 'r')
zip_ref.extractall(self.dataset_dir)
zip_ref.close()
def _check_before_run(self):
"""Check if all files are available before going deeper"""
if not osp.exists(self.dataset_dir):
raise RuntimeError("'{}' is not available".format(self.dataset_dir))
if not osp.exists(self.train_dir):
raise RuntimeError("'{}' is not available".format(self.train_dir))
if not osp.exists(self.query_dir):
raise RuntimeError("'{}' is not available".format(self.query_dir))
if not osp.exists(self.gallery_dir):
raise RuntimeError("'{}' is not available".format(self.gallery_dir))
def _process_dir(self, dir_path, relabel=False):
img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
pattern = re.compile(r'([-\d]+)_c(\d)')
pid_container = set()
for img_path in img_paths:
pid, _ = map(int, pattern.search(img_path).groups())
pid_container.add(pid)
pid2label = {pid: label for label, pid in enumerate(pid_container)}
dataset = []
for img_path in img_paths:
pid, camid = map(int, pattern.search(img_path).groups())
assert 1 <= camid <= 8
camid -= 1 # index starts from 0
if relabel: pid = pid2label[pid]
dataset.append((img_path, pid, camid))
return dataset

View File

@ -0,0 +1,63 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import numpy as np
def eval_func(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50):
"""Evaluation with market1501 metric
Key: for each query identity, its gallery images from the same camera view are discarded.
"""
num_q, num_g = distmat.shape
if num_g < max_rank:
max_rank = num_g
print("Note: number of gallery samples is quite small, got {}".format(num_g))
indices = np.argsort(distmat, axis=1)
matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)
# compute cmc curve for each query
all_cmc = []
all_AP = []
num_valid_q = 0. # number of valid query
for q_idx in range(num_q):
# get query pid and camid
q_pid = q_pids[q_idx]
q_camid = q_camids[q_idx]
# remove gallery samples that have the same pid and camid with query
order = indices[q_idx]
remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid)
keep = np.invert(remove)
# compute cmc curve
# binary vector, positions with value 1 are correct matches
orig_cmc = matches[q_idx][keep]
if not np.any(orig_cmc):
# this condition is true when query identity does not appear in gallery
continue
cmc = orig_cmc.cumsum()
cmc[cmc > 1] = 1
all_cmc.append(cmc[:max_rank])
num_valid_q += 1.
# compute average precision
# reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision
num_rel = orig_cmc.sum()
tmp_cmc = orig_cmc.cumsum()
tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)]
tmp_cmc = np.asarray(tmp_cmc) * orig_cmc
AP = tmp_cmc.sum() / num_rel
all_AP.append(AP)
assert num_valid_q > 0, "Error: all query identities do not appear in gallery"
all_cmc = np.asarray(all_cmc).astype(np.float32)
all_cmc = all_cmc.sum(0) / num_valid_q
mAP = np.mean(all_AP)
return all_cmc, mAP

70
core/data_manager.py → data/datasets/market1501.py Executable file → Normal file
View File

@ -1,13 +1,18 @@
from __future__ import print_function, absolute_import
# encoding: utf-8
"""
@author: sherlock
@contact: sherlockliao01@gmail.com
"""
import glob
import re
from os import path as osp
"""Dataset classes"""
import os.path as osp
from .bases import BaseImageDataset
class Market1501(object):
class Market1501(BaseImageDataset):
"""
Market1501
Reference:
@ -18,9 +23,10 @@ class Market1501(object):
# identities: 1501 (+1 for background)
# images: 12936 (train) + 3368 (query) + 15913 (gallery)
"""
dataset_dir = 'Market-1501-v15.09.15'
dataset_dir = 'market1501'
def __init__(self, root='/home/test2/DATA/market1501/raw/'):
def __init__(self, root='/export/home/lxy/DATA/reid', verbose=True, **kwargs):
super(Market1501, self).__init__()
self.dataset_dir = osp.join(root, self.dataset_dir)
self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train')
self.query_dir = osp.join(self.dataset_dir, 'query')
@ -28,31 +34,21 @@ class Market1501(object):
self._check_before_run()
train, num_train_pids, num_train_imgs = self._process_dir(self.train_dir, relabel=True)
query, num_query_pids, num_query_imgs = self._process_dir(self.query_dir, relabel=False)
gallery, num_gallery_pids, num_gallery_imgs = self._process_dir(self.gallery_dir, relabel=False)
num_total_pids = num_train_pids + num_query_pids
num_total_imgs = num_train_imgs + num_query_imgs + num_gallery_imgs
train = self._process_dir(self.train_dir, relabel=True)
query = self._process_dir(self.query_dir, relabel=False)
gallery = self._process_dir(self.gallery_dir, relabel=False)
print("=> Market1501 loaded")
print("Dataset statistics:")
print(" ------------------------------")
print(" subset | # ids | # images")
print(" ------------------------------")
print(" train | {:5d} | {:8d}".format(num_train_pids, num_train_imgs))
print(" query | {:5d} | {:8d}".format(num_query_pids, num_query_imgs))
print(" gallery | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_imgs))
print(" ------------------------------")
print(" total | {:5d} | {:8d}".format(num_total_pids, num_total_imgs))
print(" ------------------------------")
if verbose:
print("=> Market1501 loaded")
self.print_dataset_statistics(train, query, gallery)
self.train = train
self.query = query
self.gallery = gallery
self.num_train_pids = num_train_pids
self.num_query_pids = num_query_pids
self.num_gallery_pids = num_gallery_pids
self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)
self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)
self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)
def _check_before_run(self):
"""Check if all files are available before going deeper"""
@ -79,31 +75,11 @@ class Market1501(object):
dataset = []
for img_path in img_paths:
pid, camid = map(int, pattern.search(img_path).groups())
if pid == -1:
continue # junk images are just ignored
if pid == -1: continue # junk images are just ignored
assert 0 <= pid <= 1501 # pid == 0 means background
assert 1 <= camid <= 6
camid -= 1 # index starts from 0
if relabel: pid = pid2label[pid]
dataset.append((img_path, pid, camid))
num_pids = len(pid_container)
num_imgs = len(dataset)
return dataset, num_pids, num_imgs
"""Create datasets"""
__factory = {
'market1501': Market1501
}
def get_names():
return __factory.keys()
def init_dataset(name, *args, **kwargs):
if name not in __factory.keys():
raise KeyError("Unknown datasets: {}".format(name))
return __factory[name](*args, **kwargs)
return dataset

View File

@ -0,0 +1,7 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
from .triplet_sampler import RandomIdentitySampler

View File

@ -0,0 +1,73 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: liaoxingyu2@jd.com
"""
import copy
import random
from collections import defaultdict
import numpy as np
from torch.utils.data.sampler import Sampler
class RandomIdentitySampler(Sampler):
"""
Randomly sample N identities, then for each identity,
randomly sample K instances, therefore batch size is N*K.
Args:
- data_source (list): list of (img_path, pid, camid).
- num_instances (int): number of instances per identity in a batch.
- batch_size (int): number of examples in a batch.
"""
def __init__(self, data_source, batch_size, num_instances):
self.data_source = data_source
self.batch_size = batch_size
self.num_instances = num_instances
self.num_pids_per_batch = self.batch_size // self.num_instances
self.index_dic = defaultdict(list)
for index, (_, pid, _) in enumerate(self.data_source):
self.index_dic[pid].append(index)
self.pids = list(self.index_dic.keys())
# estimate number of examples in an epoch
self.length = 0
for pid in self.pids:
idxs = self.index_dic[pid]
num = len(idxs)
if num < self.num_instances:
num = self.num_instances
self.length += num - num % self.num_instances
def __iter__(self):
batch_idxs_dict = defaultdict(list)
for pid in self.pids:
idxs = copy.deepcopy(self.index_dic[pid])
if len(idxs) < self.num_instances:
idxs = np.random.choice(idxs, size=self.num_instances, replace=True)
random.shuffle(idxs)
batch_idxs = []
for idx in idxs:
batch_idxs.append(idx)
if len(batch_idxs) == self.num_instances:
batch_idxs_dict[pid].append(batch_idxs)
batch_idxs = []
avai_pids = copy.deepcopy(self.pids)
final_idxs = []
while len(avai_pids) >= self.num_pids_per_batch:
selected_pids = random.sample(avai_pids, self.num_pids_per_batch)
for pid in selected_pids:
batch_idxs = batch_idxs_dict[pid].pop(0)
final_idxs.extend(batch_idxs)
if len(batch_idxs_dict[pid]) == 0:
avai_pids.remove(pid)
return iter(final_idxs)
def __len__(self):
return self.length

View File

@ -0,0 +1,7 @@
# encoding: utf-8
"""
@author: sherlock
@contact: sherlockliao01@gmail.com
"""
from .build import build_transforms

31
data/transforms/build.py Normal file
View File

@ -0,0 +1,31 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: liaoxingyu2@jd.com
"""
import torchvision.transforms as T
from .transforms import RandomErasing
def build_transforms(cfg, is_train=True):
normalize_transform = T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD)
if is_train:
transform = T.Compose([
T.Resize(cfg.INPUT.SIZE_TRAIN),
T.RandomHorizontalFlip(p=cfg.INPUT.PROB),
T.Pad(cfg.INPUT.PADDING),
T.RandomCrop(cfg.INPUT.SIZE_TRAIN),
T.ToTensor(),
normalize_transform,
RandomErasing(probability=cfg.INPUT.PROB, mean=cfg.INPUT.PIXEL_MEAN)
])
else:
transform = T.Compose([
T.Resize(cfg.INPUT.SIZE_TEST),
T.ToTensor(),
normalize_transform
])
return transform

View File

@ -1,57 +1,12 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
@contact: liaoxingyu2@jd.com
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import math
import random
from PIL import Image
class Random2DTranslation(object):
"""
With a probability, first increase image size to (1 + 1/8), and then perform random crop.
Args:
height (int): target height.
width (int): target width.
p (float): probability of performing this transformation. Default: 0.5.
"""
def __init__(self, height, width, p=0.5, interpolation=Image.BILINEAR):
self.height = height
self.width = width
self.p = p
self.interpolation = interpolation
def __call__(self, img):
"""
Args:
img (PIL Image): Image to be cropped.
Returns:
PIL Image: Cropped image.
"""
if random.random() < self.p:
return img.resize((self.width, self.height), self.interpolation)
new_width, new_height = int(
round(self.width * 1.125)), int(round(self.height * 1.125))
resized_img = img.resize((new_width, new_height), self.interpolation)
x_maxrange = new_width - self.width
y_maxrange = new_height - self.height
x1 = int(round(random.uniform(0, x_maxrange)))
y1 = int(round(random.uniform(0, y_maxrange)))
croped_img = resized_img.crop(
(x1, y1, x1 + self.width, y1 + self.height))
return croped_img
class RandomErasing(object):
""" Randomly selects a rectangle region in an image and erases its pixels.
@ -65,7 +20,7 @@ class RandomErasing(object):
mean: Erasing value.
"""
def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=[0.4914, 0.4822, 0.4465]):
def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=(0.4914, 0.4822, 0.4465)):
self.probability = probability
self.mean = mean
self.sl = sl

64
engine/inference.py Normal file
View File

@ -0,0 +1,64 @@
# encoding: utf-8
"""
@author: sherlock
@contact: sherlockliao01@gmail.com
"""
import logging
import torch
from ignite.engine import Engine
from utils.reid_metric import R1_mAP
def create_supervised_evaluator(model, metrics,
device=None):
"""
Factory function for creating an evaluator for supervised models
Args:
model (`torch.nn.Module`): the model to train
metrics (dict of str - :class:`ignite.metrics.Metric`): a map of metric names to Metrics
device (str, optional): device type specification (default: None).
Applies to both model and batches.
Returns:
Engine: an evaluator engine with supervised inference function
"""
if device:
model.to(device)
def _inference(engine, batch):
model.eval()
with torch.no_grad():
data, pids, camids = batch
data = data.cuda()
feat = model(data)
return feat, pids, camids
engine = Engine(_inference)
for name, metric in metrics.items():
metric.attach(engine, name)
return engine
def inference(
cfg,
model,
val_loader,
num_query
):
device = cfg.MODEL.DEVICE
logger = logging.getLogger("reid_baseline.inference")
logger.info("Start inferencing")
evaluator = create_supervised_evaluator(model, metrics={'r1_mAP': R1_mAP(num_query)},
device=device)
evaluator.run(val_loader)
cmc, mAP = evaluator.state.metrics['r1_mAP']
logger.info('Validation Results')
logger.info("mAP: {:.1%}".format(mAP))
for r in [1, 5, 10]:
logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1]))

150
engine/trainer.py Normal file
View File

@ -0,0 +1,150 @@
# encoding: utf-8
"""
@author: sherlock
@contact: sherlockliao01@gmail.com
"""
import logging
import torch
from ignite.engine import Engine, Events
from ignite.handlers import ModelCheckpoint, Timer
from ignite.metrics import RunningAverage
from utils.reid_metric import R1_mAP
def create_supervised_trainer(model, optimizer, loss_fn,
device=None):
"""
Factory function for creating a trainer for supervised models
Args:
model (`torch.nn.Module`): the model to train
optimizer (`torch.optim.Optimizer`): the optimizer to use
loss_fn (torch.nn loss function): the loss function to use
device (str, optional): device type specification (default: None).
Applies to both model and batches.
Returns:
Engine: a trainer engine with supervised update function
"""
if device:
model.to(device)
def _update(engine, batch):
model.train()
optimizer.zero_grad()
img, target = batch
img = img.cuda()
target = target.cuda()
score, feat = model(img)
loss = loss_fn(score, feat, target)
loss.backward()
optimizer.step()
# compute acc
acc = (score.max(1)[1] == target).float().mean()
return loss.item(), acc.item()
return Engine(_update)
def create_supervised_evaluator(model, metrics,
device=None):
"""
Factory function for creating an evaluator for supervised models
Args:
model (`torch.nn.Module`): the model to train
metrics (dict of str - :class:`ignite.metrics.Metric`): a map of metric names to Metrics
device (str, optional): device type specification (default: None).
Applies to both model and batches.
Returns:
Engine: an evaluator engine with supervised inference function
"""
if device:
model.to(device)
def _inference(engine, batch):
model.eval()
with torch.no_grad():
data, pids, camids = batch
data = data.cuda()
feat = model(data)
return feat, pids, camids
engine = Engine(_inference)
for name, metric in metrics.items():
metric.attach(engine, name)
return engine
def do_train(
cfg,
model,
train_loader,
val_loader,
optimizer,
scheduler,
loss_fn,
num_query
):
log_period = cfg.SOLVER.LOG_PERIOD
checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
eval_period = cfg.SOLVER.EVAL_PERIOD
output_dir = cfg.OUTPUT_DIR
device = cfg.MODEL.DEVICE
epochs = cfg.SOLVER.MAX_EPOCHS
logger = logging.getLogger("reid_baseline.train")
logger.info("Start training")
trainer = create_supervised_trainer(model, optimizer, loss_fn, device=device)
evaluator = create_supervised_evaluator(model, metrics={'r1_mAP': R1_mAP(num_query)}, device=device)
checkpointer = ModelCheckpoint(output_dir, cfg.MODEL.NAME, checkpoint_period, n_saved=10, require_empty=False)
timer = Timer(average=True)
trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'model': model.state_dict(),
'optimizer': optimizer.state_dict()})
timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED,
pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED)
# average metric to attach on trainer
RunningAverage(output_transform=lambda x: x[0]).attach(trainer, 'avg_loss')
RunningAverage(output_transform=lambda x: x[1]).attach(trainer, 'avg_acc')
@trainer.on(Events.EPOCH_STARTED)
def adjust_learning_rate(engine):
scheduler.step()
@trainer.on(Events.ITERATION_COMPLETED)
def log_training_loss(engine):
iter = (engine.state.iteration - 1) % len(train_loader) + 1
if iter % log_period == 0:
logger.info("Epoch[{}] Iteration[{}/{}] Loss: {:.3f}, Acc: {:.3f}, Base Lr: {:.2e}"
.format(engine.state.epoch, iter, len(train_loader),
engine.state.metrics['avg_loss'], engine.state.metrics['avg_acc'],
scheduler.get_lr()[0]))
# adding handlers using `trainer.on` decorator API
@trainer.on(Events.EPOCH_COMPLETED)
def print_times(engine):
logger.info('Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]'
.format(engine.state.epoch, timer.value() * timer.step_count,
train_loader.batch_size / timer.value()))
logger.info('-' * 10)
timer.reset()
@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(engine):
if engine.state.epoch % eval_period == 0:
evaluator.run(val_loader)
cmc, mAP = evaluator.state.metrics['r1_mAP']
logger.info("Validation Results - Epoch: {}".format(engine.state.epoch))
logger.info("mAP: {:.1%}".format(mAP))
for r in [1, 5, 10]:
logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1]))
trainer.run(train_loader, max_epochs=epochs)

28
layers/__init__.py Normal file
View File

@ -0,0 +1,28 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import torch.nn.functional as F
from .triplet_loss import TripletLoss
def make_loss(cfg):
sampler = cfg.DATALOADER.SAMPLER
triplet = TripletLoss(cfg.SOLVER.MARGIN) # triplet loss
if sampler == 'softmax':
def loss_func(score, feat, target):
return F.cross_entropy(score, target)
elif cfg.DATALOADER.SAMPLER == 'triplet':
def loss_func(score, feat, target):
return triplet(feat, target)[0]
elif cfg.DATALOADER.SAMPLER == 'softmax_triplet':
def loss_func(score, feat, target):
return F.cross_entropy(score, target) + triplet(feat, target)[0]
else:
print('expected sampler should be softmax, triplet or softmax_triplet, '
'but got {}'.format(cfg.DATALOADER.SAMPLER))
return loss_func

View File

@ -1,17 +1,10 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: xyliao1993@qq.com
@contact: sherlockliao01@gmail.com
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import torch
from torch import nn
import torch.nn.functional as F
def normalize(x, axis=-1):
@ -121,34 +114,3 @@ class TripletLoss(object):
else:
loss = self.ranking_loss(dist_an - dist_ap, y)
return loss, dist_ap, dist_an
class CrossEntropyLabelSmooth(nn.Module):
"""Cross entropy loss with label smoothing regularizer.
Reference:
Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016.
Equation: y = (1 - epsilon) * y + epsilon / K.
Args:
num_classes (int): number of classes.
epsilon (float): weight.
"""
def __init__(self, num_classes, epsilon=0.1, use_gpu=True):
super(CrossEntropyLabelSmooth, self).__init__()
self.num_classes = num_classes
self.epsilon = epsilon
self.use_gpu = use_gpu
self.logsoftmax = nn.LogSoftmax(dim=1)
def forward(self, inputs, targets):
"""
Args:
inputs: prediction matrix (before softmax) with shape (batch_size, num_classes)
targets: ground truth labels with shape (num_classes)
"""
log_probs = self.logsoftmax(inputs)
targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).cpu(), 1)
if self.use_gpu: targets = targets.cuda()
targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
loss = (- targets * log_probs).mean(0).sum()
return loss

13
modeling/__init__.py Normal file
View File

@ -0,0 +1,13 @@
# encoding: utf-8
"""
@author: sherlock
@contact: sherlockliao01@gmail.com
"""
from .baseline import Baseline
def build_model(cfg, num_classes):
if cfg.MODEL.NAME == 'resnet50':
model = Baseline(num_classes, cfg.MODEL.LAST_STRIDE, cfg.MODEL.PRETRAIN_PATH)
return model

View File

@ -0,0 +1,6 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""

View File

@ -1,17 +1,12 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: liaoxingyu@megvii.com
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import math
import torch as th
import torch
from torch import nn
@ -98,7 +93,7 @@ class ResNet(nn.Module):
return x
def load_param(self, model_path):
param_dict = th.load(model_path)
param_dict = torch.load(model_path)
for i in param_dict:
if 'fc' in i:
continue
@ -112,11 +107,3 @@ class ResNet(nn.Module):
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
if __name__ == "__main__":
net = ResNet(last_stride=2)
import torch
x = net(torch.zeros(1, 3, 256, 128))
print(x.shape)

View File

@ -1,17 +1,12 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: xyliao1993@qq.com
@contact: sherlockliao01@gmail.com
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from torch import nn
from .resnet import ResNet
from .backbones.resnet import ResNet
def weights_init_kaiming(m):
@ -40,11 +35,12 @@ def weights_init_classifier(m):
class Baseline(nn.Module):
in_planes = 2048
def __init__(self, num_classes=10, last_stride=1, model_path='/home/test2/.torch/models/resnet50-19c8e357.pth'):
def __init__(self, num_classes, last_stride, model_path):
super(Baseline, self).__init__()
self.base = ResNet(last_stride)
self.base.load_param(model_path)
self.gap = nn.AdaptiveAvgPool2d(1)
# self.gap = nn.AdaptiveMaxPool2d(1)
self.num_classes = num_classes
self.bottleneck = nn.BatchNorm1d(self.in_planes)
@ -63,15 +59,3 @@ class Baseline(nn.Module):
return cls_score, global_feat # global feature for triplet loss
else:
return feat
if __name__ == '__main__':
# net = Baseline(751).cuda(1)
import torch
net = ResNet(1).cuda(1)
x = torch.ones(128, 3, 256, 128).cuda(1)
y = net(x)
from IPython import embed
embed()

View File

@ -1,13 +0,0 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: xyliao1993@qq.com
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from .baseline import Baseline

View File

@ -1,5 +0,0 @@
#!/usr/bin/env bash
python3 tools/test.py --config_file='configs/market_softmax_triplet.yml' \
--load_model='/home/test2/liaoxingyu/pytorch-ckpt/reid/market_softmax_triplet/350_Baseline350.pth.tar'

View File

@ -1,8 +0,0 @@
#!/usr/bin/env bash
checkpoint_dir=/home/test2/liaoxingyu/pytorch-ckpt/reid/market_softmax/
mkdir -p ${checkpoint_dir}
python3 tools/train.py --config_file='configs/market_softmax.yml' \
--save_dir=${checkpoint_dir} | tee ${checkpoint_dir}/train.log

View File

@ -1,8 +0,0 @@
#!/usr/bin/env bash
checkpoint_dir=/home/test2/liaoxingyu/pytorch-ckpt/reid/market_softmax_triplet/
mkdir -p ${checkpoint_dir}
python3 tools/train.py --config_file='configs/market_softmax_triplet.yml' \
--save_dir=${checkpoint_dir} | tee ${checkpoint_dir}/train.log

View File

@ -1,8 +0,0 @@
#!/usr/bin/env bash
checkpoint_dir=/home/test2/liaoxingyu/pytorch-ckpt/reid/market_triplet/
mkdir -p ${checkpoint_dir}
python3 tools/train.py --config_file='configs/market_triplet.yml' \
--save_dir=${checkpoint_dir} | tee ${checkpoint_dir}/train.log

8
solver/__init__.py Normal file
View File

@ -0,0 +1,8 @@
# encoding: utf-8
"""
@author: sherlock
@contact: sherlockliao01@gmail.com
"""
from .build import make_optimizer
from .lr_scheduler import WarmupMultiStepLR

25
solver/build.py Normal file
View File

@ -0,0 +1,25 @@
# encoding: utf-8
"""
@author: sherlock
@contact: sherlockliao01@gmail.com
"""
import torch
def make_optimizer(cfg, model):
params = []
for key, value in model.named_parameters():
if not value.requires_grad:
continue
lr = cfg.SOLVER.BASE_LR
weight_decay = cfg.SOLVER.WEIGHT_DECAY
if "bias" in key:
lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR
weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS
params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]
if cfg.SOLVER.OPTIMIZER_NAME == 'SGD':
optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params, momentum=cfg.SOLVER.MOMENTUM)
else:
optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params)
return optimizer

56
solver/lr_scheduler.py Normal file
View File

@ -0,0 +1,56 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
from bisect import bisect_right
import torch
# FIXME ideally this would be achieved with a CombinedLRScheduler,
# separating MultiStepLR with WarmupLR
# but the current LRScheduler design doesn't allow it
class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler):
def __init__(
self,
optimizer,
milestones,
gamma=0.1,
warmup_factor=1.0 / 3,
warmup_iters=500,
warmup_method="linear",
last_epoch=-1,
):
if not list(milestones) == sorted(milestones):
raise ValueError(
"Milestones should be a list of" " increasing integers. Got {}",
milestones,
)
if warmup_method not in ("constant", "linear"):
raise ValueError(
"Only 'constant' or 'linear' warmup_method accepted"
"got {}".format(warmup_method)
)
self.milestones = milestones
self.gamma = gamma
self.warmup_factor = warmup_factor
self.warmup_iters = warmup_iters
self.warmup_method = warmup_method
super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch)
def get_lr(self):
warmup_factor = 1
if self.last_epoch < self.warmup_iters:
if self.warmup_method == "constant":
warmup_factor = self.warmup_factor
elif self.warmup_method == "linear":
alpha = self.last_epoch / self.warmup_iters
warmup_factor = self.warmup_factor * (1 - alpha) + alpha
return [
base_lr
* warmup_factor
* self.gamma ** bisect_right(self.milestones, self.last_epoch)
for base_lr in self.base_lrs
]

5
tests/__init__.py Normal file
View File

@ -0,0 +1,5 @@
# encoding: utf-8
"""
@author: sherlock
@contact: sherlockliao01@gmail.com
"""

View File

@ -0,0 +1,26 @@
import sys
import unittest
import torch
from torch import nn
sys.path.append('.')
from solver.lr_scheduler import WarmupMultiStepLR
from solver.build import make_optimizer
from config import cfg
class MyTestCase(unittest.TestCase):
def test_something(self):
net = nn.Linear(10, 10)
optimizer = make_optimizer(cfg, net)
lr_scheduler = WarmupMultiStepLR(optimizer, [20, 40], warmup_iters=10)
for i in range(50):
lr_scheduler.step()
for j in range(3):
print(i, lr_scheduler.get_lr()[0])
optimizer.step()
if __name__ == '__main__':
unittest.main()

View File

@ -3,9 +3,3 @@
@author: sherlock
@contact: sherlockliao01@gmail.com
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

View File

@ -4,64 +4,61 @@
@contact: sherlockliao01@gmail.com
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import argparse
import logging
import os
import sys
from pprint import pprint
from os import mkdir
import torch
from torch import nn
from torch.backends import cudnn
import network
from core.config import opt, update_config
from core.loader import get_data_provider
from core.solver import Solver
FORMAT = '[%(levelname)s]: %(message)s'
logging.basicConfig(
level=logging.INFO,
format=FORMAT,
stream=sys.stdout
)
def test(args):
logging.info('======= user config ======')
logging.info(pprint(opt))
logging.info(pprint(args))
logging.info('======= end ======')
train_data, test_data, num_query = get_data_provider(opt)
net = getattr(network, opt.network.name)(opt.dataset.num_classes, opt.network.last_stride)
net.load_state_dict(torch.load(args.load_model)['state_dict'])
net = nn.DataParallel(net).cuda()
mod = Solver(opt, net)
mod.test_func(test_data, num_query)
sys.path.append('.')
from config import cfg
from data import make_data_loader
from engine.inference import inference
from modeling import build_model
from utils.logger import setup_logger
def main():
parser = argparse.ArgumentParser(description='reid model testing')
parser.add_argument('--config_file', type=str, default=None,
help='Optional config file for params')
parser.add_argument('--load_model', type=str, required=True,
help='load trained model for testing')
parser = argparse.ArgumentParser(description="ReID Baseline Inference")
parser.add_argument(
"--config_file", default="", help="path to config file", type=str
)
parser.add_argument("opts", help="Modify config options using the command-line", default=None,
nargs=argparse.REMAINDER)
args = parser.parse_args()
if args.config_file is not None:
update_config(args.config_file)
os.environ["CUDA_VISIBLE_DEVICES"] = opt.network.gpus
num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
if args.config_file != "":
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.freeze()
output_dir = cfg.OUTPUT_DIR
if output_dir and not os.path.exists(output_dir):
mkdir(output_dir)
logger = setup_logger("reid_baseline", output_dir, 0)
logger.info("Using {} GPUS".format(num_gpus))
logger.info(args)
if args.config_file != "":
logger.info("Loaded configuration file {}".format(args.config_file))
with open(args.config_file, 'r') as cf:
config_str = "\n" + cf.read()
logger.info(config_str)
logger.info("Running with config:\n{}".format(cfg))
cudnn.benchmark = True
test(args)
train_loader, val_loader, num_query, num_classes = make_data_loader(cfg)
model = build_model(cfg, num_classes)
model.load_state_dict(torch.load(cfg.TEST.WEIGHT))
inference(cfg, model, val_loader, num_query)
if __name__ == '__main__':

View File

@ -4,95 +4,83 @@
@contact: sherlockliao01@gmail.com
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import argparse
import logging
import os
import sys
from pprint import pprint
import torch
from torch import nn
from torch.backends import cudnn
import network
from core.config import opt, update_config
from core.loader import get_data_provider
from core.solver import Solver
from utils.loss import TripletLoss
from utils.lr_scheduler import LRScheduler
sys.path.append('.')
from config import cfg
from data import make_data_loader
from engine.trainer import do_train
from modeling import build_model
from layers import make_loss
from solver import make_optimizer, WarmupMultiStepLR
FORMAT = '[%(levelname)s]: %(message)s'
logging.basicConfig(
level=logging.INFO,
format=FORMAT,
stream=sys.stdout
)
from utils.logger import setup_logger
def train(args):
logging.info('======= user config ======')
logging.info(pprint(opt))
logging.info(pprint(args))
logging.info('======= end ======')
def train(cfg):
# prepare dataset
train_loader, val_loader, num_query, num_classes = make_data_loader(cfg)
# prepare model
model = build_model(cfg, num_classes)
train_data, test_data, num_query = get_data_provider(opt)
optimizer = make_optimizer(cfg, model)
scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR,
cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD)
net = getattr(network, opt.network.name)(opt.dataset.num_classes, opt.network.last_stride)
loss_func = make_loss(cfg)
optimizer = getattr(torch.optim, opt.train.optimizer)(net.parameters(), lr=opt.train.lr, weight_decay=opt.train.wd)
ce_loss = nn.CrossEntropyLoss()
triplet_loss = TripletLoss(margin=opt.train.margin)
arguments = {}
def ce_loss_func(scores, feat, labels):
ce = ce_loss(scores, labels)
return ce
def tri_loss_func(scores, feat, labels):
tri = triplet_loss(feat, labels)[0]
return tri
def ce_tri_loss_func(scores, feat, labels):
ce = ce_loss(scores, labels)
triplet = triplet_loss(feat, labels)[0]
return ce + triplet
if opt.train.loss_fn == 'softmax':
loss_fn = ce_loss_func
elif opt.train.loss_fn == 'triplet':
loss_fn = tri_loss_func
elif opt.train.loss_fn == 'softmax_triplet':
loss_fn = ce_tri_loss_func
else:
raise ValueError('Unknown loss func {}'.format(opt.train.loss_fn))
lr_scheduler = LRScheduler(base_lr=opt.train.lr, step=opt.train.step,
factor=opt.train.factor, warmup_epoch=opt.train.warmup_epoch,
warmup_begin_lr=opt.train.warmup_begin_lr)
net = nn.DataParallel(net).cuda()
mod = Solver(opt, net)
mod.fit(train_data=train_data, test_data=test_data, num_query=num_query, optimizer=optimizer,
criterion=loss_fn, lr_scheduler=lr_scheduler)
do_train(
cfg,
model,
train_loader,
val_loader,
optimizer,
scheduler,
loss_func,
num_query
)
def main():
parser = argparse.ArgumentParser(description='reid model training')
parser.add_argument('--config_file', type=str, default=None, required=True,
help='Optional config file for params')
parser.add_argument('--save_dir', type=str, default=None, required=True,
help='model save checkpoint directory')
parser = argparse.ArgumentParser(description="ReID Baseline Training")
parser.add_argument(
"--config_file", default="", help="path to config file", type=str
)
parser.add_argument("opts", help="Modify config options using the command-line", default=None,
nargs=argparse.REMAINDER)
args = parser.parse_args()
if args.config_file is not None:
update_config(args.config_file)
opt.misc.save_dir = args.save_dir
os.environ["CUDA_VISIBLE_DEVICES"] = opt.network.gpus
num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
if args.config_file != "":
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.freeze()
output_dir = cfg.OUTPUT_DIR
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir)
logger = setup_logger("reid_baseline", output_dir, 0)
logger.info("Using {} GPUS".format(num_gpus))
logger.info(args)
if args.config_file != "":
logger.info("Loaded configuration file {}".format(args.config_file))
with open(args.config_file, 'r') as cf:
config_str = "\n" + cf.read()
logger.info(config_str)
logger.info("Running with config:\n{}".format(cfg))
cudnn.benchmark = True
train(args)
train(cfg)
if __name__ == '__main__':

View File

@ -1,11 +1,6 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: xyliao1993@qq.com
@author: sherlock
@contact: sherlockliao01@gmail.com
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

39
utils/iotools.py Normal file
View File

@ -0,0 +1,39 @@
# encoding: utf-8
"""
@author: sherlock
@contact: sherlockliao01@gmail.com
"""
import errno
import json
import os
import os.path as osp
def mkdir_if_missing(directory):
if not osp.exists(directory):
try:
os.makedirs(directory)
except OSError as e:
if e.errno != errno.EEXIST:
raise
def check_isfile(path):
isfile = osp.isfile(path)
if not isfile:
print("=> Warning: no file found at '{}' (ignored)".format(path))
return isfile
def read_json(fpath):
with open(fpath, 'r') as f:
obj = json.load(f)
return obj
def write_json(obj, fpath):
mkdir_if_missing(osp.dirname(fpath))
with open(fpath, 'w') as f:
json.dump(obj, f, indent=4, separators=(',', ': '))

30
utils/logger.py Normal file
View File

@ -0,0 +1,30 @@
# encoding: utf-8
"""
@author: sherlock
@contact: sherlockliao01@gmail.com
"""
import logging
import os
import sys
def setup_logger(name, save_dir, distributed_rank):
logger = logging.getLogger(name)
logger.setLevel(logging.DEBUG)
# don't log results for the non-master process
if distributed_rank > 0:
return logger
ch = logging.StreamHandler(stream=sys.stdout)
ch.setLevel(logging.DEBUG)
formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s")
ch.setFormatter(formatter)
logger.addHandler(ch)
if save_dir:
fh = logging.FileHandler(os.path.join(save_dir, "log.txt"), mode='w')
fh.setLevel(logging.DEBUG)
fh.setFormatter(formatter)
logger.addHandler(fh)
return logger

View File

@ -1,65 +0,0 @@
# encoding: utf-8
"""
@author: sherlock
@contact: sherlockliao01@gmail.com
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
class LRScheduler(object):
"""Base class of a learning rate scheduler.
A scheduler returns a new learning rate based on the number of updates that have
been performed.
Parameters
----------
base_lr : float, optional
The initial learning rate.
warmup_epoch: int
number of warmup steps used before this scheduler starts decay
warmup_begin_lr: float
if using warmup, the learning rate from which it starts warming up
warmup_mode: string
warmup can be done in two modes.
'linear' mode gradually increases lr with each step in equal increments
'constant' mode keeps lr at warmup_begin_lr for warmup_steps
"""
def __init__(self, base_lr=0.01, step=(30, 60), factor=0.1,
warmup_epoch=0, warmup_begin_lr=0, warmup_mode='linear'):
self.base_lr = base_lr
self.learning_rate = base_lr
self.step = step
self.factor = factor
assert isinstance(warmup_epoch, int)
self.warmup_epoch = warmup_epoch
self.warmup_final_lr = base_lr
self.warmup_begin_lr = warmup_begin_lr
if self.warmup_begin_lr > self.warmup_final_lr:
raise ValueError("Base lr has to be higher than warmup_begin_lr")
if self.warmup_epoch < 0:
raise ValueError("Warmup steps has to be positive or 0")
if warmup_mode not in ['linear', 'constant']:
raise ValueError("Supports only linear and constant modes of warmup")
self.warmup_mode = warmup_mode
def update(self, num_epoch):
if self.warmup_epoch > num_epoch:
# warmup strategy
if self.warmup_mode == 'linear':
self.learning_rate = self.warmup_begin_lr + (self.warmup_final_lr - self.warmup_begin_lr) * \
num_epoch / self.warmup_epoch
elif self.warmup_mode == 'constant':
self.learning_rate = self.warmup_begin_lr
else:
count = sum([1 for s in self.step if s <= num_epoch])
self.learning_rate = self.base_lr * pow(self.factor, count)
return self.learning_rate

View File

@ -1,54 +0,0 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: xyliao1993@qq.com
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import math
import numpy as np
class AverageMeter(object):
def __init__(self, name):
self.name = name
self.n = 0
self.sum = 0.0
self.var = 0.0
self.val = 0.0
self.mean = np.nan
self.std = np.nan
def update(self, value, n=1):
self.val = value
self.sum += value
self.var += value * value
self.n += n
if self.n == 0:
self.mean, self.std = np.nan, np.nan
elif self.n == 1:
self.mean, self.std = self.sum, np.inf
else:
self.mean = self.sum / self.n
self.std = math.sqrt(
(self.var - self.n * self.mean * self.mean) / (self.n - 1.0))
def value(self):
return self.mean, self.std
def get(self):
return self.name, self.mean
def reset(self):
self.n = 0
self.sum = 0.0
self.var = 0.0
self.val = 0.0
self.mean = np.nan
self.std = np.nan

48
utils/reid_metric.py Normal file
View File

@ -0,0 +1,48 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import numpy as np
import torch
from ignite.metrics import Metric
from data.datasets.eval_reid import eval_func
class R1_mAP(Metric):
def __init__(self, num_query, max_rank=50):
super(R1_mAP, self).__init__()
self.num_query = num_query
self.max_rank = max_rank
def reset(self):
self.feats = []
self.pids = []
self.camids = []
def update(self, output):
feat, pid, camid = output
self.feats.append(feat)
self.pids.extend(np.asarray(pid))
self.camids.extend(np.asarray(camid))
def compute(self):
feats = torch.cat(self.feats, dim=0)
# query
qf = feats[:self.num_query]
q_pids = np.asarray(self.pids[:self.num_query])
q_camids = np.asarray(self.camids[:self.num_query])
# gallery
gf = feats[self.num_query:]
g_pids = np.asarray(self.pids[self.num_query:])
g_camids = np.asarray(self.camids[self.num_query:])
m, n = qf.shape[0], gf.shape[0]
distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \
torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
distmat.addmm_(1, -2, qf, gf.t())
distmat = distmat.cpu().numpy()
cmc, mAP = eval_func(distmat, q_pids, g_pids, q_camids, g_camids)
return cmc, mAP

View File

@ -1,35 +0,0 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: xyliao1993@qq.com
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import errno
import os
import shutil
import sys
import os.path as osp
import torch
def mkdir_if_missing(dir_path):
try:
os.makedirs(dir_path)
except OSError as e:
if e.errno != errno.EEXIST:
raise
def save_checkpoint(state, is_best, save_dir, filename='checkpoint.pth.tar'):
fpath = '_'.join((str(state['epoch']), filename))
fpath = osp.join(save_dir, fpath)
mkdir_if_missing(save_dir)
torch.save(state, fpath)
if is_best:
shutil.copy(fpath, osp.join(save_dir, 'model_best.pth.tar'))