mirror of
https://github.com/JDAI-CV/fast-reid.git
synced 2025-06-03 14:50:47 +08:00
Update first stable version v1.0
This commit is contained in:
parent
519bac01fc
commit
69e12d989d
58
README.md
58
README.md
@ -1,21 +1,23 @@
|
||||
# ReID_baseline
|
||||
Baseline model (with bottleneck) for person ReID (using softmax and triplet loss). This is PyTorch version, [mxnet version](https://github.com/L1aoXingyu/reid_baseline_gluon) has a better result and more SOTA methods.
|
||||
Baseline model (with bottleneck) for person ReID (using softmax and triplet loss).
|
||||
|
||||
We support
|
||||
- multi-GPU training
|
||||
- easy dataset preparation
|
||||
- end-to-end training and evaluation
|
||||
- [x] easy dataset preparation
|
||||
- [x] end-to-end training and evaluation
|
||||
- [x] high modular management
|
||||
|
||||
## Get Started
|
||||
The designed architecture follows this guide [PyTorch-Project-Template](https://github.com/L1aoXingyu/PyTorch-Project-Template), you can check each folder's purpose by yourself.
|
||||
|
||||
1. `cd` to folder where you want to download this repo
|
||||
2. Run `git clone https://github.com/L1aoXingyu/reid_baseline.git`
|
||||
3. Install dependencies:
|
||||
- [pytorch 0.4](https://pytorch.org/)
|
||||
- [pytorch 1.0](https://pytorch.org/)
|
||||
- torchvision
|
||||
- tensorflow (for tensorboard)
|
||||
- [tensorboardX](https://github.com/lanpa/tensorboardX)
|
||||
- [ignite](https://github.com/pytorch/ignite)
|
||||
- [yacs](https://github.com/rbgirshick/yacs)
|
||||
4. Prepare dataset
|
||||
|
||||
|
||||
Create a directory to store reid datasets under this repo via
|
||||
```bash
|
||||
cd reid_baseline
|
||||
@ -23,39 +25,43 @@ We support
|
||||
```
|
||||
1. Download dataset to `data/` from http://www.liangzheng.org/Project/project_reid.html
|
||||
2. Extract dataset and rename to `market1501`. The data structure would like:
|
||||
```
|
||||
market1501/
|
||||
bounding_box_test/
|
||||
bounding_box_train/
|
||||
```bash
|
||||
data
|
||||
market1501
|
||||
bounding_box_test/
|
||||
bounding_box_train/
|
||||
```
|
||||
5. Prepare pretrained model if you don't have
|
||||
```python
|
||||
from torchvision import models
|
||||
models.resnet50(pretrained=True)
|
||||
```
|
||||
Then it will automatically download model in `~.torch/models/`, you should set this path in `config.py`
|
||||
Then it will automatically download model in `~/.torch/models/`, you should set this path in `config/defaults.py` for all training or set in every single training config file in `configs/`.
|
||||
|
||||
## Train
|
||||
You can run
|
||||
Most of the configuration files that we provide, you can run this command for training
|
||||
```bash
|
||||
bash scripts/train_triplet_softmax.sh
|
||||
python3 tools/train.py --config_file='configs/market1501_softmax_bs64.yml'
|
||||
```
|
||||
|
||||
You can also modify your cfg parameters as follow
|
||||
```bash
|
||||
python3 tools/train.py --config_file='configs/market1501_softmax_bs64.yml' INPUT.SIZE_TRAIN '(256, 128)' INPUT.SIZE_TEST '(256, 128)'
|
||||
```
|
||||
in `reid_baseline` folder if you want to train with softmax and triplet loss. You can find others train scripts in `scripts`.
|
||||
|
||||
## Results
|
||||
|
||||
**network architecture**
|
||||
|
||||
<div align=center>
|
||||
<img src='https://ws3.sinaimg.cn/large/006tNbRwly1fvh3ekjh12j315k0j4q58.jpg' width='500'>
|
||||
</div>
|
||||
|
||||
| cfg | market1501 | cuhk03 | dukemtmc |
|
||||
| --- | -- | -- | -- |
|
||||
| softmax, size=(384, 128), batch_size=64 | 92.5 (79.4) | 60.4 (56.1) | 84.6 (68.1) |
|
||||
| softmax, size=(256, 128), batch_size=64 | 92.0 (80.4) | 60.5 (55.5) | 84.1(68.4) |
|
||||
| softmax_triplet, size=(384, 128), batch_size=128(32 id x 4 imgs) | 93.2 (82.5) | - | 86.4 (73.1)
|
||||
| softmax_triplet, size=(256, 128), batch_size=128(32 id x 4 imgs) | 93.8 (83.2) | 65.9 (61.4) | -
|
||||
|
||||
| config | Market1501 |
|
||||
| --- | -- |
|
||||
| bs(32) size(384,128) softmax | 92.2 (78.5) |
|
||||
| bs(64) size(384,128) softmax | 92.5 (79.6) |
|
||||
| bs(32) size(256,128) softmax | 92.0 (78.4) |
|
||||
| bs(64) size(256,128) softmax | 91.7 (78.3) |
|
||||
| bs(128) size(256,128) softmax | 91.2 (77.4) |
|
||||
| triplet(p=32,k=4) size(256,128) | 88.3 (73.8) |
|
||||
| triplet(p=16,k=4)+softmax size(384,128) | 93.1 (82.0) |
|
||||
| triplet(p=24,k=4)+softmax size(384,128) | 91.7 (79.0) |
|
||||
|
||||
|
7
config/__init__.py
Normal file
7
config/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
# encoding: utf-8
|
||||
"""
|
||||
@author: sherlock
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from .defaults import _C as cfg
|
101
config/defaults.py
Normal file
101
config/defaults.py
Normal file
@ -0,0 +1,101 @@
|
||||
from yacs.config import CfgNode as CN
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Convention about Training / Test specific parameters
|
||||
# -----------------------------------------------------------------------------
|
||||
# Whenever an argument can be either used for training or for testing, the
|
||||
# corresponding name will be post-fixed by a _TRAIN for a training parameter,
|
||||
# or _TEST for a test-specific parameter.
|
||||
# For example, the number of images during training will be
|
||||
# IMAGES_PER_BATCH_TRAIN, while the number of images for testing will be
|
||||
# IMAGES_PER_BATCH_TEST
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Config definition
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
_C = CN()
|
||||
|
||||
_C.MODEL = CN()
|
||||
_C.MODEL.DEVICE = "cuda"
|
||||
_C.MODEL.NAME = 'resnet50'
|
||||
_C.MODEL.LAST_STRIDE = 1
|
||||
_C.MODEL.PRETRAIN_PATH = ''
|
||||
# -----------------------------------------------------------------------------
|
||||
# INPUT
|
||||
# -----------------------------------------------------------------------------
|
||||
_C.INPUT = CN()
|
||||
# Size of the image during training
|
||||
_C.INPUT.SIZE_TRAIN = [384, 128]
|
||||
# Size of the image during test
|
||||
_C.INPUT.SIZE_TEST = [384, 128]
|
||||
# Random probability for image horizontal flip
|
||||
_C.INPUT.PROB = 0.5
|
||||
# Values to be used for image normalization
|
||||
_C.INPUT.PIXEL_MEAN = [0.485, 0.456, 0.406]
|
||||
# Values to be used for image normalization
|
||||
_C.INPUT.PIXEL_STD = [0.229, 0.224, 0.225]
|
||||
# Value of padding size
|
||||
_C.INPUT.PADDING = 10
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Dataset
|
||||
# -----------------------------------------------------------------------------
|
||||
_C.DATASETS = CN()
|
||||
# List of the dataset names for training, as present in paths_catalog.py
|
||||
_C.DATASETS.NAMES = ('market1501')
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# DataLoader
|
||||
# -----------------------------------------------------------------------------
|
||||
_C.DATALOADER = CN()
|
||||
# Number of data loading threads
|
||||
_C.DATALOADER.NUM_WORKERS = 8
|
||||
# Sampler for data loading
|
||||
_C.DATALOADER.SAMPLER = 'softmax'
|
||||
# Number of instance for one batch
|
||||
_C.DATALOADER.NUM_INSTANCE = 16
|
||||
|
||||
# ---------------------------------------------------------------------------- #
|
||||
# Solver
|
||||
# ---------------------------------------------------------------------------- #
|
||||
_C.SOLVER = CN()
|
||||
_C.SOLVER.OPTIMIZER_NAME = "Adam"
|
||||
|
||||
_C.SOLVER.MAX_EPOCHS = 50
|
||||
|
||||
_C.SOLVER.BASE_LR = 3e-4
|
||||
_C.SOLVER.BIAS_LR_FACTOR = 2
|
||||
|
||||
_C.SOLVER.MOMENTUM = 0.9
|
||||
|
||||
_C.SOLVER.MARGIN = 0.3
|
||||
|
||||
_C.SOLVER.WEIGHT_DECAY = 0.0005
|
||||
_C.SOLVER.WEIGHT_DECAY_BIAS = 0.
|
||||
|
||||
_C.SOLVER.GAMMA = 0.1
|
||||
_C.SOLVER.STEPS = (30, 55)
|
||||
|
||||
_C.SOLVER.WARMUP_FACTOR = 1.0 / 3
|
||||
_C.SOLVER.WARMUP_ITERS = 500
|
||||
_C.SOLVER.WARMUP_METHOD = "linear"
|
||||
|
||||
_C.SOLVER.CHECKPOINT_PERIOD = 50
|
||||
_C.SOLVER.LOG_PERIOD = 100
|
||||
_C.SOLVER.EVAL_PERIOD = 50
|
||||
# Number of images per batch
|
||||
# This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will
|
||||
# see 2 images per batch
|
||||
_C.SOLVER.IMS_PER_BATCH = 64
|
||||
|
||||
# This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will
|
||||
# see 2 images per batch
|
||||
_C.TEST = CN()
|
||||
_C.TEST.IMS_PER_BATCH = 128
|
||||
_C.TEST.WEIGHT = ""
|
||||
|
||||
# ---------------------------------------------------------------------------- #
|
||||
# Misc options
|
||||
# ---------------------------------------------------------------------------- #
|
||||
_C.OUTPUT_DIR = ""
|
@ -1,39 +0,0 @@
|
||||
# configuration for training market1501
|
||||
|
||||
dataset:
|
||||
name: market1501
|
||||
|
||||
aug:
|
||||
resize_size: [384, 128]
|
||||
random_mirror: True
|
||||
pad: 10
|
||||
random_crop: True
|
||||
random_erasing: True
|
||||
|
||||
train:
|
||||
optimizer: 'Adam'
|
||||
lr: 0.00035
|
||||
num_epochs: 80
|
||||
batch_size: 32
|
||||
sampler: 'softmax'
|
||||
wd: 0.0005
|
||||
step: [30, 55]
|
||||
factor: 0.1
|
||||
warmup_epoch: 5
|
||||
warmup_begin_lr: 0.0000035
|
||||
loss_fn: 'softmax'
|
||||
|
||||
test:
|
||||
batch_size: 128
|
||||
|
||||
network:
|
||||
name: 'Baseline'
|
||||
last_stride: 1
|
||||
gpus: '0'
|
||||
|
||||
misc:
|
||||
eval_step: 20
|
||||
save_step: 20
|
||||
log_interval: 100
|
||||
|
||||
|
@ -1,41 +0,0 @@
|
||||
# configuration for training market1501
|
||||
|
||||
dataset:
|
||||
name: market1501
|
||||
|
||||
aug:
|
||||
resize_size: [384, 128]
|
||||
random_mirror: True
|
||||
pad: 10
|
||||
random_crop: True
|
||||
random_erasing: True
|
||||
|
||||
train:
|
||||
optimizer: 'Adam'
|
||||
lr: 0.00035
|
||||
num_epochs: 400
|
||||
p_size: 16
|
||||
k_size: 4
|
||||
sampler: 'triplet'
|
||||
wd: 0.0005
|
||||
step: [80, 180, 300]
|
||||
factor: 0.1
|
||||
warmup_epoch: 20
|
||||
warmup_begin_lr: 0.0000035
|
||||
loss_fn: 'softmax_triplet'
|
||||
|
||||
|
||||
test:
|
||||
batch_size: 128
|
||||
|
||||
network:
|
||||
name: 'Baseline'
|
||||
last_stride: 1
|
||||
gpus: '1'
|
||||
|
||||
misc:
|
||||
eval_step: 50
|
||||
save_step: 50
|
||||
log_interval: 20
|
||||
|
||||
|
@ -1,40 +0,0 @@
|
||||
# configuration for training market1501
|
||||
|
||||
dataset:
|
||||
name: market1501
|
||||
|
||||
aug:
|
||||
resize_size: [384, 128]
|
||||
random_mirror: True
|
||||
pad: 10
|
||||
random_crop: True
|
||||
|
||||
train:
|
||||
optimizer: 'Adam'
|
||||
lr: 0.00035
|
||||
num_epochs: 400
|
||||
p_size: 32
|
||||
k_size: 4
|
||||
sampler: 'triplet'
|
||||
wd: 0.0005
|
||||
step: [80, 180, 300]
|
||||
factor: 0.1
|
||||
warmup_epoch: 20
|
||||
warmup_begin_lr: 0.0000035
|
||||
loss_fn: 'triplet'
|
||||
|
||||
|
||||
test:
|
||||
batch_size: 128
|
||||
|
||||
network:
|
||||
name: 'Baseline'
|
||||
last_stride: 1
|
||||
gpus: '1'
|
||||
|
||||
misc:
|
||||
eval_step: 50
|
||||
save_step: 50
|
||||
log_interval: 20
|
||||
|
||||
|
43
configs/softmax.yml
Normal file
43
configs/softmax.yml
Normal file
@ -0,0 +1,43 @@
|
||||
MODEL:
|
||||
PRETRAIN_PATH: '/export/home/lxy/.torch/models/resnet50-19c8e357.pth'
|
||||
|
||||
|
||||
INPUT:
|
||||
SIZE_TRAIN: [384, 128]
|
||||
SIZE_TEST: [384, 128]
|
||||
PROB: 0.5 # random horizontal flip
|
||||
PADDING: 10
|
||||
|
||||
DATASETS:
|
||||
NAMES: ('market1501')
|
||||
|
||||
DATALOADER:
|
||||
SAMPLER: 'softmax'
|
||||
NUM_WORKERS: 8
|
||||
|
||||
SOLVER:
|
||||
OPTIMIZER_NAME: 'Adam'
|
||||
MAX_EPOCHS: 120
|
||||
BASE_LR: 0.00035
|
||||
BIAS_LR_FACTOR: 1
|
||||
WEIGHT_DECAY: 0.0005
|
||||
WEIGHT_DECAY_BIAS: 0.0005
|
||||
IMS_PER_BATCH: 64
|
||||
|
||||
STEPS: [30, 55]
|
||||
GAMMA: 0.1
|
||||
|
||||
WARMUP_FACTOR: 0.01
|
||||
WARMUP_ITERS: 5
|
||||
WARMUP_METHOD: 'linear'
|
||||
|
||||
CHECKPOINT_PERIOD: 20
|
||||
LOG_PERIOD: 100
|
||||
EVAL_PERIOD: 20
|
||||
|
||||
TEST:
|
||||
IMS_PER_BATCH: 256
|
||||
|
||||
OUTPUT_DIR: "/export/home/lxy/CHECKPOINTS/reid/market1501/softmax_bs64_384x128"
|
||||
|
||||
|
45
configs/softmax_triplet.yml
Normal file
45
configs/softmax_triplet.yml
Normal file
@ -0,0 +1,45 @@
|
||||
MODEL:
|
||||
PRETRAIN_PATH: '/export/home/lxy/.torch/models/resnet50-19c8e357.pth'
|
||||
|
||||
|
||||
INPUT:
|
||||
SIZE_TRAIN: [384, 128]
|
||||
SIZE_TEST: [384, 128]
|
||||
PROB: 0.5 # random horizontal flip
|
||||
PADDING: 10
|
||||
|
||||
DATASETS:
|
||||
NAMES: ('market1501')
|
||||
|
||||
DATALOADER:
|
||||
SAMPLER: 'softmax_triplet'
|
||||
NUM_INSTANCE: 4
|
||||
NUM_WORKERS: 8
|
||||
|
||||
SOLVER:
|
||||
OPTIMIZER_NAME: 'Adam'
|
||||
MAX_EPOCHS: 120
|
||||
BASE_LR: 0.00035
|
||||
BIAS_LR_FACTOR: 1
|
||||
WEIGHT_DECAY: 0.0005
|
||||
WEIGHT_DECAY_BIAS: 0.0005
|
||||
IMS_PER_BATCH: 64
|
||||
|
||||
STEPS: [40, 70]
|
||||
GAMMA: 0.1
|
||||
|
||||
WARMUP_FACTOR: 0.01
|
||||
WARMUP_ITERS: 10
|
||||
WARMUP_METHOD: 'linear'
|
||||
|
||||
CHECKPOINT_PERIOD: 40
|
||||
LOG_PERIOD: 100
|
||||
EVAL_PERIOD: 40
|
||||
|
||||
TEST:
|
||||
IMS_PER_BATCH: 256
|
||||
WEIGHT: "path"
|
||||
|
||||
OUTPUT_DIR: "/export/home/lxy/CHECKPOINTS/reid/market1501/softmax_triplet_bs128_384x128"
|
||||
|
||||
|
@ -1,11 +0,0 @@
|
||||
# encoding: utf-8
|
||||
"""
|
||||
@author: sherlock
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
@ -1,79 +0,0 @@
|
||||
# encoding: utf-8
|
||||
"""
|
||||
@author: sherlock
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import yaml
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
__C = edict()
|
||||
opt = __C
|
||||
__C.seed = 0
|
||||
|
||||
__C.dataset = edict()
|
||||
__C.dataset.name = 'market1501'
|
||||
__C.dataset.num_classes = 751
|
||||
|
||||
__C.aug = edict()
|
||||
__C.aug.resize_size = [256, 128]
|
||||
__C.aug.color_jitter = False
|
||||
__C.aug.random_erasing = False
|
||||
__C.aug.random_mirror = True
|
||||
__C.aug.pad = 10
|
||||
__C.aug.random_crop = True
|
||||
|
||||
__C.train = edict()
|
||||
__C.train.optimizer = 'Adam'
|
||||
__C.train.lr = 3e-4
|
||||
__C.train.wd = 5e-4
|
||||
__C.train.momentum = 0.9
|
||||
__C.train.step = [80, 180, 300]
|
||||
__C.train.warmup_epoch = 20
|
||||
__C.train.warmup_begin_lr = 3e-6
|
||||
__C.train.factor = 0.1
|
||||
__C.train.margin = 0.3
|
||||
__C.train.num_epochs = 400
|
||||
__C.train.sampler = 'softmax'
|
||||
__C.train.p_size = 32 # number of person in a single gpu
|
||||
__C.train.k_size = 4 # number of images per person
|
||||
__C.train.batch_size = 128
|
||||
__C.train.loss_fn = 'softmax' # softmax, triplet, softmax_triplet
|
||||
__C.train.triplet_normalize = False
|
||||
|
||||
__C.test = edict()
|
||||
__C.test.batch_size = 128
|
||||
__C.test.load_path = '/mnt/truenas/scratch/xingyu.liao/DATA/mx-ckpt'
|
||||
|
||||
__C.network = edict()
|
||||
__C.network.depth = 50
|
||||
__C.network.name = 'Baseline'
|
||||
__C.network.last_stride = 1
|
||||
__C.network.gpus = "1"
|
||||
__C.network.workers = 8
|
||||
|
||||
__C.misc = edict()
|
||||
__C.misc.log_interval = 10
|
||||
__C.misc.eval_step = 50
|
||||
__C.misc.save_step = 50
|
||||
__C.misc.save_dir = ''
|
||||
|
||||
|
||||
def update_config(config_file):
|
||||
exp_config = None
|
||||
with open(config_file) as f:
|
||||
exp_config = edict(yaml.load(f))
|
||||
for k, v in exp_config.items():
|
||||
if k in __C:
|
||||
if isinstance(v, dict):
|
||||
for vk, vv in v.items():
|
||||
__C[k][vk] = vv
|
||||
else:
|
||||
__C[k] = v
|
||||
else:
|
||||
raise ValueError("key must exist in configs.py")
|
126
core/loader.py
126
core/loader.py
@ -1,126 +0,0 @@
|
||||
from __future__ import print_function, absolute_import
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
import torchvision.transforms as T
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset, Sampler, DataLoader
|
||||
|
||||
from utils import augmenter
|
||||
from .data_manager import init_dataset
|
||||
|
||||
|
||||
def read_image(img_path):
|
||||
"""Keep reading image until succeed.
|
||||
This can avoid IOError incurred by heavy IO process."""
|
||||
got_img = False
|
||||
while not got_img:
|
||||
try:
|
||||
img = Image.open(img_path).convert("RGB")
|
||||
got_img = True
|
||||
except IOError:
|
||||
print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path))
|
||||
pass
|
||||
return img
|
||||
|
||||
|
||||
class ImageData(Dataset):
|
||||
def __init__(self, dataset, transform):
|
||||
self.dataset = dataset
|
||||
self.transform = transform
|
||||
|
||||
def __getitem__(self, item):
|
||||
img, pid, camid = self.dataset[item]
|
||||
img = read_image(img)
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
return img, pid, camid
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
|
||||
class RandomIdentitySampler(Sampler):
|
||||
def __init__(self, data_source, num_instances=4):
|
||||
self.data_source = data_source
|
||||
self.num_instances = num_instances
|
||||
self.index_dic = defaultdict(list)
|
||||
for index, (_, pid, _) in enumerate(data_source):
|
||||
self.index_dic[pid].append(index)
|
||||
self.pids = list(self.index_dic.keys())
|
||||
self.num_identities = len(self.pids)
|
||||
|
||||
def __iter__(self):
|
||||
indices = np.random.permutation(self.num_identities)
|
||||
ret = []
|
||||
for i in indices:
|
||||
pid = self.pids[i]
|
||||
t = self.index_dic[pid]
|
||||
replace = False if len(t) >= self.num_instances else True
|
||||
t = np.random.choice(t, size=self.num_instances, replace=replace)
|
||||
ret.extend(t)
|
||||
return iter(ret)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_identities * self.num_instances
|
||||
|
||||
|
||||
def get_data_provider(opt):
|
||||
num_gpus = (len(opt.network.gpus) + 1) // 2
|
||||
test_batch_size = opt.test.batch_size * num_gpus
|
||||
|
||||
# data augmenter
|
||||
random_mirror = opt.aug.get('random_mirror', False)
|
||||
pad = opt.aug.get('pad', False)
|
||||
random_crop = opt.aug.get('random_crop', False)
|
||||
random_erasing = opt.aug.get('random_erasing', False)
|
||||
|
||||
h, w = opt.aug.resize_size
|
||||
train_aug = list()
|
||||
train_aug.append(T.Resize((h, w)))
|
||||
if random_mirror:
|
||||
train_aug.append(T.RandomHorizontalFlip())
|
||||
if pad:
|
||||
train_aug.append(T.Pad(padding=pad))
|
||||
if random_crop:
|
||||
train_aug.append(T.RandomCrop((h, w)))
|
||||
train_aug.append(T.ToTensor())
|
||||
train_aug.append(T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
|
||||
if random_erasing:
|
||||
train_aug.append(augmenter.RandomErasing())
|
||||
train_aug = T.Compose(train_aug)
|
||||
|
||||
test_aug = list()
|
||||
test_aug.append(T.Resize((h, w)))
|
||||
test_aug.append(T.ToTensor())
|
||||
test_aug.append(T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
|
||||
test_aug = T.Compose(test_aug)
|
||||
|
||||
dataset = init_dataset(opt.dataset.name)
|
||||
train_set = ImageData(dataset.train, train_aug)
|
||||
test_set = ImageData(dataset.query + dataset.gallery, test_aug)
|
||||
|
||||
if opt.train.sampler == 'softmax':
|
||||
train_batch_size = opt.train.batch_size * num_gpus
|
||||
train_loader = DataLoader(train_set, batch_size=train_batch_size, shuffle=True,
|
||||
num_workers=opt.network.workers, pin_memory=True, drop_last=True)
|
||||
elif opt.train.sampler == 'triplet':
|
||||
train_batch_size = opt.train.p_size * num_gpus * opt.train.k_size
|
||||
train_loader = DataLoader(train_set, batch_size=train_batch_size,
|
||||
sampler=RandomIdentitySampler(dataset.train, opt.train.k_size),
|
||||
num_workers=opt.network.workers, pin_memory=True)
|
||||
else:
|
||||
raise ValueError('sampler must be softmax or triplet, but get {}'.format(opt.train.sampler))
|
||||
|
||||
test_loader = DataLoader(test_set, batch_size=test_batch_size, num_workers=opt.network.workers, pin_memory=True)
|
||||
return train_loader, test_loader, len(dataset.query) # return number of query
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from config import opt
|
||||
|
||||
train_loader, test_loader, num_query = get_data_provider(opt)
|
||||
from IPython import embed
|
||||
|
||||
embed()
|
187
core/solver.py
187
core/solver.py
@ -1,187 +0,0 @@
|
||||
# encoding: utf-8
|
||||
"""
|
||||
@author: sherlock
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from utils.meters import AverageMeter
|
||||
from utils.serialization import save_checkpoint
|
||||
|
||||
|
||||
class Solver(object):
|
||||
def __init__(self, opt, net):
|
||||
self.opt = opt
|
||||
self.net = net
|
||||
self.loss = AverageMeter('loss')
|
||||
self.acc = AverageMeter('acc')
|
||||
|
||||
def fit(self, train_data, test_data, num_query, optimizer, criterion, lr_scheduler):
|
||||
best_rank1 = -np.inf
|
||||
for epoch in range(self.opt.train.num_epochs):
|
||||
self.loss.reset()
|
||||
self.acc.reset()
|
||||
self.net.train()
|
||||
# update learning rate
|
||||
lr = lr_scheduler.update(epoch)
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
logging.info('Epoch [{}] learning rate update to {:.3e}'.format(epoch, lr))
|
||||
|
||||
tic = time.time()
|
||||
btic = time.time()
|
||||
for i, inputs in enumerate(train_data):
|
||||
data, pids, _ = inputs
|
||||
label = pids.cuda()
|
||||
score, feat = self.net(data)
|
||||
loss = criterion(score, feat, label)
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
self.loss.update(loss.item())
|
||||
acc = (score.max(1)[1] == label.long()).float().mean().item()
|
||||
self.acc.update(acc)
|
||||
|
||||
log_interval = self.opt.misc.log_interval
|
||||
if log_interval and not (i + 1) % log_interval:
|
||||
loss_name, loss_value = self.loss.get()
|
||||
metric_name, metric_value = self.acc.get()
|
||||
logging.info('Epoch[%d] Batch [%d]\tSpeed: %f samples/sec\t%s=%f\t'
|
||||
'%s=%f' % (
|
||||
epoch, i + 1, train_data.batch_size * log_interval / (time.time() - btic),
|
||||
loss_name, loss_value,
|
||||
metric_name, metric_value
|
||||
))
|
||||
btic = time.time()
|
||||
|
||||
loss_name, loss_value = self.loss.get()
|
||||
metric_name, metric_value = self.acc.get()
|
||||
throughput = int(train_data.batch_size * len(train_data) / (time.time() - tic))
|
||||
|
||||
logging.info('[Epoch %d] training: %s=%f\t%s=%f' % (
|
||||
epoch, loss_name, loss_value, metric_name, metric_value))
|
||||
logging.info('[Epoch %d] speed: %d samples/sec\ttime cost: %f' % (epoch, throughput, time.time() - tic))
|
||||
|
||||
is_best = False
|
||||
if test_data is not None and self.opt.misc.eval_step and not (epoch + 1) % self.opt.misc.eval_step:
|
||||
rank1 = self.test_func(test_data, num_query)
|
||||
is_best = rank1 > best_rank1
|
||||
if is_best:
|
||||
best_rank1 = rank1
|
||||
state_dict = self.net.module.state_dict()
|
||||
if not (epoch + 1) % self.opt.misc.save_step:
|
||||
save_checkpoint({
|
||||
'state_dict': state_dict,
|
||||
'epoch': epoch + 1,
|
||||
}, is_best=is_best, save_dir=self.opt.misc.save_dir,
|
||||
filename=self.opt.network.name + '.pth.tar')
|
||||
|
||||
def test_func(self, test_data, num_query):
|
||||
self.net.eval()
|
||||
feat, person, camera = list(), list(), list()
|
||||
for inputs in test_data:
|
||||
data, pids, camids = inputs
|
||||
with torch.no_grad():
|
||||
outputs = self.net(data).cpu()
|
||||
feat.append(outputs)
|
||||
person.extend(pids.numpy())
|
||||
camera.extend(camids.numpy())
|
||||
feat = torch.cat(feat, 0)
|
||||
qf = feat[:num_query]
|
||||
q_pids = np.asarray(person[:num_query])
|
||||
q_camids = np.asarray(camera[:num_query])
|
||||
gf = feat[num_query:]
|
||||
g_pids = np.asarray(person[num_query:])
|
||||
g_camids = np.asarray(camera[num_query:])
|
||||
|
||||
logging.info("Extracted features for query set, obtained {}-by-{} matrix".format(
|
||||
qf.shape[0], qf.shape[1]))
|
||||
logging.info("Extracted features for gallery set, obtained {}-by-{} matrix".format(
|
||||
gf.shape[0], gf.shape[1]))
|
||||
|
||||
logging.info("Computing distance matrix")
|
||||
|
||||
m, n = qf.shape[0], gf.shape[0]
|
||||
distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \
|
||||
torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
|
||||
distmat.addmm_(1, -2, qf, gf.t())
|
||||
distmat = distmat.numpy()
|
||||
|
||||
logging.info("Computing CMC and mAP")
|
||||
cmc, mAP = self.eval_func(distmat, q_pids, g_pids, q_camids, g_camids)
|
||||
|
||||
print("Results ----------")
|
||||
print("mAP: {:.1%}".format(mAP))
|
||||
print("CMC curve")
|
||||
for r in [1, 5, 10]:
|
||||
print("Rank-{:<3}: {:.1%}".format(r, cmc[r - 1]))
|
||||
print("------------------")
|
||||
return cmc[0]
|
||||
|
||||
@staticmethod
|
||||
def eval_func(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50):
|
||||
"""Evaluation with market1501 metric
|
||||
Key: for each query identity, its gallery images from the same camera view are discarded.
|
||||
"""
|
||||
num_q, num_g = distmat.shape
|
||||
if num_g < max_rank:
|
||||
max_rank = num_g
|
||||
print("Note: number of gallery samples is quite small, got {}".format(num_g))
|
||||
indices = np.argsort(distmat, axis=1)
|
||||
matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)
|
||||
|
||||
# compute cmc curve for each query
|
||||
all_cmc = []
|
||||
all_AP = []
|
||||
num_valid_q = 0. # number of valid query
|
||||
for q_idx in range(num_q):
|
||||
# get query pid and camid
|
||||
q_pid = q_pids[q_idx]
|
||||
q_camid = q_camids[q_idx]
|
||||
|
||||
# remove gallery samples that have the same pid and camid with query
|
||||
order = indices[q_idx]
|
||||
remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid)
|
||||
keep = np.invert(remove)
|
||||
|
||||
# compute cmc curve
|
||||
# binary vector, positions with value 1 are correct matches
|
||||
orig_cmc = matches[q_idx][keep]
|
||||
if not np.any(orig_cmc):
|
||||
# this condition is true when query identity does not appear in gallery
|
||||
continue
|
||||
|
||||
cmc = orig_cmc.cumsum()
|
||||
cmc[cmc > 1] = 1
|
||||
|
||||
all_cmc.append(cmc[:max_rank])
|
||||
num_valid_q += 1.
|
||||
|
||||
# compute average precision
|
||||
# reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision
|
||||
num_rel = orig_cmc.sum()
|
||||
tmp_cmc = orig_cmc.cumsum()
|
||||
tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)]
|
||||
tmp_cmc = np.asarray(tmp_cmc) * orig_cmc
|
||||
AP = tmp_cmc.sum() / num_rel
|
||||
all_AP.append(AP)
|
||||
|
||||
assert num_valid_q > 0, "Error: all query identities do not appear in gallery"
|
||||
|
||||
all_cmc = np.asarray(all_cmc).astype(np.float32)
|
||||
all_cmc = all_cmc.sum(0) / num_valid_q
|
||||
mAP = np.mean(all_AP)
|
||||
|
||||
return all_cmc, mAP
|
7
data/__init__.py
Normal file
7
data/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
# encoding: utf-8
|
||||
"""
|
||||
@author: sherlock
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from .build import make_data_loader
|
44
data/build.py
Normal file
44
data/build.py
Normal file
@ -0,0 +1,44 @@
|
||||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .collate_batch import train_collate_fn, val_collate_fn
|
||||
from .datasets import init_dataset, ImageDataset
|
||||
from .samplers import RandomIdentitySampler
|
||||
from .transforms import build_transforms
|
||||
|
||||
|
||||
def make_data_loader(cfg):
|
||||
train_transforms = build_transforms(cfg, is_train=True)
|
||||
val_transforms = build_transforms(cfg, is_train=False)
|
||||
num_workers = cfg.DATALOADER.NUM_WORKERS
|
||||
if len(cfg.DATASETS.NAMES) == 1:
|
||||
dataset = init_dataset(cfg.DATASETS.NAMES)
|
||||
else:
|
||||
# TODO: add multi dataset to train
|
||||
dataset = init_dataset(cfg.DATASETS.NAMES)
|
||||
|
||||
num_classes = dataset.num_train_pids
|
||||
train_set = ImageDataset(dataset.train, train_transforms)
|
||||
if cfg.DATALOADER.SAMPLER == 'softmax':
|
||||
train_loader = DataLoader(
|
||||
train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH, shuffle=True, num_workers=num_workers,
|
||||
collate_fn=train_collate_fn
|
||||
)
|
||||
else:
|
||||
train_loader = DataLoader(
|
||||
train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH,
|
||||
sampler=RandomIdentitySampler(dataset.train, cfg.SOLVER.IMS_PER_BATCH, cfg.DATALOADER.NUM_INSTANCE),
|
||||
num_workers=num_workers, collate_fn=train_collate_fn
|
||||
)
|
||||
|
||||
val_set = ImageDataset(dataset.query + dataset.gallery, val_transforms)
|
||||
val_loader = DataLoader(
|
||||
val_set, batch_size=cfg.TEST.IMS_PER_BATCH, shuffle=False, num_workers=num_workers,
|
||||
collate_fn=val_collate_fn
|
||||
)
|
||||
return train_loader, val_loader, len(dataset.query), num_classes
|
18
data/collate_batch.py
Normal file
18
data/collate_batch.py
Normal file
@ -0,0 +1,18 @@
|
||||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def train_collate_fn(batch):
|
||||
imgs, pids, _, _, = zip(*batch)
|
||||
pids = torch.tensor(pids, dtype=torch.int64)
|
||||
return torch.stack(imgs, dim=0), pids
|
||||
|
||||
|
||||
def val_collate_fn(batch):
|
||||
imgs, pids, camids, _ = zip(*batch)
|
||||
return torch.stack(imgs, dim=0), pids, camids
|
25
data/datasets/__init__.py
Normal file
25
data/datasets/__init__.py
Normal file
@ -0,0 +1,25 @@
|
||||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
from .cuhk03 import CUHK03
|
||||
from .dukemtmcreid import DukeMTMCreID
|
||||
from .market1501 import Market1501
|
||||
from .dataset_loader import ImageDataset
|
||||
|
||||
__factory = {
|
||||
'market1501': Market1501,
|
||||
'cuhk03': CUHK03,
|
||||
'dukemtmc': DukeMTMCreID
|
||||
}
|
||||
|
||||
|
||||
def get_names():
|
||||
return __factory.keys()
|
||||
|
||||
|
||||
def init_dataset(name, *args, **kwargs):
|
||||
if name not in __factory.keys():
|
||||
raise KeyError("Unknown datasets: {}".format(name))
|
||||
return __factory[name](*args, **kwargs)
|
95
data/datasets/bases.py
Normal file
95
data/datasets/bases.py
Normal file
@ -0,0 +1,95 @@
|
||||
# encoding: utf-8
|
||||
"""
|
||||
@author: sherlock
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class BaseDataset(object):
|
||||
"""
|
||||
Base class of reid dataset
|
||||
"""
|
||||
|
||||
def get_imagedata_info(self, data):
|
||||
pids, cams = [], []
|
||||
for _, pid, camid in data:
|
||||
pids += [pid]
|
||||
cams += [camid]
|
||||
pids = set(pids)
|
||||
cams = set(cams)
|
||||
num_pids = len(pids)
|
||||
num_cams = len(cams)
|
||||
num_imgs = len(data)
|
||||
return num_pids, num_imgs, num_cams
|
||||
|
||||
def get_videodata_info(self, data, return_tracklet_stats=False):
|
||||
pids, cams, tracklet_stats = [], [], []
|
||||
for img_paths, pid, camid in data:
|
||||
pids += [pid]
|
||||
cams += [camid]
|
||||
tracklet_stats += [len(img_paths)]
|
||||
pids = set(pids)
|
||||
cams = set(cams)
|
||||
num_pids = len(pids)
|
||||
num_cams = len(cams)
|
||||
num_tracklets = len(data)
|
||||
if return_tracklet_stats:
|
||||
return num_pids, num_tracklets, num_cams, tracklet_stats
|
||||
return num_pids, num_tracklets, num_cams
|
||||
|
||||
def print_dataset_statistics(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class BaseImageDataset(BaseDataset):
|
||||
"""
|
||||
Base class of image reid dataset
|
||||
"""
|
||||
|
||||
def print_dataset_statistics(self, train, query, gallery):
|
||||
num_train_pids, num_train_imgs, num_train_cams = self.get_imagedata_info(train)
|
||||
num_query_pids, num_query_imgs, num_query_cams = self.get_imagedata_info(query)
|
||||
num_gallery_pids, num_gallery_imgs, num_gallery_cams = self.get_imagedata_info(gallery)
|
||||
|
||||
print("Dataset statistics:")
|
||||
print(" ----------------------------------------")
|
||||
print(" subset | # ids | # images | # cameras")
|
||||
print(" ----------------------------------------")
|
||||
print(" train | {:5d} | {:8d} | {:9d}".format(num_train_pids, num_train_imgs, num_train_cams))
|
||||
print(" query | {:5d} | {:8d} | {:9d}".format(num_query_pids, num_query_imgs, num_query_cams))
|
||||
print(" gallery | {:5d} | {:8d} | {:9d}".format(num_gallery_pids, num_gallery_imgs, num_gallery_cams))
|
||||
print(" ----------------------------------------")
|
||||
|
||||
|
||||
class BaseVideoDataset(BaseDataset):
|
||||
"""
|
||||
Base class of video reid dataset
|
||||
"""
|
||||
|
||||
def print_dataset_statistics(self, train, query, gallery):
|
||||
num_train_pids, num_train_tracklets, num_train_cams, train_tracklet_stats = \
|
||||
self.get_videodata_info(train, return_tracklet_stats=True)
|
||||
|
||||
num_query_pids, num_query_tracklets, num_query_cams, query_tracklet_stats = \
|
||||
self.get_videodata_info(query, return_tracklet_stats=True)
|
||||
|
||||
num_gallery_pids, num_gallery_tracklets, num_gallery_cams, gallery_tracklet_stats = \
|
||||
self.get_videodata_info(gallery, return_tracklet_stats=True)
|
||||
|
||||
tracklet_stats = train_tracklet_stats + query_tracklet_stats + gallery_tracklet_stats
|
||||
min_num = np.min(tracklet_stats)
|
||||
max_num = np.max(tracklet_stats)
|
||||
avg_num = np.mean(tracklet_stats)
|
||||
|
||||
print("Dataset statistics:")
|
||||
print(" -------------------------------------------")
|
||||
print(" subset | # ids | # tracklets | # cameras")
|
||||
print(" -------------------------------------------")
|
||||
print(" train | {:5d} | {:11d} | {:9d}".format(num_train_pids, num_train_tracklets, num_train_cams))
|
||||
print(" query | {:5d} | {:11d} | {:9d}".format(num_query_pids, num_query_tracklets, num_query_cams))
|
||||
print(" gallery | {:5d} | {:11d} | {:9d}".format(num_gallery_pids, num_gallery_tracklets, num_gallery_cams))
|
||||
print(" -------------------------------------------")
|
||||
print(" number of images per tracklet: {} ~ {}, average {:.2f}".format(min_num, max_num, avg_num))
|
||||
print(" -------------------------------------------")
|
259
data/datasets/cuhk03.py
Normal file
259
data/datasets/cuhk03.py
Normal file
@ -0,0 +1,259 @@
|
||||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: liaoxingyu2@jd.com
|
||||
"""
|
||||
|
||||
import h5py
|
||||
import os.path as osp
|
||||
from scipy.io import loadmat
|
||||
from scipy.misc import imsave
|
||||
|
||||
from utils.iotools import mkdir_if_missing, write_json, read_json
|
||||
from .bases import BaseImageDataset
|
||||
|
||||
|
||||
class CUHK03(BaseImageDataset):
|
||||
"""
|
||||
CUHK03
|
||||
Reference:
|
||||
Li et al. DeepReID: Deep Filter Pairing Neural Network for Person Re-identification. CVPR 2014.
|
||||
URL: http://www.ee.cuhk.edu.hk/~xgwang/CUHK_identification.html#!
|
||||
|
||||
Dataset statistics:
|
||||
# identities: 1360
|
||||
# images: 13164
|
||||
# cameras: 6
|
||||
# splits: 20 (classic)
|
||||
Args:
|
||||
split_id (int): split index (default: 0)
|
||||
cuhk03_labeled (bool): whether to load labeled images; if false, detected images are loaded (default: False)
|
||||
"""
|
||||
dataset_dir = 'cuhk03'
|
||||
|
||||
def __init__(self, root='/export/home/lxy/DATA/reid', split_id=0, cuhk03_labeled=False,
|
||||
cuhk03_classic_split=False, verbose=True,
|
||||
**kwargs):
|
||||
super(CUHK03, self).__init__()
|
||||
self.dataset_dir = osp.join(root, self.dataset_dir)
|
||||
self.data_dir = osp.join(self.dataset_dir, 'cuhk03_release')
|
||||
self.raw_mat_path = osp.join(self.data_dir, 'cuhk-03.mat')
|
||||
|
||||
self.imgs_detected_dir = osp.join(self.dataset_dir, 'images_detected')
|
||||
self.imgs_labeled_dir = osp.join(self.dataset_dir, 'images_labeled')
|
||||
|
||||
self.split_classic_det_json_path = osp.join(self.dataset_dir, 'splits_classic_detected.json')
|
||||
self.split_classic_lab_json_path = osp.join(self.dataset_dir, 'splits_classic_labeled.json')
|
||||
|
||||
self.split_new_det_json_path = osp.join(self.dataset_dir, 'splits_new_detected.json')
|
||||
self.split_new_lab_json_path = osp.join(self.dataset_dir, 'splits_new_labeled.json')
|
||||
|
||||
self.split_new_det_mat_path = osp.join(self.dataset_dir, 'cuhk03_new_protocol_config_detected.mat')
|
||||
self.split_new_lab_mat_path = osp.join(self.dataset_dir, 'cuhk03_new_protocol_config_labeled.mat')
|
||||
|
||||
self._check_before_run()
|
||||
self._preprocess()
|
||||
|
||||
if cuhk03_labeled:
|
||||
image_type = 'labeled'
|
||||
split_path = self.split_classic_lab_json_path if cuhk03_classic_split else self.split_new_lab_json_path
|
||||
else:
|
||||
image_type = 'detected'
|
||||
split_path = self.split_classic_det_json_path if cuhk03_classic_split else self.split_new_det_json_path
|
||||
|
||||
splits = read_json(split_path)
|
||||
assert split_id < len(splits), "Condition split_id ({}) < len(splits) ({}) is false".format(split_id,
|
||||
len(splits))
|
||||
split = splits[split_id]
|
||||
print("Split index = {}".format(split_id))
|
||||
|
||||
train = split['train']
|
||||
query = split['query']
|
||||
gallery = split['gallery']
|
||||
|
||||
if verbose:
|
||||
print("=> CUHK03 ({}) loaded".format(image_type))
|
||||
self.print_dataset_statistics(train, query, gallery)
|
||||
|
||||
self.train = train
|
||||
self.query = query
|
||||
self.gallery = gallery
|
||||
|
||||
self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)
|
||||
self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)
|
||||
self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)
|
||||
|
||||
def _check_before_run(self):
|
||||
"""Check if all files are available before going deeper"""
|
||||
if not osp.exists(self.dataset_dir):
|
||||
raise RuntimeError("'{}' is not available".format(self.dataset_dir))
|
||||
if not osp.exists(self.data_dir):
|
||||
raise RuntimeError("'{}' is not available".format(self.data_dir))
|
||||
if not osp.exists(self.raw_mat_path):
|
||||
raise RuntimeError("'{}' is not available".format(self.raw_mat_path))
|
||||
if not osp.exists(self.split_new_det_mat_path):
|
||||
raise RuntimeError("'{}' is not available".format(self.split_new_det_mat_path))
|
||||
if not osp.exists(self.split_new_lab_mat_path):
|
||||
raise RuntimeError("'{}' is not available".format(self.split_new_lab_mat_path))
|
||||
|
||||
def _preprocess(self):
|
||||
"""
|
||||
This function is a bit complex and ugly, what it does is
|
||||
1. Extract data from cuhk-03.mat and save as png images.
|
||||
2. Create 20 classic splits. (Li et al. CVPR'14)
|
||||
3. Create new split. (Zhong et al. CVPR'17)
|
||||
"""
|
||||
print(
|
||||
"Note: if root path is changed, the previously generated json files need to be re-generated (delete them first)")
|
||||
if osp.exists(self.imgs_labeled_dir) and \
|
||||
osp.exists(self.imgs_detected_dir) and \
|
||||
osp.exists(self.split_classic_det_json_path) and \
|
||||
osp.exists(self.split_classic_lab_json_path) and \
|
||||
osp.exists(self.split_new_det_json_path) and \
|
||||
osp.exists(self.split_new_lab_json_path):
|
||||
return
|
||||
|
||||
mkdir_if_missing(self.imgs_detected_dir)
|
||||
mkdir_if_missing(self.imgs_labeled_dir)
|
||||
|
||||
print("Extract image data from {} and save as png".format(self.raw_mat_path))
|
||||
mat = h5py.File(self.raw_mat_path, 'r')
|
||||
|
||||
def _deref(ref):
|
||||
return mat[ref][:].T
|
||||
|
||||
def _process_images(img_refs, campid, pid, save_dir):
|
||||
img_paths = [] # Note: some persons only have images for one view
|
||||
for imgid, img_ref in enumerate(img_refs):
|
||||
img = _deref(img_ref)
|
||||
# skip empty cell
|
||||
if img.size == 0 or img.ndim < 3: continue
|
||||
# images are saved with the following format, index-1 (ensure uniqueness)
|
||||
# campid: index of camera pair (1-5)
|
||||
# pid: index of person in 'campid'-th camera pair
|
||||
# viewid: index of view, {1, 2}
|
||||
# imgid: index of image, (1-10)
|
||||
viewid = 1 if imgid < 5 else 2
|
||||
img_name = '{:01d}_{:03d}_{:01d}_{:02d}.png'.format(campid + 1, pid + 1, viewid, imgid + 1)
|
||||
img_path = osp.join(save_dir, img_name)
|
||||
if not osp.isfile(img_path):
|
||||
imsave(img_path, img)
|
||||
img_paths.append(img_path)
|
||||
return img_paths
|
||||
|
||||
def _extract_img(name):
|
||||
print("Processing {} images (extract and save) ...".format(name))
|
||||
meta_data = []
|
||||
imgs_dir = self.imgs_detected_dir if name == 'detected' else self.imgs_labeled_dir
|
||||
for campid, camp_ref in enumerate(mat[name][0]):
|
||||
camp = _deref(camp_ref)
|
||||
num_pids = camp.shape[0]
|
||||
for pid in range(num_pids):
|
||||
img_paths = _process_images(camp[pid, :], campid, pid, imgs_dir)
|
||||
assert len(img_paths) > 0, "campid{}-pid{} has no images".format(campid, pid)
|
||||
meta_data.append((campid + 1, pid + 1, img_paths))
|
||||
print("- done camera pair {} with {} identities".format(campid + 1, num_pids))
|
||||
return meta_data
|
||||
|
||||
meta_detected = _extract_img('detected')
|
||||
meta_labeled = _extract_img('labeled')
|
||||
|
||||
def _extract_classic_split(meta_data, test_split):
|
||||
train, test = [], []
|
||||
num_train_pids, num_test_pids = 0, 0
|
||||
num_train_imgs, num_test_imgs = 0, 0
|
||||
for i, (campid, pid, img_paths) in enumerate(meta_data):
|
||||
|
||||
if [campid, pid] in test_split:
|
||||
for img_path in img_paths:
|
||||
camid = int(osp.basename(img_path).split('_')[2]) - 1 # make it 0-based
|
||||
test.append((img_path, num_test_pids, camid))
|
||||
num_test_pids += 1
|
||||
num_test_imgs += len(img_paths)
|
||||
else:
|
||||
for img_path in img_paths:
|
||||
camid = int(osp.basename(img_path).split('_')[2]) - 1 # make it 0-based
|
||||
train.append((img_path, num_train_pids, camid))
|
||||
num_train_pids += 1
|
||||
num_train_imgs += len(img_paths)
|
||||
return train, num_train_pids, num_train_imgs, test, num_test_pids, num_test_imgs
|
||||
|
||||
print("Creating classic splits (# = 20) ...")
|
||||
splits_classic_det, splits_classic_lab = [], []
|
||||
for split_ref in mat['testsets'][0]:
|
||||
test_split = _deref(split_ref).tolist()
|
||||
|
||||
# create split for detected images
|
||||
train, num_train_pids, num_train_imgs, test, num_test_pids, num_test_imgs = \
|
||||
_extract_classic_split(meta_detected, test_split)
|
||||
splits_classic_det.append({
|
||||
'train': train, 'query': test, 'gallery': test,
|
||||
'num_train_pids': num_train_pids, 'num_train_imgs': num_train_imgs,
|
||||
'num_query_pids': num_test_pids, 'num_query_imgs': num_test_imgs,
|
||||
'num_gallery_pids': num_test_pids, 'num_gallery_imgs': num_test_imgs,
|
||||
})
|
||||
|
||||
# create split for labeled images
|
||||
train, num_train_pids, num_train_imgs, test, num_test_pids, num_test_imgs = \
|
||||
_extract_classic_split(meta_labeled, test_split)
|
||||
splits_classic_lab.append({
|
||||
'train': train, 'query': test, 'gallery': test,
|
||||
'num_train_pids': num_train_pids, 'num_train_imgs': num_train_imgs,
|
||||
'num_query_pids': num_test_pids, 'num_query_imgs': num_test_imgs,
|
||||
'num_gallery_pids': num_test_pids, 'num_gallery_imgs': num_test_imgs,
|
||||
})
|
||||
|
||||
write_json(splits_classic_det, self.split_classic_det_json_path)
|
||||
write_json(splits_classic_lab, self.split_classic_lab_json_path)
|
||||
|
||||
def _extract_set(filelist, pids, pid2label, idxs, img_dir, relabel):
|
||||
tmp_set = []
|
||||
unique_pids = set()
|
||||
for idx in idxs:
|
||||
img_name = filelist[idx][0]
|
||||
camid = int(img_name.split('_')[2]) - 1 # make it 0-based
|
||||
pid = pids[idx]
|
||||
if relabel: pid = pid2label[pid]
|
||||
img_path = osp.join(img_dir, img_name)
|
||||
tmp_set.append((img_path, int(pid), camid))
|
||||
unique_pids.add(pid)
|
||||
return tmp_set, len(unique_pids), len(idxs)
|
||||
|
||||
def _extract_new_split(split_dict, img_dir):
|
||||
train_idxs = split_dict['train_idx'].flatten() - 1 # index-0
|
||||
pids = split_dict['labels'].flatten()
|
||||
train_pids = set(pids[train_idxs])
|
||||
pid2label = {pid: label for label, pid in enumerate(train_pids)}
|
||||
query_idxs = split_dict['query_idx'].flatten() - 1
|
||||
gallery_idxs = split_dict['gallery_idx'].flatten() - 1
|
||||
filelist = split_dict['filelist'].flatten()
|
||||
train_info = _extract_set(filelist, pids, pid2label, train_idxs, img_dir, relabel=True)
|
||||
query_info = _extract_set(filelist, pids, pid2label, query_idxs, img_dir, relabel=False)
|
||||
gallery_info = _extract_set(filelist, pids, pid2label, gallery_idxs, img_dir, relabel=False)
|
||||
return train_info, query_info, gallery_info
|
||||
|
||||
print("Creating new splits for detected images (767/700) ...")
|
||||
train_info, query_info, gallery_info = _extract_new_split(
|
||||
loadmat(self.split_new_det_mat_path),
|
||||
self.imgs_detected_dir,
|
||||
)
|
||||
splits = [{
|
||||
'train': train_info[0], 'query': query_info[0], 'gallery': gallery_info[0],
|
||||
'num_train_pids': train_info[1], 'num_train_imgs': train_info[2],
|
||||
'num_query_pids': query_info[1], 'num_query_imgs': query_info[2],
|
||||
'num_gallery_pids': gallery_info[1], 'num_gallery_imgs': gallery_info[2],
|
||||
}]
|
||||
write_json(splits, self.split_new_det_json_path)
|
||||
|
||||
print("Creating new splits for labeled images (767/700) ...")
|
||||
train_info, query_info, gallery_info = _extract_new_split(
|
||||
loadmat(self.split_new_lab_mat_path),
|
||||
self.imgs_labeled_dir,
|
||||
)
|
||||
splits = [{
|
||||
'train': train_info[0], 'query': query_info[0], 'gallery': gallery_info[0],
|
||||
'num_train_pids': train_info[1], 'num_train_imgs': train_info[2],
|
||||
'num_query_pids': query_info[1], 'num_query_imgs': query_info[2],
|
||||
'num_gallery_pids': gallery_info[1], 'num_gallery_imgs': gallery_info[2],
|
||||
}]
|
||||
write_json(splits, self.split_new_lab_json_path)
|
45
data/datasets/dataset_loader.py
Normal file
45
data/datasets/dataset_loader.py
Normal file
@ -0,0 +1,45 @@
|
||||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import os.path as osp
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
def read_image(img_path):
|
||||
"""Keep reading image until succeed.
|
||||
This can avoid IOError incurred by heavy IO process."""
|
||||
got_img = False
|
||||
if not osp.exists(img_path):
|
||||
raise IOError("{} does not exist".format(img_path))
|
||||
while not got_img:
|
||||
try:
|
||||
img = Image.open(img_path).convert('RGB')
|
||||
got_img = True
|
||||
except IOError:
|
||||
print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path))
|
||||
pass
|
||||
return img
|
||||
|
||||
|
||||
class ImageDataset(Dataset):
|
||||
"""Image Person ReID Dataset"""
|
||||
|
||||
def __init__(self, dataset, transform=None):
|
||||
self.dataset = dataset
|
||||
self.transform = transform
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def __getitem__(self, index):
|
||||
img_path, pid, camid = self.dataset[index]
|
||||
img = read_image(img_path)
|
||||
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
|
||||
return img, pid, camid, img_path
|
106
data/datasets/dukemtmcreid.py
Normal file
106
data/datasets/dukemtmcreid.py
Normal file
@ -0,0 +1,106 @@
|
||||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: liaoxingyu2@jd.com
|
||||
"""
|
||||
|
||||
import glob
|
||||
import re
|
||||
import urllib
|
||||
import zipfile
|
||||
|
||||
import os.path as osp
|
||||
|
||||
from utils.iotools import mkdir_if_missing
|
||||
from .bases import BaseImageDataset
|
||||
|
||||
|
||||
class DukeMTMCreID(BaseImageDataset):
|
||||
"""
|
||||
DukeMTMC-reID
|
||||
Reference:
|
||||
1. Ristani et al. Performance Measures and a Data Set for Multi-Target, Multi-Camera Tracking. ECCVW 2016.
|
||||
2. Zheng et al. Unlabeled Samples Generated by GAN Improve the Person Re-identification Baseline in vitro. ICCV 2017.
|
||||
URL: https://github.com/layumi/DukeMTMC-reID_evaluation
|
||||
|
||||
Dataset statistics:
|
||||
# identities: 1404 (train + query)
|
||||
# images:16522 (train) + 2228 (query) + 17661 (gallery)
|
||||
# cameras: 8
|
||||
"""
|
||||
dataset_dir = 'dukemtmc-reid'
|
||||
|
||||
def __init__(self, root='/export/home/lxy/DATA/reid', verbose=True, **kwargs):
|
||||
super(DukeMTMCreID, self).__init__()
|
||||
self.dataset_dir = osp.join(root, self.dataset_dir)
|
||||
self.dataset_url = 'http://vision.cs.duke.edu/DukeMTMC/data/misc/DukeMTMC-reID.zip'
|
||||
self.train_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/bounding_box_train')
|
||||
self.query_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/query')
|
||||
self.gallery_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/bounding_box_test')
|
||||
|
||||
self._download_data()
|
||||
self._check_before_run()
|
||||
|
||||
train = self._process_dir(self.train_dir, relabel=True)
|
||||
query = self._process_dir(self.query_dir, relabel=False)
|
||||
gallery = self._process_dir(self.gallery_dir, relabel=False)
|
||||
|
||||
if verbose:
|
||||
print("=> DukeMTMC-reID loaded")
|
||||
self.print_dataset_statistics(train, query, gallery)
|
||||
|
||||
self.train = train
|
||||
self.query = query
|
||||
self.gallery = gallery
|
||||
|
||||
self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)
|
||||
self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)
|
||||
self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)
|
||||
|
||||
def _download_data(self):
|
||||
if osp.exists(self.dataset_dir):
|
||||
print("This dataset has been downloaded.")
|
||||
return
|
||||
|
||||
print("Creating directory {}".format(self.dataset_dir))
|
||||
mkdir_if_missing(self.dataset_dir)
|
||||
fpath = osp.join(self.dataset_dir, osp.basename(self.dataset_url))
|
||||
|
||||
print("Downloading DukeMTMC-reID dataset")
|
||||
urllib.urlretrieve(self.dataset_url, fpath)
|
||||
|
||||
print("Extracting files")
|
||||
zip_ref = zipfile.ZipFile(fpath, 'r')
|
||||
zip_ref.extractall(self.dataset_dir)
|
||||
zip_ref.close()
|
||||
|
||||
def _check_before_run(self):
|
||||
"""Check if all files are available before going deeper"""
|
||||
if not osp.exists(self.dataset_dir):
|
||||
raise RuntimeError("'{}' is not available".format(self.dataset_dir))
|
||||
if not osp.exists(self.train_dir):
|
||||
raise RuntimeError("'{}' is not available".format(self.train_dir))
|
||||
if not osp.exists(self.query_dir):
|
||||
raise RuntimeError("'{}' is not available".format(self.query_dir))
|
||||
if not osp.exists(self.gallery_dir):
|
||||
raise RuntimeError("'{}' is not available".format(self.gallery_dir))
|
||||
|
||||
def _process_dir(self, dir_path, relabel=False):
|
||||
img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
|
||||
pattern = re.compile(r'([-\d]+)_c(\d)')
|
||||
|
||||
pid_container = set()
|
||||
for img_path in img_paths:
|
||||
pid, _ = map(int, pattern.search(img_path).groups())
|
||||
pid_container.add(pid)
|
||||
pid2label = {pid: label for label, pid in enumerate(pid_container)}
|
||||
|
||||
dataset = []
|
||||
for img_path in img_paths:
|
||||
pid, camid = map(int, pattern.search(img_path).groups())
|
||||
assert 1 <= camid <= 8
|
||||
camid -= 1 # index starts from 0
|
||||
if relabel: pid = pid2label[pid]
|
||||
dataset.append((img_path, pid, camid))
|
||||
|
||||
return dataset
|
63
data/datasets/eval_reid.py
Normal file
63
data/datasets/eval_reid.py
Normal file
@ -0,0 +1,63 @@
|
||||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def eval_func(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50):
|
||||
"""Evaluation with market1501 metric
|
||||
Key: for each query identity, its gallery images from the same camera view are discarded.
|
||||
"""
|
||||
num_q, num_g = distmat.shape
|
||||
if num_g < max_rank:
|
||||
max_rank = num_g
|
||||
print("Note: number of gallery samples is quite small, got {}".format(num_g))
|
||||
indices = np.argsort(distmat, axis=1)
|
||||
matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)
|
||||
|
||||
# compute cmc curve for each query
|
||||
all_cmc = []
|
||||
all_AP = []
|
||||
num_valid_q = 0. # number of valid query
|
||||
for q_idx in range(num_q):
|
||||
# get query pid and camid
|
||||
q_pid = q_pids[q_idx]
|
||||
q_camid = q_camids[q_idx]
|
||||
|
||||
# remove gallery samples that have the same pid and camid with query
|
||||
order = indices[q_idx]
|
||||
remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid)
|
||||
keep = np.invert(remove)
|
||||
|
||||
# compute cmc curve
|
||||
# binary vector, positions with value 1 are correct matches
|
||||
orig_cmc = matches[q_idx][keep]
|
||||
if not np.any(orig_cmc):
|
||||
# this condition is true when query identity does not appear in gallery
|
||||
continue
|
||||
|
||||
cmc = orig_cmc.cumsum()
|
||||
cmc[cmc > 1] = 1
|
||||
|
||||
all_cmc.append(cmc[:max_rank])
|
||||
num_valid_q += 1.
|
||||
|
||||
# compute average precision
|
||||
# reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision
|
||||
num_rel = orig_cmc.sum()
|
||||
tmp_cmc = orig_cmc.cumsum()
|
||||
tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)]
|
||||
tmp_cmc = np.asarray(tmp_cmc) * orig_cmc
|
||||
AP = tmp_cmc.sum() / num_rel
|
||||
all_AP.append(AP)
|
||||
|
||||
assert num_valid_q > 0, "Error: all query identities do not appear in gallery"
|
||||
|
||||
all_cmc = np.asarray(all_cmc).astype(np.float32)
|
||||
all_cmc = all_cmc.sum(0) / num_valid_q
|
||||
mAP = np.mean(all_AP)
|
||||
|
||||
return all_cmc, mAP
|
70
core/data_manager.py → data/datasets/market1501.py
Executable file → Normal file
70
core/data_manager.py → data/datasets/market1501.py
Executable file → Normal file
@ -1,13 +1,18 @@
|
||||
from __future__ import print_function, absolute_import
|
||||
# encoding: utf-8
|
||||
"""
|
||||
@author: sherlock
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import glob
|
||||
import re
|
||||
from os import path as osp
|
||||
|
||||
"""Dataset classes"""
|
||||
import os.path as osp
|
||||
|
||||
from .bases import BaseImageDataset
|
||||
|
||||
|
||||
class Market1501(object):
|
||||
class Market1501(BaseImageDataset):
|
||||
"""
|
||||
Market1501
|
||||
Reference:
|
||||
@ -18,9 +23,10 @@ class Market1501(object):
|
||||
# identities: 1501 (+1 for background)
|
||||
# images: 12936 (train) + 3368 (query) + 15913 (gallery)
|
||||
"""
|
||||
dataset_dir = 'Market-1501-v15.09.15'
|
||||
dataset_dir = 'market1501'
|
||||
|
||||
def __init__(self, root='/home/test2/DATA/market1501/raw/'):
|
||||
def __init__(self, root='/export/home/lxy/DATA/reid', verbose=True, **kwargs):
|
||||
super(Market1501, self).__init__()
|
||||
self.dataset_dir = osp.join(root, self.dataset_dir)
|
||||
self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train')
|
||||
self.query_dir = osp.join(self.dataset_dir, 'query')
|
||||
@ -28,31 +34,21 @@ class Market1501(object):
|
||||
|
||||
self._check_before_run()
|
||||
|
||||
train, num_train_pids, num_train_imgs = self._process_dir(self.train_dir, relabel=True)
|
||||
query, num_query_pids, num_query_imgs = self._process_dir(self.query_dir, relabel=False)
|
||||
gallery, num_gallery_pids, num_gallery_imgs = self._process_dir(self.gallery_dir, relabel=False)
|
||||
num_total_pids = num_train_pids + num_query_pids
|
||||
num_total_imgs = num_train_imgs + num_query_imgs + num_gallery_imgs
|
||||
train = self._process_dir(self.train_dir, relabel=True)
|
||||
query = self._process_dir(self.query_dir, relabel=False)
|
||||
gallery = self._process_dir(self.gallery_dir, relabel=False)
|
||||
|
||||
print("=> Market1501 loaded")
|
||||
print("Dataset statistics:")
|
||||
print(" ------------------------------")
|
||||
print(" subset | # ids | # images")
|
||||
print(" ------------------------------")
|
||||
print(" train | {:5d} | {:8d}".format(num_train_pids, num_train_imgs))
|
||||
print(" query | {:5d} | {:8d}".format(num_query_pids, num_query_imgs))
|
||||
print(" gallery | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_imgs))
|
||||
print(" ------------------------------")
|
||||
print(" total | {:5d} | {:8d}".format(num_total_pids, num_total_imgs))
|
||||
print(" ------------------------------")
|
||||
if verbose:
|
||||
print("=> Market1501 loaded")
|
||||
self.print_dataset_statistics(train, query, gallery)
|
||||
|
||||
self.train = train
|
||||
self.query = query
|
||||
self.gallery = gallery
|
||||
|
||||
self.num_train_pids = num_train_pids
|
||||
self.num_query_pids = num_query_pids
|
||||
self.num_gallery_pids = num_gallery_pids
|
||||
self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)
|
||||
self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)
|
||||
self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)
|
||||
|
||||
def _check_before_run(self):
|
||||
"""Check if all files are available before going deeper"""
|
||||
@ -79,31 +75,11 @@ class Market1501(object):
|
||||
dataset = []
|
||||
for img_path in img_paths:
|
||||
pid, camid = map(int, pattern.search(img_path).groups())
|
||||
if pid == -1:
|
||||
continue # junk images are just ignored
|
||||
if pid == -1: continue # junk images are just ignored
|
||||
assert 0 <= pid <= 1501 # pid == 0 means background
|
||||
assert 1 <= camid <= 6
|
||||
camid -= 1 # index starts from 0
|
||||
if relabel: pid = pid2label[pid]
|
||||
dataset.append((img_path, pid, camid))
|
||||
|
||||
num_pids = len(pid_container)
|
||||
num_imgs = len(dataset)
|
||||
return dataset, num_pids, num_imgs
|
||||
|
||||
|
||||
"""Create datasets"""
|
||||
|
||||
__factory = {
|
||||
'market1501': Market1501
|
||||
}
|
||||
|
||||
|
||||
def get_names():
|
||||
return __factory.keys()
|
||||
|
||||
|
||||
def init_dataset(name, *args, **kwargs):
|
||||
if name not in __factory.keys():
|
||||
raise KeyError("Unknown datasets: {}".format(name))
|
||||
return __factory[name](*args, **kwargs)
|
||||
return dataset
|
7
data/samplers/__init__.py
Normal file
7
data/samplers/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from .triplet_sampler import RandomIdentitySampler
|
73
data/samplers/triplet_sampler.py
Normal file
73
data/samplers/triplet_sampler.py
Normal file
@ -0,0 +1,73 @@
|
||||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: liaoxingyu2@jd.com
|
||||
"""
|
||||
|
||||
import copy
|
||||
import random
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
from torch.utils.data.sampler import Sampler
|
||||
|
||||
|
||||
class RandomIdentitySampler(Sampler):
|
||||
"""
|
||||
Randomly sample N identities, then for each identity,
|
||||
randomly sample K instances, therefore batch size is N*K.
|
||||
Args:
|
||||
- data_source (list): list of (img_path, pid, camid).
|
||||
- num_instances (int): number of instances per identity in a batch.
|
||||
- batch_size (int): number of examples in a batch.
|
||||
"""
|
||||
|
||||
def __init__(self, data_source, batch_size, num_instances):
|
||||
self.data_source = data_source
|
||||
self.batch_size = batch_size
|
||||
self.num_instances = num_instances
|
||||
self.num_pids_per_batch = self.batch_size // self.num_instances
|
||||
self.index_dic = defaultdict(list)
|
||||
for index, (_, pid, _) in enumerate(self.data_source):
|
||||
self.index_dic[pid].append(index)
|
||||
self.pids = list(self.index_dic.keys())
|
||||
|
||||
# estimate number of examples in an epoch
|
||||
self.length = 0
|
||||
for pid in self.pids:
|
||||
idxs = self.index_dic[pid]
|
||||
num = len(idxs)
|
||||
if num < self.num_instances:
|
||||
num = self.num_instances
|
||||
self.length += num - num % self.num_instances
|
||||
|
||||
def __iter__(self):
|
||||
batch_idxs_dict = defaultdict(list)
|
||||
|
||||
for pid in self.pids:
|
||||
idxs = copy.deepcopy(self.index_dic[pid])
|
||||
if len(idxs) < self.num_instances:
|
||||
idxs = np.random.choice(idxs, size=self.num_instances, replace=True)
|
||||
random.shuffle(idxs)
|
||||
batch_idxs = []
|
||||
for idx in idxs:
|
||||
batch_idxs.append(idx)
|
||||
if len(batch_idxs) == self.num_instances:
|
||||
batch_idxs_dict[pid].append(batch_idxs)
|
||||
batch_idxs = []
|
||||
|
||||
avai_pids = copy.deepcopy(self.pids)
|
||||
final_idxs = []
|
||||
|
||||
while len(avai_pids) >= self.num_pids_per_batch:
|
||||
selected_pids = random.sample(avai_pids, self.num_pids_per_batch)
|
||||
for pid in selected_pids:
|
||||
batch_idxs = batch_idxs_dict[pid].pop(0)
|
||||
final_idxs.extend(batch_idxs)
|
||||
if len(batch_idxs_dict[pid]) == 0:
|
||||
avai_pids.remove(pid)
|
||||
|
||||
return iter(final_idxs)
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
7
data/transforms/__init__.py
Normal file
7
data/transforms/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
# encoding: utf-8
|
||||
"""
|
||||
@author: sherlock
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from .build import build_transforms
|
31
data/transforms/build.py
Normal file
31
data/transforms/build.py
Normal file
@ -0,0 +1,31 @@
|
||||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: liaoxingyu2@jd.com
|
||||
"""
|
||||
|
||||
import torchvision.transforms as T
|
||||
|
||||
from .transforms import RandomErasing
|
||||
|
||||
|
||||
def build_transforms(cfg, is_train=True):
|
||||
normalize_transform = T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD)
|
||||
if is_train:
|
||||
transform = T.Compose([
|
||||
T.Resize(cfg.INPUT.SIZE_TRAIN),
|
||||
T.RandomHorizontalFlip(p=cfg.INPUT.PROB),
|
||||
T.Pad(cfg.INPUT.PADDING),
|
||||
T.RandomCrop(cfg.INPUT.SIZE_TRAIN),
|
||||
T.ToTensor(),
|
||||
normalize_transform,
|
||||
RandomErasing(probability=cfg.INPUT.PROB, mean=cfg.INPUT.PIXEL_MEAN)
|
||||
])
|
||||
else:
|
||||
transform = T.Compose([
|
||||
T.Resize(cfg.INPUT.SIZE_TEST),
|
||||
T.ToTensor(),
|
||||
normalize_transform
|
||||
])
|
||||
|
||||
return transform
|
@ -1,57 +1,12 @@
|
||||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
@contact: liaoxingyu2@jd.com
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import math
|
||||
import random
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class Random2DTranslation(object):
|
||||
"""
|
||||
With a probability, first increase image size to (1 + 1/8), and then perform random crop.
|
||||
|
||||
Args:
|
||||
height (int): target height.
|
||||
width (int): target width.
|
||||
p (float): probability of performing this transformation. Default: 0.5.
|
||||
"""
|
||||
|
||||
def __init__(self, height, width, p=0.5, interpolation=Image.BILINEAR):
|
||||
self.height = height
|
||||
self.width = width
|
||||
self.p = p
|
||||
self.interpolation = interpolation
|
||||
|
||||
def __call__(self, img):
|
||||
"""
|
||||
Args:
|
||||
img (PIL Image): Image to be cropped.
|
||||
|
||||
Returns:
|
||||
PIL Image: Cropped image.
|
||||
"""
|
||||
if random.random() < self.p:
|
||||
return img.resize((self.width, self.height), self.interpolation)
|
||||
new_width, new_height = int(
|
||||
round(self.width * 1.125)), int(round(self.height * 1.125))
|
||||
resized_img = img.resize((new_width, new_height), self.interpolation)
|
||||
x_maxrange = new_width - self.width
|
||||
y_maxrange = new_height - self.height
|
||||
x1 = int(round(random.uniform(0, x_maxrange)))
|
||||
y1 = int(round(random.uniform(0, y_maxrange)))
|
||||
croped_img = resized_img.crop(
|
||||
(x1, y1, x1 + self.width, y1 + self.height))
|
||||
return croped_img
|
||||
|
||||
|
||||
class RandomErasing(object):
|
||||
""" Randomly selects a rectangle region in an image and erases its pixels.
|
||||
@ -65,7 +20,7 @@ class RandomErasing(object):
|
||||
mean: Erasing value.
|
||||
"""
|
||||
|
||||
def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=[0.4914, 0.4822, 0.4465]):
|
||||
def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=(0.4914, 0.4822, 0.4465)):
|
||||
self.probability = probability
|
||||
self.mean = mean
|
||||
self.sl = sl
|
64
engine/inference.py
Normal file
64
engine/inference.py
Normal file
@ -0,0 +1,64 @@
|
||||
# encoding: utf-8
|
||||
"""
|
||||
@author: sherlock
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from ignite.engine import Engine
|
||||
|
||||
from utils.reid_metric import R1_mAP
|
||||
|
||||
|
||||
def create_supervised_evaluator(model, metrics,
|
||||
device=None):
|
||||
"""
|
||||
Factory function for creating an evaluator for supervised models
|
||||
|
||||
Args:
|
||||
model (`torch.nn.Module`): the model to train
|
||||
metrics (dict of str - :class:`ignite.metrics.Metric`): a map of metric names to Metrics
|
||||
device (str, optional): device type specification (default: None).
|
||||
Applies to both model and batches.
|
||||
Returns:
|
||||
Engine: an evaluator engine with supervised inference function
|
||||
"""
|
||||
if device:
|
||||
model.to(device)
|
||||
|
||||
def _inference(engine, batch):
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
data, pids, camids = batch
|
||||
data = data.cuda()
|
||||
feat = model(data)
|
||||
return feat, pids, camids
|
||||
|
||||
engine = Engine(_inference)
|
||||
|
||||
for name, metric in metrics.items():
|
||||
metric.attach(engine, name)
|
||||
|
||||
return engine
|
||||
|
||||
|
||||
def inference(
|
||||
cfg,
|
||||
model,
|
||||
val_loader,
|
||||
num_query
|
||||
):
|
||||
device = cfg.MODEL.DEVICE
|
||||
|
||||
logger = logging.getLogger("reid_baseline.inference")
|
||||
logger.info("Start inferencing")
|
||||
evaluator = create_supervised_evaluator(model, metrics={'r1_mAP': R1_mAP(num_query)},
|
||||
device=device)
|
||||
|
||||
evaluator.run(val_loader)
|
||||
cmc, mAP = evaluator.state.metrics['r1_mAP']
|
||||
logger.info('Validation Results')
|
||||
logger.info("mAP: {:.1%}".format(mAP))
|
||||
for r in [1, 5, 10]:
|
||||
logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1]))
|
150
engine/trainer.py
Normal file
150
engine/trainer.py
Normal file
@ -0,0 +1,150 @@
|
||||
# encoding: utf-8
|
||||
"""
|
||||
@author: sherlock
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from ignite.engine import Engine, Events
|
||||
from ignite.handlers import ModelCheckpoint, Timer
|
||||
from ignite.metrics import RunningAverage
|
||||
|
||||
from utils.reid_metric import R1_mAP
|
||||
|
||||
|
||||
def create_supervised_trainer(model, optimizer, loss_fn,
|
||||
device=None):
|
||||
"""
|
||||
Factory function for creating a trainer for supervised models
|
||||
|
||||
Args:
|
||||
model (`torch.nn.Module`): the model to train
|
||||
optimizer (`torch.optim.Optimizer`): the optimizer to use
|
||||
loss_fn (torch.nn loss function): the loss function to use
|
||||
device (str, optional): device type specification (default: None).
|
||||
Applies to both model and batches.
|
||||
|
||||
Returns:
|
||||
Engine: a trainer engine with supervised update function
|
||||
"""
|
||||
if device:
|
||||
model.to(device)
|
||||
|
||||
def _update(engine, batch):
|
||||
model.train()
|
||||
optimizer.zero_grad()
|
||||
img, target = batch
|
||||
img = img.cuda()
|
||||
target = target.cuda()
|
||||
score, feat = model(img)
|
||||
loss = loss_fn(score, feat, target)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
# compute acc
|
||||
acc = (score.max(1)[1] == target).float().mean()
|
||||
return loss.item(), acc.item()
|
||||
|
||||
return Engine(_update)
|
||||
|
||||
|
||||
def create_supervised_evaluator(model, metrics,
|
||||
device=None):
|
||||
"""
|
||||
Factory function for creating an evaluator for supervised models
|
||||
|
||||
Args:
|
||||
model (`torch.nn.Module`): the model to train
|
||||
metrics (dict of str - :class:`ignite.metrics.Metric`): a map of metric names to Metrics
|
||||
device (str, optional): device type specification (default: None).
|
||||
Applies to both model and batches.
|
||||
Returns:
|
||||
Engine: an evaluator engine with supervised inference function
|
||||
"""
|
||||
if device:
|
||||
model.to(device)
|
||||
|
||||
def _inference(engine, batch):
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
data, pids, camids = batch
|
||||
data = data.cuda()
|
||||
feat = model(data)
|
||||
return feat, pids, camids
|
||||
|
||||
engine = Engine(_inference)
|
||||
|
||||
for name, metric in metrics.items():
|
||||
metric.attach(engine, name)
|
||||
|
||||
return engine
|
||||
|
||||
|
||||
def do_train(
|
||||
cfg,
|
||||
model,
|
||||
train_loader,
|
||||
val_loader,
|
||||
optimizer,
|
||||
scheduler,
|
||||
loss_fn,
|
||||
num_query
|
||||
):
|
||||
log_period = cfg.SOLVER.LOG_PERIOD
|
||||
checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
|
||||
eval_period = cfg.SOLVER.EVAL_PERIOD
|
||||
output_dir = cfg.OUTPUT_DIR
|
||||
device = cfg.MODEL.DEVICE
|
||||
epochs = cfg.SOLVER.MAX_EPOCHS
|
||||
|
||||
logger = logging.getLogger("reid_baseline.train")
|
||||
logger.info("Start training")
|
||||
trainer = create_supervised_trainer(model, optimizer, loss_fn, device=device)
|
||||
evaluator = create_supervised_evaluator(model, metrics={'r1_mAP': R1_mAP(num_query)}, device=device)
|
||||
checkpointer = ModelCheckpoint(output_dir, cfg.MODEL.NAME, checkpoint_period, n_saved=10, require_empty=False)
|
||||
timer = Timer(average=True)
|
||||
|
||||
trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'model': model.state_dict(),
|
||||
'optimizer': optimizer.state_dict()})
|
||||
timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED,
|
||||
pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED)
|
||||
|
||||
# average metric to attach on trainer
|
||||
RunningAverage(output_transform=lambda x: x[0]).attach(trainer, 'avg_loss')
|
||||
RunningAverage(output_transform=lambda x: x[1]).attach(trainer, 'avg_acc')
|
||||
|
||||
@trainer.on(Events.EPOCH_STARTED)
|
||||
def adjust_learning_rate(engine):
|
||||
scheduler.step()
|
||||
|
||||
@trainer.on(Events.ITERATION_COMPLETED)
|
||||
def log_training_loss(engine):
|
||||
iter = (engine.state.iteration - 1) % len(train_loader) + 1
|
||||
|
||||
if iter % log_period == 0:
|
||||
logger.info("Epoch[{}] Iteration[{}/{}] Loss: {:.3f}, Acc: {:.3f}, Base Lr: {:.2e}"
|
||||
.format(engine.state.epoch, iter, len(train_loader),
|
||||
engine.state.metrics['avg_loss'], engine.state.metrics['avg_acc'],
|
||||
scheduler.get_lr()[0]))
|
||||
|
||||
# adding handlers using `trainer.on` decorator API
|
||||
@trainer.on(Events.EPOCH_COMPLETED)
|
||||
def print_times(engine):
|
||||
logger.info('Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]'
|
||||
.format(engine.state.epoch, timer.value() * timer.step_count,
|
||||
train_loader.batch_size / timer.value()))
|
||||
logger.info('-' * 10)
|
||||
timer.reset()
|
||||
|
||||
@trainer.on(Events.EPOCH_COMPLETED)
|
||||
def log_validation_results(engine):
|
||||
if engine.state.epoch % eval_period == 0:
|
||||
evaluator.run(val_loader)
|
||||
cmc, mAP = evaluator.state.metrics['r1_mAP']
|
||||
logger.info("Validation Results - Epoch: {}".format(engine.state.epoch))
|
||||
logger.info("mAP: {:.1%}".format(mAP))
|
||||
for r in [1, 5, 10]:
|
||||
logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1]))
|
||||
|
||||
trainer.run(train_loader, max_epochs=epochs)
|
28
layers/__init__.py
Normal file
28
layers/__init__.py
Normal file
@ -0,0 +1,28 @@
|
||||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .triplet_loss import TripletLoss
|
||||
|
||||
|
||||
def make_loss(cfg):
|
||||
sampler = cfg.DATALOADER.SAMPLER
|
||||
triplet = TripletLoss(cfg.SOLVER.MARGIN) # triplet loss
|
||||
|
||||
if sampler == 'softmax':
|
||||
def loss_func(score, feat, target):
|
||||
return F.cross_entropy(score, target)
|
||||
elif cfg.DATALOADER.SAMPLER == 'triplet':
|
||||
def loss_func(score, feat, target):
|
||||
return triplet(feat, target)[0]
|
||||
elif cfg.DATALOADER.SAMPLER == 'softmax_triplet':
|
||||
def loss_func(score, feat, target):
|
||||
return F.cross_entropy(score, target) + triplet(feat, target)[0]
|
||||
else:
|
||||
print('expected sampler should be softmax, triplet or softmax_triplet, '
|
||||
'but got {}'.format(cfg.DATALOADER.SAMPLER))
|
||||
return loss_func
|
@ -1,17 +1,10 @@
|
||||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: xyliao1993@qq.com
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def normalize(x, axis=-1):
|
||||
@ -121,34 +114,3 @@ class TripletLoss(object):
|
||||
else:
|
||||
loss = self.ranking_loss(dist_an - dist_ap, y)
|
||||
return loss, dist_ap, dist_an
|
||||
|
||||
|
||||
class CrossEntropyLabelSmooth(nn.Module):
|
||||
"""Cross entropy loss with label smoothing regularizer.
|
||||
Reference:
|
||||
Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016.
|
||||
Equation: y = (1 - epsilon) * y + epsilon / K.
|
||||
Args:
|
||||
num_classes (int): number of classes.
|
||||
epsilon (float): weight.
|
||||
"""
|
||||
|
||||
def __init__(self, num_classes, epsilon=0.1, use_gpu=True):
|
||||
super(CrossEntropyLabelSmooth, self).__init__()
|
||||
self.num_classes = num_classes
|
||||
self.epsilon = epsilon
|
||||
self.use_gpu = use_gpu
|
||||
self.logsoftmax = nn.LogSoftmax(dim=1)
|
||||
|
||||
def forward(self, inputs, targets):
|
||||
"""
|
||||
Args:
|
||||
inputs: prediction matrix (before softmax) with shape (batch_size, num_classes)
|
||||
targets: ground truth labels with shape (num_classes)
|
||||
"""
|
||||
log_probs = self.logsoftmax(inputs)
|
||||
targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).cpu(), 1)
|
||||
if self.use_gpu: targets = targets.cuda()
|
||||
targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
|
||||
loss = (- targets * log_probs).mean(0).sum()
|
||||
return loss
|
13
modeling/__init__.py
Normal file
13
modeling/__init__.py
Normal file
@ -0,0 +1,13 @@
|
||||
# encoding: utf-8
|
||||
"""
|
||||
@author: sherlock
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from .baseline import Baseline
|
||||
|
||||
|
||||
def build_model(cfg, num_classes):
|
||||
if cfg.MODEL.NAME == 'resnet50':
|
||||
model = Baseline(num_classes, cfg.MODEL.LAST_STRIDE, cfg.MODEL.PRETRAIN_PATH)
|
||||
return model
|
6
modeling/backbones/__init__.py
Normal file
6
modeling/backbones/__init__.py
Normal file
@ -0,0 +1,6 @@
|
||||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
@ -1,17 +1,12 @@
|
||||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: liaoxingyu@megvii.com
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import math
|
||||
|
||||
import torch as th
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
@ -98,7 +93,7 @@ class ResNet(nn.Module):
|
||||
return x
|
||||
|
||||
def load_param(self, model_path):
|
||||
param_dict = th.load(model_path)
|
||||
param_dict = torch.load(model_path)
|
||||
for i in param_dict:
|
||||
if 'fc' in i:
|
||||
continue
|
||||
@ -112,11 +107,3 @@ class ResNet(nn.Module):
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
net = ResNet(last_stride=2)
|
||||
import torch
|
||||
|
||||
x = net(torch.zeros(1, 3, 256, 128))
|
||||
print(x.shape)
|
@ -1,17 +1,12 @@
|
||||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: xyliao1993@qq.com
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from torch import nn
|
||||
|
||||
from .resnet import ResNet
|
||||
from .backbones.resnet import ResNet
|
||||
|
||||
|
||||
def weights_init_kaiming(m):
|
||||
@ -40,11 +35,12 @@ def weights_init_classifier(m):
|
||||
class Baseline(nn.Module):
|
||||
in_planes = 2048
|
||||
|
||||
def __init__(self, num_classes=10, last_stride=1, model_path='/home/test2/.torch/models/resnet50-19c8e357.pth'):
|
||||
def __init__(self, num_classes, last_stride, model_path):
|
||||
super(Baseline, self).__init__()
|
||||
self.base = ResNet(last_stride)
|
||||
self.base.load_param(model_path)
|
||||
self.gap = nn.AdaptiveAvgPool2d(1)
|
||||
# self.gap = nn.AdaptiveMaxPool2d(1)
|
||||
self.num_classes = num_classes
|
||||
|
||||
self.bottleneck = nn.BatchNorm1d(self.in_planes)
|
||||
@ -63,15 +59,3 @@ class Baseline(nn.Module):
|
||||
return cls_score, global_feat # global feature for triplet loss
|
||||
else:
|
||||
return feat
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# net = Baseline(751).cuda(1)
|
||||
import torch
|
||||
|
||||
net = ResNet(1).cuda(1)
|
||||
x = torch.ones(128, 3, 256, 128).cuda(1)
|
||||
y = net(x)
|
||||
from IPython import embed
|
||||
|
||||
embed()
|
@ -1,13 +0,0 @@
|
||||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: xyliao1993@qq.com
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
|
||||
from .baseline import Baseline
|
@ -1,5 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
|
||||
python3 tools/test.py --config_file='configs/market_softmax_triplet.yml' \
|
||||
--load_model='/home/test2/liaoxingyu/pytorch-ckpt/reid/market_softmax_triplet/350_Baseline350.pth.tar'
|
@ -1,8 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
checkpoint_dir=/home/test2/liaoxingyu/pytorch-ckpt/reid/market_softmax/
|
||||
mkdir -p ${checkpoint_dir}
|
||||
|
||||
python3 tools/train.py --config_file='configs/market_softmax.yml' \
|
||||
--save_dir=${checkpoint_dir} | tee ${checkpoint_dir}/train.log
|
||||
|
@ -1,8 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
checkpoint_dir=/home/test2/liaoxingyu/pytorch-ckpt/reid/market_softmax_triplet/
|
||||
mkdir -p ${checkpoint_dir}
|
||||
|
||||
python3 tools/train.py --config_file='configs/market_softmax_triplet.yml' \
|
||||
--save_dir=${checkpoint_dir} | tee ${checkpoint_dir}/train.log
|
||||
|
@ -1,8 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
checkpoint_dir=/home/test2/liaoxingyu/pytorch-ckpt/reid/market_triplet/
|
||||
mkdir -p ${checkpoint_dir}
|
||||
|
||||
python3 tools/train.py --config_file='configs/market_triplet.yml' \
|
||||
--save_dir=${checkpoint_dir} | tee ${checkpoint_dir}/train.log
|
||||
|
8
solver/__init__.py
Normal file
8
solver/__init__.py
Normal file
@ -0,0 +1,8 @@
|
||||
# encoding: utf-8
|
||||
"""
|
||||
@author: sherlock
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from .build import make_optimizer
|
||||
from .lr_scheduler import WarmupMultiStepLR
|
25
solver/build.py
Normal file
25
solver/build.py
Normal file
@ -0,0 +1,25 @@
|
||||
# encoding: utf-8
|
||||
"""
|
||||
@author: sherlock
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def make_optimizer(cfg, model):
|
||||
params = []
|
||||
for key, value in model.named_parameters():
|
||||
if not value.requires_grad:
|
||||
continue
|
||||
lr = cfg.SOLVER.BASE_LR
|
||||
weight_decay = cfg.SOLVER.WEIGHT_DECAY
|
||||
if "bias" in key:
|
||||
lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR
|
||||
weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS
|
||||
params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]
|
||||
if cfg.SOLVER.OPTIMIZER_NAME == 'SGD':
|
||||
optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params, momentum=cfg.SOLVER.MOMENTUM)
|
||||
else:
|
||||
optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params)
|
||||
return optimizer
|
56
solver/lr_scheduler.py
Normal file
56
solver/lr_scheduler.py
Normal file
@ -0,0 +1,56 @@
|
||||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
from bisect import bisect_right
|
||||
import torch
|
||||
|
||||
|
||||
# FIXME ideally this would be achieved with a CombinedLRScheduler,
|
||||
# separating MultiStepLR with WarmupLR
|
||||
# but the current LRScheduler design doesn't allow it
|
||||
|
||||
class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler):
|
||||
def __init__(
|
||||
self,
|
||||
optimizer,
|
||||
milestones,
|
||||
gamma=0.1,
|
||||
warmup_factor=1.0 / 3,
|
||||
warmup_iters=500,
|
||||
warmup_method="linear",
|
||||
last_epoch=-1,
|
||||
):
|
||||
if not list(milestones) == sorted(milestones):
|
||||
raise ValueError(
|
||||
"Milestones should be a list of" " increasing integers. Got {}",
|
||||
milestones,
|
||||
)
|
||||
|
||||
if warmup_method not in ("constant", "linear"):
|
||||
raise ValueError(
|
||||
"Only 'constant' or 'linear' warmup_method accepted"
|
||||
"got {}".format(warmup_method)
|
||||
)
|
||||
self.milestones = milestones
|
||||
self.gamma = gamma
|
||||
self.warmup_factor = warmup_factor
|
||||
self.warmup_iters = warmup_iters
|
||||
self.warmup_method = warmup_method
|
||||
super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch)
|
||||
|
||||
def get_lr(self):
|
||||
warmup_factor = 1
|
||||
if self.last_epoch < self.warmup_iters:
|
||||
if self.warmup_method == "constant":
|
||||
warmup_factor = self.warmup_factor
|
||||
elif self.warmup_method == "linear":
|
||||
alpha = self.last_epoch / self.warmup_iters
|
||||
warmup_factor = self.warmup_factor * (1 - alpha) + alpha
|
||||
return [
|
||||
base_lr
|
||||
* warmup_factor
|
||||
* self.gamma ** bisect_right(self.milestones, self.last_epoch)
|
||||
for base_lr in self.base_lrs
|
||||
]
|
5
tests/__init__.py
Normal file
5
tests/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
# encoding: utf-8
|
||||
"""
|
||||
@author: sherlock
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
26
tests/lr_scheduler_test.py
Normal file
26
tests/lr_scheduler_test.py
Normal file
@ -0,0 +1,26 @@
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
sys.path.append('.')
|
||||
from solver.lr_scheduler import WarmupMultiStepLR
|
||||
from solver.build import make_optimizer
|
||||
from config import cfg
|
||||
|
||||
|
||||
class MyTestCase(unittest.TestCase):
|
||||
def test_something(self):
|
||||
net = nn.Linear(10, 10)
|
||||
optimizer = make_optimizer(cfg, net)
|
||||
lr_scheduler = WarmupMultiStepLR(optimizer, [20, 40], warmup_iters=10)
|
||||
for i in range(50):
|
||||
lr_scheduler.step()
|
||||
for j in range(3):
|
||||
print(i, lr_scheduler.get_lr()[0])
|
||||
optimizer.step()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -3,9 +3,3 @@
|
||||
@author: sherlock
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
|
@ -4,64 +4,61 @@
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from pprint import pprint
|
||||
from os import mkdir
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.backends import cudnn
|
||||
|
||||
import network
|
||||
from core.config import opt, update_config
|
||||
from core.loader import get_data_provider
|
||||
from core.solver import Solver
|
||||
|
||||
FORMAT = '[%(levelname)s]: %(message)s'
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format=FORMAT,
|
||||
stream=sys.stdout
|
||||
)
|
||||
|
||||
|
||||
def test(args):
|
||||
logging.info('======= user config ======')
|
||||
logging.info(pprint(opt))
|
||||
logging.info(pprint(args))
|
||||
logging.info('======= end ======')
|
||||
|
||||
train_data, test_data, num_query = get_data_provider(opt)
|
||||
|
||||
net = getattr(network, opt.network.name)(opt.dataset.num_classes, opt.network.last_stride)
|
||||
net.load_state_dict(torch.load(args.load_model)['state_dict'])
|
||||
net = nn.DataParallel(net).cuda()
|
||||
|
||||
mod = Solver(opt, net)
|
||||
mod.test_func(test_data, num_query)
|
||||
sys.path.append('.')
|
||||
from config import cfg
|
||||
from data import make_data_loader
|
||||
from engine.inference import inference
|
||||
from modeling import build_model
|
||||
from utils.logger import setup_logger
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='reid model testing')
|
||||
parser.add_argument('--config_file', type=str, default=None,
|
||||
help='Optional config file for params')
|
||||
parser.add_argument('--load_model', type=str, required=True,
|
||||
help='load trained model for testing')
|
||||
parser = argparse.ArgumentParser(description="ReID Baseline Inference")
|
||||
parser.add_argument(
|
||||
"--config_file", default="", help="path to config file", type=str
|
||||
)
|
||||
parser.add_argument("opts", help="Modify config options using the command-line", default=None,
|
||||
nargs=argparse.REMAINDER)
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.config_file is not None:
|
||||
update_config(args.config_file)
|
||||
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = opt.network.gpus
|
||||
num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
|
||||
|
||||
if args.config_file != "":
|
||||
cfg.merge_from_file(args.config_file)
|
||||
cfg.merge_from_list(args.opts)
|
||||
cfg.freeze()
|
||||
|
||||
output_dir = cfg.OUTPUT_DIR
|
||||
if output_dir and not os.path.exists(output_dir):
|
||||
mkdir(output_dir)
|
||||
|
||||
logger = setup_logger("reid_baseline", output_dir, 0)
|
||||
logger.info("Using {} GPUS".format(num_gpus))
|
||||
logger.info(args)
|
||||
|
||||
if args.config_file != "":
|
||||
logger.info("Loaded configuration file {}".format(args.config_file))
|
||||
with open(args.config_file, 'r') as cf:
|
||||
config_str = "\n" + cf.read()
|
||||
logger.info(config_str)
|
||||
logger.info("Running with config:\n{}".format(cfg))
|
||||
|
||||
cudnn.benchmark = True
|
||||
test(args)
|
||||
|
||||
train_loader, val_loader, num_query, num_classes = make_data_loader(cfg)
|
||||
model = build_model(cfg, num_classes)
|
||||
model.load_state_dict(torch.load(cfg.TEST.WEIGHT))
|
||||
|
||||
inference(cfg, model, val_loader, num_query)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
128
tools/train.py
128
tools/train.py
@ -4,95 +4,83 @@
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from pprint import pprint
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.backends import cudnn
|
||||
|
||||
import network
|
||||
from core.config import opt, update_config
|
||||
from core.loader import get_data_provider
|
||||
from core.solver import Solver
|
||||
from utils.loss import TripletLoss
|
||||
from utils.lr_scheduler import LRScheduler
|
||||
sys.path.append('.')
|
||||
from config import cfg
|
||||
from data import make_data_loader
|
||||
from engine.trainer import do_train
|
||||
from modeling import build_model
|
||||
from layers import make_loss
|
||||
from solver import make_optimizer, WarmupMultiStepLR
|
||||
|
||||
FORMAT = '[%(levelname)s]: %(message)s'
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format=FORMAT,
|
||||
stream=sys.stdout
|
||||
)
|
||||
from utils.logger import setup_logger
|
||||
|
||||
|
||||
def train(args):
|
||||
logging.info('======= user config ======')
|
||||
logging.info(pprint(opt))
|
||||
logging.info(pprint(args))
|
||||
logging.info('======= end ======')
|
||||
def train(cfg):
|
||||
# prepare dataset
|
||||
train_loader, val_loader, num_query, num_classes = make_data_loader(cfg)
|
||||
# prepare model
|
||||
model = build_model(cfg, num_classes)
|
||||
|
||||
train_data, test_data, num_query = get_data_provider(opt)
|
||||
optimizer = make_optimizer(cfg, model)
|
||||
scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR,
|
||||
cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD)
|
||||
|
||||
net = getattr(network, opt.network.name)(opt.dataset.num_classes, opt.network.last_stride)
|
||||
loss_func = make_loss(cfg)
|
||||
|
||||
optimizer = getattr(torch.optim, opt.train.optimizer)(net.parameters(), lr=opt.train.lr, weight_decay=opt.train.wd)
|
||||
ce_loss = nn.CrossEntropyLoss()
|
||||
triplet_loss = TripletLoss(margin=opt.train.margin)
|
||||
arguments = {}
|
||||
|
||||
def ce_loss_func(scores, feat, labels):
|
||||
ce = ce_loss(scores, labels)
|
||||
return ce
|
||||
|
||||
def tri_loss_func(scores, feat, labels):
|
||||
tri = triplet_loss(feat, labels)[0]
|
||||
return tri
|
||||
|
||||
def ce_tri_loss_func(scores, feat, labels):
|
||||
ce = ce_loss(scores, labels)
|
||||
triplet = triplet_loss(feat, labels)[0]
|
||||
return ce + triplet
|
||||
|
||||
if opt.train.loss_fn == 'softmax':
|
||||
loss_fn = ce_loss_func
|
||||
elif opt.train.loss_fn == 'triplet':
|
||||
loss_fn = tri_loss_func
|
||||
elif opt.train.loss_fn == 'softmax_triplet':
|
||||
loss_fn = ce_tri_loss_func
|
||||
else:
|
||||
raise ValueError('Unknown loss func {}'.format(opt.train.loss_fn))
|
||||
|
||||
lr_scheduler = LRScheduler(base_lr=opt.train.lr, step=opt.train.step,
|
||||
factor=opt.train.factor, warmup_epoch=opt.train.warmup_epoch,
|
||||
warmup_begin_lr=opt.train.warmup_begin_lr)
|
||||
net = nn.DataParallel(net).cuda()
|
||||
mod = Solver(opt, net)
|
||||
mod.fit(train_data=train_data, test_data=test_data, num_query=num_query, optimizer=optimizer,
|
||||
criterion=loss_fn, lr_scheduler=lr_scheduler)
|
||||
do_train(
|
||||
cfg,
|
||||
model,
|
||||
train_loader,
|
||||
val_loader,
|
||||
optimizer,
|
||||
scheduler,
|
||||
loss_func,
|
||||
num_query
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='reid model training')
|
||||
parser.add_argument('--config_file', type=str, default=None, required=True,
|
||||
help='Optional config file for params')
|
||||
parser.add_argument('--save_dir', type=str, default=None, required=True,
|
||||
help='model save checkpoint directory')
|
||||
parser = argparse.ArgumentParser(description="ReID Baseline Training")
|
||||
parser.add_argument(
|
||||
"--config_file", default="", help="path to config file", type=str
|
||||
)
|
||||
parser.add_argument("opts", help="Modify config options using the command-line", default=None,
|
||||
nargs=argparse.REMAINDER)
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.config_file is not None:
|
||||
update_config(args.config_file)
|
||||
opt.misc.save_dir = args.save_dir
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = opt.network.gpus
|
||||
|
||||
num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
|
||||
|
||||
if args.config_file != "":
|
||||
cfg.merge_from_file(args.config_file)
|
||||
cfg.merge_from_list(args.opts)
|
||||
cfg.freeze()
|
||||
|
||||
output_dir = cfg.OUTPUT_DIR
|
||||
if output_dir and not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
|
||||
logger = setup_logger("reid_baseline", output_dir, 0)
|
||||
logger.info("Using {} GPUS".format(num_gpus))
|
||||
logger.info(args)
|
||||
|
||||
if args.config_file != "":
|
||||
logger.info("Loaded configuration file {}".format(args.config_file))
|
||||
with open(args.config_file, 'r') as cf:
|
||||
config_str = "\n" + cf.read()
|
||||
logger.info(config_str)
|
||||
logger.info("Running with config:\n{}".format(cfg))
|
||||
|
||||
cudnn.benchmark = True
|
||||
train(args)
|
||||
train(cfg)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -1,11 +1,6 @@
|
||||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: xyliao1993@qq.com
|
||||
@author: sherlock
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
|
39
utils/iotools.py
Normal file
39
utils/iotools.py
Normal file
@ -0,0 +1,39 @@
|
||||
# encoding: utf-8
|
||||
"""
|
||||
@author: sherlock
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import errno
|
||||
import json
|
||||
import os
|
||||
|
||||
import os.path as osp
|
||||
|
||||
|
||||
def mkdir_if_missing(directory):
|
||||
if not osp.exists(directory):
|
||||
try:
|
||||
os.makedirs(directory)
|
||||
except OSError as e:
|
||||
if e.errno != errno.EEXIST:
|
||||
raise
|
||||
|
||||
|
||||
def check_isfile(path):
|
||||
isfile = osp.isfile(path)
|
||||
if not isfile:
|
||||
print("=> Warning: no file found at '{}' (ignored)".format(path))
|
||||
return isfile
|
||||
|
||||
|
||||
def read_json(fpath):
|
||||
with open(fpath, 'r') as f:
|
||||
obj = json.load(f)
|
||||
return obj
|
||||
|
||||
|
||||
def write_json(obj, fpath):
|
||||
mkdir_if_missing(osp.dirname(fpath))
|
||||
with open(fpath, 'w') as f:
|
||||
json.dump(obj, f, indent=4, separators=(',', ': '))
|
30
utils/logger.py
Normal file
30
utils/logger.py
Normal file
@ -0,0 +1,30 @@
|
||||
# encoding: utf-8
|
||||
"""
|
||||
@author: sherlock
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
def setup_logger(name, save_dir, distributed_rank):
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
# don't log results for the non-master process
|
||||
if distributed_rank > 0:
|
||||
return logger
|
||||
ch = logging.StreamHandler(stream=sys.stdout)
|
||||
ch.setLevel(logging.DEBUG)
|
||||
formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s")
|
||||
ch.setFormatter(formatter)
|
||||
logger.addHandler(ch)
|
||||
|
||||
if save_dir:
|
||||
fh = logging.FileHandler(os.path.join(save_dir, "log.txt"), mode='w')
|
||||
fh.setLevel(logging.DEBUG)
|
||||
fh.setFormatter(formatter)
|
||||
logger.addHandler(fh)
|
||||
|
||||
return logger
|
@ -1,65 +0,0 @@
|
||||
# encoding: utf-8
|
||||
"""
|
||||
@author: sherlock
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
|
||||
class LRScheduler(object):
|
||||
"""Base class of a learning rate scheduler.
|
||||
|
||||
A scheduler returns a new learning rate based on the number of updates that have
|
||||
been performed.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
base_lr : float, optional
|
||||
The initial learning rate.
|
||||
warmup_epoch: int
|
||||
number of warmup steps used before this scheduler starts decay
|
||||
warmup_begin_lr: float
|
||||
if using warmup, the learning rate from which it starts warming up
|
||||
warmup_mode: string
|
||||
warmup can be done in two modes.
|
||||
'linear' mode gradually increases lr with each step in equal increments
|
||||
'constant' mode keeps lr at warmup_begin_lr for warmup_steps
|
||||
"""
|
||||
|
||||
def __init__(self, base_lr=0.01, step=(30, 60), factor=0.1,
|
||||
warmup_epoch=0, warmup_begin_lr=0, warmup_mode='linear'):
|
||||
self.base_lr = base_lr
|
||||
self.learning_rate = base_lr
|
||||
self.step = step
|
||||
self.factor = factor
|
||||
assert isinstance(warmup_epoch, int)
|
||||
self.warmup_epoch = warmup_epoch
|
||||
|
||||
self.warmup_final_lr = base_lr
|
||||
self.warmup_begin_lr = warmup_begin_lr
|
||||
if self.warmup_begin_lr > self.warmup_final_lr:
|
||||
raise ValueError("Base lr has to be higher than warmup_begin_lr")
|
||||
if self.warmup_epoch < 0:
|
||||
raise ValueError("Warmup steps has to be positive or 0")
|
||||
if warmup_mode not in ['linear', 'constant']:
|
||||
raise ValueError("Supports only linear and constant modes of warmup")
|
||||
self.warmup_mode = warmup_mode
|
||||
|
||||
def update(self, num_epoch):
|
||||
if self.warmup_epoch > num_epoch:
|
||||
# warmup strategy
|
||||
if self.warmup_mode == 'linear':
|
||||
self.learning_rate = self.warmup_begin_lr + (self.warmup_final_lr - self.warmup_begin_lr) * \
|
||||
num_epoch / self.warmup_epoch
|
||||
elif self.warmup_mode == 'constant':
|
||||
self.learning_rate = self.warmup_begin_lr
|
||||
|
||||
else:
|
||||
count = sum([1 for s in self.step if s <= num_epoch])
|
||||
self.learning_rate = self.base_lr * pow(self.factor, count)
|
||||
return self.learning_rate
|
||||
|
@ -1,54 +0,0 @@
|
||||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: xyliao1993@qq.com
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class AverageMeter(object):
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
self.n = 0
|
||||
self.sum = 0.0
|
||||
self.var = 0.0
|
||||
self.val = 0.0
|
||||
self.mean = np.nan
|
||||
self.std = np.nan
|
||||
|
||||
def update(self, value, n=1):
|
||||
self.val = value
|
||||
self.sum += value
|
||||
self.var += value * value
|
||||
self.n += n
|
||||
|
||||
if self.n == 0:
|
||||
self.mean, self.std = np.nan, np.nan
|
||||
elif self.n == 1:
|
||||
self.mean, self.std = self.sum, np.inf
|
||||
else:
|
||||
self.mean = self.sum / self.n
|
||||
self.std = math.sqrt(
|
||||
(self.var - self.n * self.mean * self.mean) / (self.n - 1.0))
|
||||
|
||||
def value(self):
|
||||
return self.mean, self.std
|
||||
|
||||
def get(self):
|
||||
return self.name, self.mean
|
||||
|
||||
def reset(self):
|
||||
self.n = 0
|
||||
self.sum = 0.0
|
||||
self.var = 0.0
|
||||
self.val = 0.0
|
||||
self.mean = np.nan
|
||||
self.std = np.nan
|
48
utils/reid_metric.py
Normal file
48
utils/reid_metric.py
Normal file
@ -0,0 +1,48 @@
|
||||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from ignite.metrics import Metric
|
||||
|
||||
from data.datasets.eval_reid import eval_func
|
||||
|
||||
|
||||
class R1_mAP(Metric):
|
||||
def __init__(self, num_query, max_rank=50):
|
||||
super(R1_mAP, self).__init__()
|
||||
self.num_query = num_query
|
||||
self.max_rank = max_rank
|
||||
|
||||
def reset(self):
|
||||
self.feats = []
|
||||
self.pids = []
|
||||
self.camids = []
|
||||
|
||||
def update(self, output):
|
||||
feat, pid, camid = output
|
||||
self.feats.append(feat)
|
||||
self.pids.extend(np.asarray(pid))
|
||||
self.camids.extend(np.asarray(camid))
|
||||
|
||||
def compute(self):
|
||||
feats = torch.cat(self.feats, dim=0)
|
||||
# query
|
||||
qf = feats[:self.num_query]
|
||||
q_pids = np.asarray(self.pids[:self.num_query])
|
||||
q_camids = np.asarray(self.camids[:self.num_query])
|
||||
# gallery
|
||||
gf = feats[self.num_query:]
|
||||
g_pids = np.asarray(self.pids[self.num_query:])
|
||||
g_camids = np.asarray(self.camids[self.num_query:])
|
||||
m, n = qf.shape[0], gf.shape[0]
|
||||
distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \
|
||||
torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
|
||||
distmat.addmm_(1, -2, qf, gf.t())
|
||||
distmat = distmat.cpu().numpy()
|
||||
cmc, mAP = eval_func(distmat, q_pids, g_pids, q_camids, g_camids)
|
||||
|
||||
return cmc, mAP
|
@ -1,35 +0,0 @@
|
||||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: xyliao1993@qq.com
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import errno
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
|
||||
import os.path as osp
|
||||
import torch
|
||||
|
||||
|
||||
def mkdir_if_missing(dir_path):
|
||||
try:
|
||||
os.makedirs(dir_path)
|
||||
except OSError as e:
|
||||
if e.errno != errno.EEXIST:
|
||||
raise
|
||||
|
||||
|
||||
def save_checkpoint(state, is_best, save_dir, filename='checkpoint.pth.tar'):
|
||||
fpath = '_'.join((str(state['epoch']), filename))
|
||||
fpath = osp.join(save_dir, fpath)
|
||||
mkdir_if_missing(save_dir)
|
||||
torch.save(state, fpath)
|
||||
if is_best:
|
||||
shutil.copy(fpath, osp.join(save_dir, 'model_best.pth.tar'))
|
Loading…
x
Reference in New Issue
Block a user