add config
parent
3457f7890c
commit
5365619606
137
README.rst
137
README.rst
|
@ -28,41 +28,28 @@ Model zoo: https://kaiyangzhou.github.io/deep-person-reid/MODEL_ZOO.
|
|||
Installation
|
||||
---------------
|
||||
|
||||
We recommend using `conda <https://www.anaconda.com/distribution/>`_ to manage the packages.
|
||||
Make sure your `conda <https://www.anaconda.com/distribution/>`_ is installed.
|
||||
|
||||
1. Clone ``deep-person-reid`` to your preferred directory.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ git clone https://github.com/KaiyangZhou/deep-person-reid.git
|
||||
|
||||
2. Create a conda environment (the default name is ``torchreid``).
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ cd deep-person-reid/
|
||||
$ conda env create -f environment.yml
|
||||
$ conda activate torchreid
|
||||
|
||||
Do check whether ``which python`` and ``which pip`` point to the right path.
|
||||
|
||||
3. Install tensorboard.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ pip install tb-nightly
|
||||
|
||||
4. Install PyTorch and torchvision (select the proper cuda version to suit your machine).
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ conda install pytorch torchvision cudatoolkit=9.0 -c pytorch
|
||||
|
||||
5. Install ``torchreid``.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ python setup.py develop
|
||||
# cd to your preferred directory and clone this repo
|
||||
git clone https://github.com/KaiyangZhou/deep-person-reid.git
|
||||
|
||||
# create environment
|
||||
cd deep-person-reid/
|
||||
conda create --name torchreid python=3.7
|
||||
conda activate torchreid
|
||||
|
||||
# install dependencies
|
||||
# make sure `which python` and `which pip` point to the correct path
|
||||
pip install -r requirements.txt
|
||||
|
||||
# install torch and torchvision (select the proper cuda version to suit your machine)
|
||||
conda install pytorch torchvision cudatoolkit=9.0 -c pytorch
|
||||
|
||||
# install torchreid (don't need to re-build it if you modify the source code)
|
||||
python setup.py develop
|
||||
|
||||
|
||||
Get started: 30 seconds to Torchreid
|
||||
|
@ -80,9 +67,11 @@ Get started: 30 seconds to Torchreid
|
|||
datamanager = torchreid.data.ImageDataManager(
|
||||
root='reid-data',
|
||||
sources='market1501',
|
||||
targets='market1501',
|
||||
height=256,
|
||||
width=128,
|
||||
batch_size=32,
|
||||
batch_size_train=32,
|
||||
batch_size_test=100,
|
||||
transforms=['random_flip', 'random_crop']
|
||||
)
|
||||
|
||||
|
@ -138,46 +127,72 @@ Get started: 30 seconds to Torchreid
|
|||
|
||||
A unified interface
|
||||
-----------------------
|
||||
In "deep-person-reid/scripts/", we provide a unified interface to train and test a model.
|
||||
In "deep-person-reid/scripts/", we provide a unified interface to train and test a model. See "scripts/main.py" and "scripts/default_config.py" for more details. "configs/" contains some predefined configs which you can use as a starting point.
|
||||
|
||||
For instance, to train an image reid model on Market1501 using softmax, you can do
|
||||
Below we provide examples to train and test `OSNet <https://arxiv.org/abs/1905.00953>`_. Assume :code:`PATH_TO_DATA` is the directory containing reid datasets.
|
||||
|
||||
Conventional setting
|
||||
^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
To train OSNet on Market1501, do
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
# suppose you are in deep-person-reid/
|
||||
|
||||
python scripts/main.py \
|
||||
--root PATH_TO_DATA \
|
||||
--app image \
|
||||
--loss softmax \
|
||||
--label-smooth \
|
||||
-s market1501 \
|
||||
-a resnet50 \
|
||||
--optim adam \
|
||||
--lr 0.0003 \
|
||||
--max-epoch 60 \
|
||||
--stepsize 20 40 \
|
||||
--batch-size 32 \
|
||||
--transforms random_flip random_crop \
|
||||
--save-dir log/resnet50-market1501-softmax \
|
||||
--config-file configs/im_osnet_x1_0_softmax_256x128_amsgrad_cosine.yaml \
|
||||
--transforms random_flip random_erase \
|
||||
--root $PATH_TO_DATA \
|
||||
--gpu-devices 0
|
||||
|
||||
To evaluate a trained model, do
|
||||
|
||||
The config file sets Market1501 as the default dataset. If you wanna use DukeMTMC-reID, do
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
|
||||
python scripts/main.py \
|
||||
--root PATH_TO_DATA \
|
||||
--app image \
|
||||
--loss softmax \
|
||||
-s market1501 \
|
||||
-a resnet50 \
|
||||
--batch-size 32 \
|
||||
--evaluate \
|
||||
--load-weights log/resnet50-market1501-softmax/model.pth.tar-60 \
|
||||
--save-dir log/eval-resnet50 \
|
||||
--config-file configs/im_osnet_x1_0_softmax_256x128_amsgrad_cosine.yaml \
|
||||
-s dukemtmcreid \
|
||||
-t dukemtmcreid \
|
||||
--transforms random_flip random_erase \
|
||||
--root $PATH_TO_DATA \
|
||||
--gpu-devices 0
|
||||
|
||||
Please refer to ``default_parser.py`` and ``main.py`` for more details.
|
||||
The code will automatically (download and) load the ImageNet pretrained weights. After the training is done, the model will be saved as "log/osnet_x1_0_market1501_softmax_cosinelr/model.pth.tar-250".
|
||||
|
||||
Evaluation will be automatically performed at the end of training.
|
||||
|
||||
To run the test again using the trained model, do
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python scripts/main.py \
|
||||
--config-file configs/im_osnet_x1_0_softmax_256x128_amsgrad_cosine.yaml \
|
||||
--root $PATH_TO_DATA \
|
||||
--gpu-devices 0 \
|
||||
model.load_weights log/osnet_x1_0_market1501_softmax_cosinelr/model.pth.tar-250 \
|
||||
test.evaluate True
|
||||
|
||||
|
||||
Cross-domain setting
|
||||
^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Suppose you wanna train OSNet on DukeMTMC-reID and test its performance on Market1501, you can do
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python scripts/main.py \
|
||||
--config-file configs/im_osnet_x1_0_softmax_256x128_amsgrad.yaml \
|
||||
-s dukemtmcreid \
|
||||
-t market1501 \
|
||||
--transforms random_flip color_jitter \
|
||||
--root $PATH_TO_DATA \
|
||||
--gpu-devices 0
|
||||
|
||||
Here we only test the cross-domain performance. However, if you also want to test the same-domain performance, you can set :code:`-t dukemtmcreid market1501`, which will evaluate the model on the two datasets separately.
|
||||
|
||||
Different from the same-domain setting, here we replace :code:`random_erase` with :code:`color_jitter`. This can improve the generalization performance on the unseen target dataset.
|
||||
|
||||
Pretrained models are available in the `Model Zoo <https://kaiyangzhou.github.io/deep-person-reid/MODEL_ZOO.html>`_.
|
||||
|
||||
|
||||
Datasets
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
model:
|
||||
name: 'osnet_x1_0'
|
||||
pretrained: True
|
||||
|
||||
data:
|
||||
type: 'image'
|
||||
sources: ['market1501']
|
||||
targets: ['market1501']
|
||||
height: 256
|
||||
width: 128
|
||||
combineall: False
|
||||
transforms: ['random_flip']
|
||||
save_dir: 'log/osnet_x1_0_market1501_softmax'
|
||||
|
||||
loss:
|
||||
name: 'softmax'
|
||||
softmax:
|
||||
label_smooth: True
|
||||
|
||||
train:
|
||||
optim: 'amsgrad'
|
||||
lr: 0.0015
|
||||
max_epoch: 150
|
||||
batch_size: 64
|
||||
fixbase_epoch: 10
|
||||
open_layers: ['classsifier']
|
||||
lr_scheduler: 'single_step'
|
||||
stepsize: [60]
|
||||
|
||||
test:
|
||||
batch_size: 300
|
||||
dist_metric: 'euclidean'
|
||||
normalize_feature: False
|
||||
evaluate: False
|
||||
eval_freq: 10
|
||||
rerank: False
|
||||
visactmap: False
|
|
@ -0,0 +1,36 @@
|
|||
model:
|
||||
name: 'osnet_x1_0'
|
||||
pretrained: True
|
||||
|
||||
data:
|
||||
type: 'image'
|
||||
sources: ['market1501']
|
||||
targets: ['market1501']
|
||||
height: 256
|
||||
width: 128
|
||||
combineall: False
|
||||
transforms: ['random_flip']
|
||||
save_dir: 'log/osnet_x1_0_market1501_softmax_cosinelr'
|
||||
|
||||
loss:
|
||||
name: 'softmax'
|
||||
softmax:
|
||||
label_smooth: True
|
||||
|
||||
train:
|
||||
optim: 'amsgrad'
|
||||
lr: 0.0015
|
||||
max_epoch: 250
|
||||
batch_size: 64
|
||||
fixbase_epoch: 10
|
||||
open_layers: ['classsifier']
|
||||
lr_scheduler: 'cosine'
|
||||
|
||||
test:
|
||||
batch_size: 300
|
||||
dist_metric: 'euclidean'
|
||||
normalize_feature: False
|
||||
evaluate: False
|
||||
eval_freq: 10
|
||||
rerank: False
|
||||
visactmap: False
|
|
@ -0,0 +1,37 @@
|
|||
model:
|
||||
name: 'resnet50_fc512'
|
||||
pretrained: True
|
||||
|
||||
data:
|
||||
type: 'image'
|
||||
sources: ['market1501']
|
||||
targets: ['market1501']
|
||||
height: 256
|
||||
width: 128
|
||||
combineall: False
|
||||
transforms: ['random_flip']
|
||||
save_dir: 'log/resnet50_market1501_softmax'
|
||||
|
||||
loss:
|
||||
name: 'softmax'
|
||||
softmax:
|
||||
label_smooth: True
|
||||
|
||||
train:
|
||||
optim: 'amsgrad'
|
||||
lr: 0.0003
|
||||
max_epoch: 60
|
||||
batch_size: 32
|
||||
fixbase_epoch: 5
|
||||
open_layers: ['classsifier']
|
||||
lr_scheduler: 'single_step'
|
||||
stepsize: [20]
|
||||
|
||||
test:
|
||||
batch_size: 100
|
||||
dist_metric: 'euclidean'
|
||||
normalize_feature: False
|
||||
evaluate: False
|
||||
eval_freq: 10
|
||||
rerank: False
|
||||
visactmap: False
|
|
@ -0,0 +1,37 @@
|
|||
model:
|
||||
name: 'resnet50_fc512'
|
||||
pretrained: True
|
||||
|
||||
data:
|
||||
type: 'image'
|
||||
sources: ['market1501']
|
||||
targets: ['market1501']
|
||||
height: 256
|
||||
width: 128
|
||||
combineall: False
|
||||
transforms: ['random_flip']
|
||||
save_dir: 'log/resnet50_fc512_market1501_softmax'
|
||||
|
||||
loss:
|
||||
name: 'softmax'
|
||||
softmax:
|
||||
label_smooth: True
|
||||
|
||||
train:
|
||||
optim: 'amsgrad'
|
||||
lr: 0.0003
|
||||
max_epoch: 60
|
||||
batch_size: 32
|
||||
fixbase_epoch: 5
|
||||
open_layers: ['fc', 'classsifier']
|
||||
lr_scheduler: 'single_step'
|
||||
stepsize: [20]
|
||||
|
||||
test:
|
||||
batch_size: 100
|
||||
dist_metric: 'euclidean'
|
||||
normalize_feature: False
|
||||
evaluate: False
|
||||
eval_freq: 10
|
||||
rerank: False
|
||||
visactmap: False
|
|
@ -0,0 +1,139 @@
|
|||
import sys
|
||||
import os
|
||||
import os.path as osp
|
||||
import warnings
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from default_parser import (
|
||||
init_parser, imagedata_kwargs, videodata_kwargs,
|
||||
optimizer_kwargs, lr_scheduler_kwargs, engine_run_kwargs
|
||||
)
|
||||
import torchreid
|
||||
from torchreid.utils import (
|
||||
Logger, set_random_seed, check_isfile, resume_from_checkpoint,
|
||||
load_pretrained_weights, compute_model_complexity, collect_env_info
|
||||
)
|
||||
|
||||
|
||||
parser = init_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def build_datamanager(args):
|
||||
if args.app == 'image':
|
||||
return torchreid.data.ImageDataManager(**imagedata_kwargs(args))
|
||||
else:
|
||||
return torchreid.data.VideoDataManager(**videodata_kwargs(args))
|
||||
|
||||
|
||||
def build_engine(args, datamanager, model, optimizer, scheduler):
|
||||
if args.app == 'image':
|
||||
if args.loss == 'softmax':
|
||||
engine = torchreid.engine.ImageSoftmaxEngine(
|
||||
datamanager,
|
||||
model,
|
||||
optimizer,
|
||||
scheduler=scheduler,
|
||||
use_cpu=args.use_cpu,
|
||||
label_smooth=args.label_smooth
|
||||
)
|
||||
else:
|
||||
engine = torchreid.engine.ImageTripletEngine(
|
||||
datamanager,
|
||||
model,
|
||||
optimizer,
|
||||
margin=args.margin,
|
||||
weight_t=args.weight_t,
|
||||
weight_x=args.weight_x,
|
||||
scheduler=scheduler,
|
||||
use_cpu=args.use_cpu,
|
||||
label_smooth=args.label_smooth
|
||||
)
|
||||
|
||||
else:
|
||||
if args.loss == 'softmax':
|
||||
engine = torchreid.engine.VideoSoftmaxEngine(
|
||||
datamanager,
|
||||
model,
|
||||
optimizer,
|
||||
scheduler=scheduler,
|
||||
use_cpu=args.use_cpu,
|
||||
label_smooth=args.label_smooth,
|
||||
pooling_method=args.pooling_method
|
||||
)
|
||||
else:
|
||||
engine = torchreid.engine.VideoTripletEngine(
|
||||
datamanager,
|
||||
model,
|
||||
optimizer,
|
||||
margin=args.margin,
|
||||
weight_t=args.weight_t,
|
||||
weight_x=args.weight_x,
|
||||
scheduler=scheduler,
|
||||
use_cpu=args.use_cpu,
|
||||
label_smooth=args.label_smooth
|
||||
)
|
||||
|
||||
return engine
|
||||
|
||||
|
||||
def main():
|
||||
global args
|
||||
|
||||
set_random_seed(args.seed)
|
||||
if not args.use_avai_gpus:
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices
|
||||
use_gpu = torch.cuda.is_available() and not args.use_cpu
|
||||
log_name = 'test.log' if args.evaluate else 'train.log'
|
||||
log_name += time.strftime('-%Y-%m-%d-%H-%M-%S')
|
||||
sys.stdout = Logger(osp.join(args.save_dir, log_name))
|
||||
print('** Arguments **')
|
||||
arg_keys = list(args.__dict__.keys())
|
||||
arg_keys.sort()
|
||||
for key in arg_keys:
|
||||
print('{}: {}'.format(key, args.__dict__[key]))
|
||||
print('\n')
|
||||
print('Collecting env info ...')
|
||||
print('** System info **\n{}\n'.format(collect_env_info()))
|
||||
if use_gpu:
|
||||
torch.backends.cudnn.benchmark = True
|
||||
else:
|
||||
warnings.warn('Currently using CPU, however, GPU is highly recommended')
|
||||
|
||||
datamanager = build_datamanager(args)
|
||||
|
||||
print('Building model: {}'.format(args.arch))
|
||||
model = torchreid.models.build_model(
|
||||
name=args.arch,
|
||||
num_classes=datamanager.num_train_pids,
|
||||
loss=args.loss.lower(),
|
||||
pretrained=(not args.no_pretrained),
|
||||
use_gpu=use_gpu
|
||||
)
|
||||
num_params, flops = compute_model_complexity(model, (1, 3, args.height, args.width))
|
||||
print('Model complexity: params={:,} flops={:,}'.format(num_params, flops))
|
||||
|
||||
if args.load_weights and check_isfile(args.load_weights):
|
||||
load_pretrained_weights(model, args.load_weights)
|
||||
|
||||
if use_gpu:
|
||||
model = nn.DataParallel(model).cuda()
|
||||
|
||||
optimizer = torchreid.optim.build_optimizer(model, **optimizer_kwargs(args))
|
||||
|
||||
scheduler = torchreid.optim.build_lr_scheduler(optimizer, **lr_scheduler_kwargs(args))
|
||||
|
||||
if args.resume and check_isfile(args.resume):
|
||||
args.start_epoch = resume_from_checkpoint(args.resume, model, optimizer=optimizer)
|
||||
|
||||
print('Building {}-engine for {}-reid'.format(args.loss, args.app))
|
||||
engine = build_engine(args, datamanager, model, optimizer, scheduler)
|
||||
|
||||
engine.run(**engine_run_kwargs(args))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -193,7 +193,7 @@ Note that ``fixbase_epoch`` is counted into ``max_epoch``. In the above example,
|
|||
|
||||
Test a trained model
|
||||
----------------------
|
||||
You can load a trained model using :code:`torchreid.utils.load_pretrained_weights(model, weight_path)` and set ``test_only=True`` in ``engine.run()``. If you use ``scripts/main.py``, you can do ``--evaluate --load-weights PATH_TO_WEIGHTS``.
|
||||
You can load a trained model using :code:`torchreid.utils.load_pretrained_weights(model, weight_path)` and set ``test_only=True`` in ``engine.run()``.
|
||||
|
||||
|
||||
Visualize learning curves with tensorboard
|
||||
|
@ -209,22 +209,6 @@ Ranked images can be visualized by setting ``visrank`` to true in ``engine.run()
|
|||
:width: 800px
|
||||
:align: center
|
||||
|
||||
Example command for ``scripts/main.py`` is
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
python scripts/main.py \
|
||||
--root $DATA \
|
||||
-s market1501 \
|
||||
-t market1501 \
|
||||
-a osnet_x1_0 \
|
||||
--load-weights PATH_TO_WEIGHTS \
|
||||
--evaluate \
|
||||
--visrank \
|
||||
--visrank-topk 15 \
|
||||
--save-dir log/eval-osnet_x1_0 \
|
||||
--gpu-devices 0
|
||||
|
||||
|
||||
Visualize activation maps
|
||||
--------------------------
|
||||
|
@ -235,21 +219,6 @@ To understand where the CNN focuses on to extract features for ReID, you can vis
|
|||
:align: center
|
||||
|
||||
|
||||
Example command for ``scripts/main.py`` is
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
python scripts/main.py \
|
||||
--root $DATA \
|
||||
-s market1501 \
|
||||
-t market1501 \
|
||||
-a osnet_x1_0 \
|
||||
--load-weights PATH_TO_WEIGHTS \
|
||||
--visactmap \
|
||||
--save-dir log/eval-osnet_x1_0 \
|
||||
--gpu-devices 0
|
||||
|
||||
|
||||
.. note::
|
||||
In order to visualize activation maps, the CNN needs to output the last convolutional feature maps at eval mode. See ``torchreid/models/osnet.py`` for example.
|
||||
|
||||
|
|
|
@ -1,12 +0,0 @@
|
|||
name: torchreid
|
||||
dependencies:
|
||||
- python=3.7
|
||||
- numpy
|
||||
- Cython
|
||||
- h5py
|
||||
- Pillow
|
||||
- six
|
||||
- scipy
|
||||
- opencv
|
||||
- matplotlib
|
||||
- future
|
|
@ -8,3 +8,5 @@ opencv-python
|
|||
matplotlib
|
||||
tb-nightly
|
||||
future
|
||||
yacs
|
||||
gdown
|
||||
|
|
|
@ -0,0 +1,201 @@
|
|||
import argparse
|
||||
from yacs.config import CfgNode as CN
|
||||
|
||||
|
||||
def get_default_config():
|
||||
cfg = CN()
|
||||
|
||||
# model
|
||||
cfg.model = CN()
|
||||
cfg.model.name = 'resnet50'
|
||||
cfg.model.pretrained = True # automatically load pretrained model weights if available
|
||||
cfg.model.load_weights = '' # path to model weights
|
||||
cfg.model.resume = '' # path to checkpoint for resume training
|
||||
|
||||
# data
|
||||
cfg.data = CN()
|
||||
cfg.data.type = 'image'
|
||||
cfg.data.root = 'reid-data'
|
||||
cfg.data.sources = ['market1501']
|
||||
cfg.data.targets = ['market1501']
|
||||
cfg.data.workers = 4 # number of data loading workers
|
||||
cfg.data.split_id = 0 # split index
|
||||
cfg.data.height = 256 # image height
|
||||
cfg.data.width = 128 # image width
|
||||
cfg.data.combineall = False # combine train, query and gallery for training
|
||||
cfg.data.transforms = ['random_flip'] # data augmentation
|
||||
cfg.data.norm_mean = [0.485, 0.456, 0.406] # default is imagenet mean
|
||||
cfg.data.norm_std = [0.229, 0.224, 0.225] # default is imagenet std
|
||||
cfg.data.save_dir = 'log' # path to save log
|
||||
|
||||
# specific datasets
|
||||
cfg.market1501 = CN()
|
||||
cfg.market1501.use_500k_distractors = False # add 500k distractors to the gallery set for market1501
|
||||
cfg.cuhk03 = CN()
|
||||
cfg.cuhk03.labeled_images = False # use labeled images, if False, use detected images
|
||||
cfg.cuhk03.classic_split = False # use classic split by Li et al. CVPR14
|
||||
cfg.cuhk03.use_metric_cuhk03 = False # use cuhk03's metric for evaluation
|
||||
|
||||
# sampler
|
||||
cfg.sampler = CN()
|
||||
cfg.sampler.train_sampler = 'RandomSampler'
|
||||
cfg.sampler.num_instances = 4 # number of instances per identity for RandomIdentitySampler
|
||||
|
||||
# video reid setting
|
||||
cfg.video = CN()
|
||||
cfg.video.seq_len = 15 # number of images to sample in a tracklet
|
||||
cfg.video.sample_method = 'evenly' # how to sample images from a tracklet
|
||||
cfg.video.pooling_method = 'avg' # how to pool features over a tracklet
|
||||
|
||||
# train
|
||||
cfg.train = CN()
|
||||
cfg.train.optim = 'adam'
|
||||
cfg.train.lr = 0.0003
|
||||
cfg.train.weight_decay = 5e-4
|
||||
cfg.train.max_epoch = 60
|
||||
cfg.train.start_epoch = 0
|
||||
cfg.train.batch_size = 32
|
||||
cfg.train.fixbase_epoch = 0 # number of epochs to fix base layers
|
||||
cfg.train.open_layers = ['classifier'] # layers for training while keeping others frozen
|
||||
cfg.train.staged_lr = False # set different lr to different layers
|
||||
cfg.train.new_layers = ['classifier'] # newly added layers with default lr
|
||||
cfg.train.base_lr_mult = 0.1 # learning rate multiplier for base layers
|
||||
cfg.train.lr_scheduler = 'single_step'
|
||||
cfg.train.stepsize = [20] # stepsize to decay learning rate
|
||||
cfg.train.gamma = 0.1 # learning rate decay multiplier
|
||||
cfg.train.print_freq = 20 # print frequency
|
||||
cfg.train.seed = 1 # random seed
|
||||
|
||||
# optimizer
|
||||
cfg.sgd = CN()
|
||||
cfg.sgd.momentum = 0.9 # momentum factor for sgd and rmsprop
|
||||
cfg.sgd.dampening = 0. # dampening for momentum
|
||||
cfg.sgd.nesterov = False # Nesterov momentum
|
||||
cfg.rmsprop = CN()
|
||||
cfg.rmsprop.alpha = 0.99 # smoothing constant
|
||||
cfg.adam = CN()
|
||||
cfg.adam.beta1 = 0.9 # exponential decay rate for first moment
|
||||
cfg.adam.beta2 = 0.999 # exponential decay rate for second moment
|
||||
|
||||
# loss
|
||||
cfg.loss = CN()
|
||||
cfg.loss.name = 'softmax'
|
||||
cfg.loss.softmax = CN()
|
||||
cfg.loss.softmax.label_smooth = True # use label smoothing regularizer
|
||||
cfg.loss.triplet = CN()
|
||||
cfg.loss.triplet.margin = 0.3 # distance margin
|
||||
cfg.loss.triplet.weight_t =1. # weight to balance hard triplet loss
|
||||
cfg.loss.triplet.weight_x = 0. # weight to balance cross entropy loss
|
||||
|
||||
# test
|
||||
cfg.test = CN()
|
||||
cfg.test.batch_size = 100
|
||||
cfg.test.dist_metric = 'euclidean' # distance metric, ['euclidean', 'cosine']
|
||||
cfg.test.normalize_feature = False # normalize feature vectors before computing distance
|
||||
cfg.test.ranks = [1, 5, 10, 20] # cmc ranks
|
||||
cfg.test.evaluate = False # test only
|
||||
cfg.test.eval_freq = -1 # evaluation frequency (-1 means to only test after training)
|
||||
cfg.test.start_eval = 0 # start to evaluate after a specific epoch
|
||||
cfg.test.rerank = False # use person re-ranking
|
||||
cfg.test.visrank = False # visualize ranked results (only available when cfg.test.evaluate=True)
|
||||
cfg.test.visrank_topk = 10 # top-k ranks to visualize
|
||||
cfg.test.visactmap = False # visualize CNN activation maps
|
||||
|
||||
return cfg
|
||||
|
||||
|
||||
def imagedata_kwargs(cfg):
|
||||
return {
|
||||
'root': cfg.data.root,
|
||||
'sources': cfg.data.sources,
|
||||
'targets': cfg.data.targets,
|
||||
'height': cfg.data.height,
|
||||
'width': cfg.data.width,
|
||||
'transforms': cfg.data.transforms,
|
||||
'norm_mean': cfg.data.norm_mean,
|
||||
'norm_std': cfg.data.norm_std,
|
||||
'use_gpu': cfg.use_gpu,
|
||||
'split_id': cfg.data.split_id,
|
||||
'combineall': cfg.data.combineall,
|
||||
'batch_size_train': cfg.train.batch_size,
|
||||
'batch_size_test': cfg.test.batch_size,
|
||||
'workers': cfg.data.workers,
|
||||
'num_instances': cfg.sampler.num_instances,
|
||||
'train_sampler': cfg.sampler.train_sampler,
|
||||
# image
|
||||
'cuhk03_labeled': cfg.cuhk03.labeled_images,
|
||||
'cuhk03_classic_split': cfg.cuhk03.classic_split,
|
||||
'market1501_500k': cfg.market1501.use_500k_distractors,
|
||||
}
|
||||
|
||||
|
||||
def videodata_kwargs(cfg):
|
||||
return {
|
||||
'root': cfg.data.root,
|
||||
'sources': cfg.data.sources,
|
||||
'targets': cfg.data.targets,
|
||||
'height': cfg.data.height,
|
||||
'width': cfg.data.width,
|
||||
'transforms': cfg.data.transforms,
|
||||
'norm_mean': cfg.data.norm_mean,
|
||||
'norm_std': cfg.data.norm_std,
|
||||
'use_gpu': cfg.use_gpu,
|
||||
'split_id': cfg.data.split_id,
|
||||
'combineall': cfg.data.combineall,
|
||||
'batch_size_train': cfg.train.batch_size,
|
||||
'batch_size_test': cfg.test.batch_size,
|
||||
'workers': cfg.data.workers,
|
||||
'num_instances': cfg.sampler.num_instances,
|
||||
'train_sampler': cfg.sampler.train_sampler,
|
||||
# video
|
||||
'seq_len': cfg.video.seq_len,
|
||||
'sample_method': cfg.video.sample_method
|
||||
}
|
||||
|
||||
|
||||
def optimizer_kwargs(cfg):
|
||||
return {
|
||||
'optim': cfg.train.optim,
|
||||
'lr': cfg.train.lr,
|
||||
'weight_decay': cfg.train.weight_decay,
|
||||
'momentum': cfg.sgd.momentum,
|
||||
'sgd_dampening': cfg.sgd.dampening,
|
||||
'sgd_nesterov': cfg.sgd.nesterov,
|
||||
'rmsprop_alpha': cfg.rmsprop.alpha,
|
||||
'adam_beta1': cfg.adam.beta1,
|
||||
'adam_beta2': cfg.adam.beta2,
|
||||
'staged_lr': cfg.train.staged_lr,
|
||||
'new_layers': cfg.train.new_layers,
|
||||
'base_lr_mult': cfg.train.base_lr_mult
|
||||
}
|
||||
|
||||
|
||||
def lr_scheduler_kwargs(cfg):
|
||||
return {
|
||||
'lr_scheduler': cfg.train.lr_scheduler,
|
||||
'stepsize': cfg.train.stepsize,
|
||||
'gamma': cfg.train.gamma,
|
||||
'max_epoch': cfg.train.max_epoch
|
||||
}
|
||||
|
||||
|
||||
def engine_run_kwargs(cfg):
|
||||
return {
|
||||
'save_dir': cfg.data.save_dir,
|
||||
'max_epoch': cfg.train.max_epoch,
|
||||
'start_epoch': cfg.train.start_epoch,
|
||||
'fixbase_epoch': cfg.train.fixbase_epoch,
|
||||
'open_layers': cfg.train.open_layers,
|
||||
'start_eval': cfg.test.start_eval,
|
||||
'eval_freq': cfg.test.eval_freq,
|
||||
'test_only': cfg.test.evaluate,
|
||||
'print_freq': cfg.train.print_freq,
|
||||
'dist_metric': cfg.test.dist_metric,
|
||||
'normalize_feature': cfg.test.normalize_feature,
|
||||
'visrank': cfg.test.visrank,
|
||||
'visrank_topk': cfg.test.visrank_topk,
|
||||
'use_metric_cuhk03': cfg.cuhk03.use_metric_cuhk03,
|
||||
'ranks': cfg.test.ranks,
|
||||
'rerank': cfg.test.rerank,
|
||||
'visactmap': cfg.test.visactmap
|
||||
}
|
149
scripts/main.py
149
scripts/main.py
|
@ -3,12 +3,13 @@ import os
|
|||
import os.path as osp
|
||||
import warnings
|
||||
import time
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from default_parser import (
|
||||
init_parser, imagedata_kwargs, videodata_kwargs,
|
||||
from default_config import (
|
||||
get_default_config, imagedata_kwargs, videodata_kwargs,
|
||||
optimizer_kwargs, lr_scheduler_kwargs, engine_run_kwargs
|
||||
)
|
||||
import torchreid
|
||||
|
@ -18,121 +19,137 @@ from torchreid.utils import (
|
|||
)
|
||||
|
||||
|
||||
parser = init_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def build_datamanager(args):
|
||||
if args.app == 'image':
|
||||
return torchreid.data.ImageDataManager(**imagedata_kwargs(args))
|
||||
def build_datamanager(cfg):
|
||||
if cfg.data.type == 'image':
|
||||
return torchreid.data.ImageDataManager(**imagedata_kwargs(cfg))
|
||||
else:
|
||||
return torchreid.data.VideoDataManager(**videodata_kwargs(args))
|
||||
return torchreid.data.VideoDataManager(**videodata_kwargs(cfg))
|
||||
|
||||
|
||||
def build_engine(args, datamanager, model, optimizer, scheduler):
|
||||
if args.app == 'image':
|
||||
if args.loss == 'softmax':
|
||||
def build_engine(cfg, datamanager, model, optimizer, scheduler):
|
||||
if cfg.data.type == 'image':
|
||||
if cfg.loss.name == 'softmax':
|
||||
engine = torchreid.engine.ImageSoftmaxEngine(
|
||||
datamanager,
|
||||
model,
|
||||
optimizer,
|
||||
scheduler=scheduler,
|
||||
use_cpu=args.use_cpu,
|
||||
label_smooth=args.label_smooth
|
||||
use_gpu=cfg.use_gpu,
|
||||
label_smooth=cfg.loss.softmax.label_smooth
|
||||
)
|
||||
else:
|
||||
engine = torchreid.engine.ImageTripletEngine(
|
||||
datamanager,
|
||||
model,
|
||||
optimizer,
|
||||
margin=args.margin,
|
||||
weight_t=args.weight_t,
|
||||
weight_x=args.weight_x,
|
||||
margin=cfg.loss.triplet.margin,
|
||||
weight_t=cfg.loss.triplet.weight_t,
|
||||
weight_x=cfg.loss.triplet.weight_x,
|
||||
scheduler=scheduler,
|
||||
use_cpu=args.use_cpu,
|
||||
label_smooth=args.label_smooth
|
||||
use_gpu=cfg.use_gpu,
|
||||
label_smooth=cfg.loss.softmax.label_smooth
|
||||
)
|
||||
|
||||
else:
|
||||
if args.loss == 'softmax':
|
||||
if cfg.loss.name == 'softmax':
|
||||
engine = torchreid.engine.VideoSoftmaxEngine(
|
||||
datamanager,
|
||||
model,
|
||||
optimizer,
|
||||
scheduler=scheduler,
|
||||
use_cpu=args.use_cpu,
|
||||
label_smooth=args.label_smooth,
|
||||
pooling_method=args.pooling_method
|
||||
use_gpu=cfg.use_gpu,
|
||||
label_smooth=cfg.loss.softmax.label_smooth,
|
||||
pooling_method=cfg.video.pooling_method
|
||||
)
|
||||
else:
|
||||
engine = torchreid.engine.VideoTripletEngine(
|
||||
datamanager,
|
||||
model,
|
||||
optimizer,
|
||||
margin=args.margin,
|
||||
weight_t=args.weight_t,
|
||||
weight_x=args.weight_x,
|
||||
margin=cfg.loss.triplet.margin,
|
||||
weight_t=cfg.loss.triplet.weight_t,
|
||||
weight_x=cfg.loss.triplet.weight_x,
|
||||
scheduler=scheduler,
|
||||
use_cpu=args.use_cpu,
|
||||
label_smooth=args.label_smooth
|
||||
use_gpu=cfg.use_gpu,
|
||||
label_smooth=cfg.loss.softmax.label_smooth
|
||||
)
|
||||
|
||||
return engine
|
||||
|
||||
|
||||
def main():
|
||||
global args
|
||||
def reset_config(cfg, args):
|
||||
if args.root:
|
||||
cfg.data.root = args.root
|
||||
if args.sources:
|
||||
cfg.data.sources = args.sources
|
||||
if args.targets:
|
||||
cfg.data.targets = args.targets
|
||||
if args.transforms:
|
||||
cfg.data.transforms = args.transforms
|
||||
|
||||
set_random_seed(args.seed)
|
||||
if not args.use_avai_gpus:
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument('--config-file', type=str, default='', help='path to config file')
|
||||
parser.add_argument('-s', '--sources', type=str, nargs='+', help='source datasets (delimited by space)')
|
||||
parser.add_argument('-t', '--targets', type=str, nargs='+', help='target datasets (delimited by space)')
|
||||
parser.add_argument('--transforms', type=str, nargs='+', help='data augmentation')
|
||||
parser.add_argument('--root', type=str, default='', help='path to data root')
|
||||
parser.add_argument('--gpu-devices', type=str, default='',)
|
||||
parser.add_argument('opts', default=None, nargs=argparse.REMAINDER, help='Modify config options using the command-line')
|
||||
args = parser.parse_args()
|
||||
|
||||
cfg = get_default_config()
|
||||
cfg.use_gpu = torch.cuda.is_available()
|
||||
if args.config_file:
|
||||
cfg.merge_from_file(args.config_file)
|
||||
reset_config(cfg, args)
|
||||
cfg.merge_from_list(args.opts)
|
||||
cfg.freeze()
|
||||
set_random_seed(cfg.train.seed)
|
||||
|
||||
if cfg.use_gpu and args.gpu_devices:
|
||||
# if gpu_devices is not specified, all available gpus will be used
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices
|
||||
use_gpu = torch.cuda.is_available() and not args.use_cpu
|
||||
log_name = 'test.log' if args.evaluate else 'train.log'
|
||||
log_name = 'test.log' if cfg.test.evaluate else 'train.log'
|
||||
log_name += time.strftime('-%Y-%m-%d-%H-%M-%S')
|
||||
sys.stdout = Logger(osp.join(args.save_dir, log_name))
|
||||
print('** Arguments **')
|
||||
arg_keys = list(args.__dict__.keys())
|
||||
arg_keys.sort()
|
||||
for key in arg_keys:
|
||||
print('{}: {}'.format(key, args.__dict__[key]))
|
||||
print('\n')
|
||||
sys.stdout = Logger(osp.join(cfg.data.save_dir, log_name))
|
||||
|
||||
print('Show configuration\n{}\n'.format(cfg))
|
||||
print('Collecting env info ...')
|
||||
print('** System info **\n{}\n'.format(collect_env_info()))
|
||||
if use_gpu:
|
||||
torch.backends.cudnn.benchmark = True
|
||||
else:
|
||||
warnings.warn('Currently using CPU, however, GPU is highly recommended')
|
||||
|
||||
datamanager = build_datamanager(args)
|
||||
|
||||
print('Building model: {}'.format(args.arch))
|
||||
if cfg.use_gpu:
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
datamanager = build_datamanager(cfg)
|
||||
|
||||
print('Building model: {}'.format(cfg.model.name))
|
||||
model = torchreid.models.build_model(
|
||||
name=args.arch,
|
||||
name=cfg.model.name,
|
||||
num_classes=datamanager.num_train_pids,
|
||||
loss=args.loss.lower(),
|
||||
pretrained=(not args.no_pretrained),
|
||||
use_gpu=use_gpu
|
||||
loss=cfg.loss.name,
|
||||
pretrained=cfg.model.pretrained,
|
||||
use_gpu=cfg.use_gpu
|
||||
)
|
||||
num_params, flops = compute_model_complexity(model, (1, 3, args.height, args.width))
|
||||
num_params, flops = compute_model_complexity(model, (1, 3, cfg.data.height, cfg.data.width))
|
||||
print('Model complexity: params={:,} flops={:,}'.format(num_params, flops))
|
||||
|
||||
if args.load_weights and check_isfile(args.load_weights):
|
||||
load_pretrained_weights(model, args.load_weights)
|
||||
if cfg.model.load_weights and check_isfile(cfg.model.load_weights):
|
||||
load_pretrained_weights(model, cfg.model.load_weights)
|
||||
|
||||
if use_gpu:
|
||||
if cfg.use_gpu:
|
||||
model = nn.DataParallel(model).cuda()
|
||||
|
||||
optimizer = torchreid.optim.build_optimizer(model, **optimizer_kwargs(args))
|
||||
optimizer = torchreid.optim.build_optimizer(model, **optimizer_kwargs(cfg))
|
||||
scheduler = torchreid.optim.build_lr_scheduler(optimizer, **lr_scheduler_kwargs(cfg))
|
||||
|
||||
scheduler = torchreid.optim.build_lr_scheduler(optimizer, **lr_scheduler_kwargs(args))
|
||||
if cfg.model.resume and check_isfile(cfg.model.resume):
|
||||
args.start_epoch = resume_from_checkpoint(cfg.model.resume, model, optimizer=optimizer)
|
||||
|
||||
if args.resume and check_isfile(args.resume):
|
||||
args.start_epoch = resume_from_checkpoint(args.resume, model, optimizer=optimizer)
|
||||
|
||||
print('Building {}-engine for {}-reid'.format(args.loss, args.app))
|
||||
engine = build_engine(args, datamanager, model, optimizer, scheduler)
|
||||
|
||||
engine.run(**engine_run_kwargs(args))
|
||||
print('Building {}-engine for {}-reid'.format(cfg.loss.name, cfg.data.type))
|
||||
engine = build_engine(cfg, datamanager, model, optimizer, scheduler)
|
||||
engine.run(**engine_run_kwargs(cfg))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -20,11 +20,13 @@ class DataManager(object):
|
|||
width (int, optional): target image width. Default is 128.
|
||||
transforms (str or list of str, optional): transformations applied to model training.
|
||||
Default is 'random_flip'.
|
||||
use_cpu (bool, optional): use cpu. Default is False.
|
||||
norm_mean (list or None, optional): data mean. Default is None (use imagenet mean).
|
||||
norm_std (list or None, optional): data std. Default is None (use imagenet std).
|
||||
use_gpu (bool, optional): use gpu. Default is True.
|
||||
"""
|
||||
|
||||
def __init__(self, sources=None, targets=None, height=256, width=128, transforms='random_flip',
|
||||
use_cpu=False):
|
||||
norm_mean=None, norm_std=None, use_gpu=False):
|
||||
self.sources = sources
|
||||
self.targets = targets
|
||||
self.height = height
|
||||
|
@ -43,10 +45,11 @@ class DataManager(object):
|
|||
self.targets = [self.targets]
|
||||
|
||||
self.transform_tr, self.transform_te = build_transforms(
|
||||
self.height, self.width, transforms
|
||||
self.height, self.width, transforms=transforms,
|
||||
norm_mean=norm_mean, norm_std=norm_std
|
||||
)
|
||||
|
||||
self.use_gpu = (torch.cuda.is_available() and not use_cpu)
|
||||
self.use_gpu = (torch.cuda.is_available() and use_gpu)
|
||||
|
||||
@property
|
||||
def num_train_pids(self):
|
||||
|
@ -84,11 +87,14 @@ class ImageDataManager(DataManager):
|
|||
width (int, optional): target image width. Default is 128.
|
||||
transforms (str or list of str, optional): transformations applied to model training.
|
||||
Default is 'random_flip'.
|
||||
use_cpu (bool, optional): use cpu. Default is False.
|
||||
norm_mean (list or None, optional): data mean. Default is None (use imagenet mean).
|
||||
norm_std (list or None, optional): data std. Default is None (use imagenet std).
|
||||
use_gpu (bool, optional): use gpu. Default is True.
|
||||
split_id (int, optional): split id (*0-based*). Default is 0.
|
||||
combineall (bool, optional): combine train, query and gallery in a dataset for
|
||||
training. Default is False.
|
||||
batch_size (int, optional): number of images in a batch. Default is 32.
|
||||
batch_size_train (int, optional): number of images in a training batch. Default is 32.
|
||||
batch_size_test (int, optional): number of images in a test batch. Default is 32.
|
||||
workers (int, optional): number of workers. Default is 4.
|
||||
num_instances (int, optional): number of instances per identity in a batch.
|
||||
Default is 4.
|
||||
|
@ -107,18 +113,20 @@ class ImageDataManager(DataManager):
|
|||
sources='market1501',
|
||||
height=256,
|
||||
width=128,
|
||||
batch_size=32
|
||||
batch_size_train=32,
|
||||
batch_size_test=100
|
||||
)
|
||||
"""
|
||||
data_type = 'image'
|
||||
|
||||
def __init__(self, root='', sources=None, targets=None, height=256, width=128, transforms='random_flip',
|
||||
use_cpu=False, split_id=0, combineall=False,
|
||||
batch_size=32, workers=4, num_instances=4, train_sampler='',
|
||||
norm_mean=None, norm_std=None, use_gpu=True, split_id=0, combineall=False,
|
||||
batch_size_train=32, batch_size_test=32, workers=4, num_instances=4, train_sampler='',
|
||||
cuhk03_labeled=False, cuhk03_classic_split=False, market1501_500k=False):
|
||||
|
||||
super(ImageDataManager, self).__init__(sources=sources, targets=targets, height=height, width=width,
|
||||
transforms=transforms, use_cpu=use_cpu)
|
||||
transforms=transforms, norm_mean=norm_mean, norm_std=norm_std,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
print('=> Loading train (source) dataset')
|
||||
trainset = []
|
||||
|
@ -142,14 +150,14 @@ class ImageDataManager(DataManager):
|
|||
|
||||
train_sampler = build_train_sampler(
|
||||
trainset.train, train_sampler,
|
||||
batch_size=batch_size,
|
||||
batch_size=batch_size_train,
|
||||
num_instances=num_instances
|
||||
)
|
||||
|
||||
self.trainloader = torch.utils.data.DataLoader(
|
||||
trainset,
|
||||
sampler=train_sampler,
|
||||
batch_size=batch_size,
|
||||
batch_size=batch_size_train,
|
||||
shuffle=False,
|
||||
num_workers=workers,
|
||||
pin_memory=self.use_gpu,
|
||||
|
@ -175,7 +183,7 @@ class ImageDataManager(DataManager):
|
|||
)
|
||||
self.testloader[name]['query'] = torch.utils.data.DataLoader(
|
||||
queryset,
|
||||
batch_size=batch_size,
|
||||
batch_size=batch_size_test,
|
||||
shuffle=False,
|
||||
num_workers=workers,
|
||||
pin_memory=self.use_gpu,
|
||||
|
@ -197,7 +205,7 @@ class ImageDataManager(DataManager):
|
|||
)
|
||||
self.testloader[name]['gallery'] = torch.utils.data.DataLoader(
|
||||
galleryset,
|
||||
batch_size=batch_size,
|
||||
batch_size=batch_size_test,
|
||||
shuffle=False,
|
||||
num_workers=workers,
|
||||
pin_memory=self.use_gpu,
|
||||
|
@ -231,19 +239,22 @@ class VideoDataManager(DataManager):
|
|||
width (int, optional): target image width. Default is 128.
|
||||
transforms (str or list of str, optional): transformations applied to model training.
|
||||
Default is 'random_flip'.
|
||||
use_cpu (bool, optional): use cpu. Default is False.
|
||||
norm_mean (list or None, optional): data mean. Default is None (use imagenet mean).
|
||||
norm_std (list or None, optional): data std. Default is None (use imagenet std).
|
||||
use_gpu (bool, optional): use gpu. Default is True.
|
||||
split_id (int, optional): split id (*0-based*). Default is 0.
|
||||
combineall (bool, optional): combine train, query and gallery in a dataset for
|
||||
training. Default is False.
|
||||
batch_size (int, optional): number of *tracklets* in a batch. Default is 3.
|
||||
batch_size_train (int, optional): number of tracklets in a training batch. Default is 3.
|
||||
batch_size_test (int, optional): number of tracklets in a test batch. Default is 3.
|
||||
workers (int, optional): number of workers. Default is 4.
|
||||
num_instances (int, optional): number of instances per identity in a batch.
|
||||
Default is 4.
|
||||
train_sampler (str, optional): sampler. Default is empty (``RandomSampler``).
|
||||
seq_len (int, optional): how many images to sample in a tracklet. Default is 15.
|
||||
sample_method (str, optional): how to sample images in a tracklet. Default is "evenly".
|
||||
Choices are ["evenly", "random", "all"]. "evenly" and "random" sample ``seq_len``
|
||||
images in a tracklet while "all" samples all images in a tracklet, thus ``batch_size``
|
||||
Choices are ["evenly", "random", "all"]. "evenly" and "random" will sample ``seq_len``
|
||||
images in a tracklet while "all" samples all images in a tracklet, where the batch size
|
||||
needs to be set to 1.
|
||||
|
||||
Examples::
|
||||
|
@ -253,7 +264,8 @@ class VideoDataManager(DataManager):
|
|||
sources='mars',
|
||||
height=256,
|
||||
width=128,
|
||||
batch_size=3,
|
||||
batch_size_train=3,
|
||||
batch_size_test=3,
|
||||
seq_len=15,
|
||||
sample_method='evenly'
|
||||
)
|
||||
|
@ -267,12 +279,13 @@ class VideoDataManager(DataManager):
|
|||
data_type = 'video'
|
||||
|
||||
def __init__(self, root='', sources=None, targets=None, height=256, width=128, transforms='random_flip',
|
||||
use_cpu=False, split_id=0, combineall=False,
|
||||
batch_size=3, workers=4, num_instances=4, train_sampler=None,
|
||||
norm_mean=None, norm_std=None, use_gpu=True, split_id=0, combineall=False,
|
||||
batch_size_train=3, batch_size_test=3, workers=4, num_instances=4, train_sampler=None,
|
||||
seq_len=15, sample_method='evenly'):
|
||||
|
||||
super(VideoDataManager, self).__init__(sources=sources, targets=targets, height=height, width=width,
|
||||
transforms=transforms, use_cpu=use_cpu)
|
||||
transforms=transforms, norm_mean=norm_mean, norm_std=norm_std,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
print('=> Loading train (source) dataset')
|
||||
trainset = []
|
||||
|
@ -295,14 +308,14 @@ class VideoDataManager(DataManager):
|
|||
|
||||
train_sampler = build_train_sampler(
|
||||
trainset.train, train_sampler,
|
||||
batch_size=batch_size,
|
||||
batch_size=batch_size_train,
|
||||
num_instances=num_instances
|
||||
)
|
||||
|
||||
self.trainloader = torch.utils.data.DataLoader(
|
||||
trainset,
|
||||
sampler=train_sampler,
|
||||
batch_size=batch_size,
|
||||
batch_size=batch_size_train,
|
||||
shuffle=False,
|
||||
num_workers=workers,
|
||||
pin_memory=self.use_gpu,
|
||||
|
@ -327,7 +340,7 @@ class VideoDataManager(DataManager):
|
|||
)
|
||||
self.testloader[name]['query'] = torch.utils.data.DataLoader(
|
||||
queryset,
|
||||
batch_size=batch_size,
|
||||
batch_size=batch_size_test,
|
||||
shuffle=False,
|
||||
num_workers=workers,
|
||||
pin_memory=self.use_gpu,
|
||||
|
@ -348,7 +361,7 @@ class VideoDataManager(DataManager):
|
|||
)
|
||||
self.testloader[name]['gallery'] = torch.utils.data.DataLoader(
|
||||
galleryset,
|
||||
batch_size=batch_size,
|
||||
batch_size=batch_size_test,
|
||||
shuffle=False,
|
||||
num_workers=workers,
|
||||
pin_memory=self.use_gpu,
|
||||
|
|
|
@ -140,8 +140,8 @@ def build_transforms(height, width, transforms='random_flip', norm_mean=[0.485,
|
|||
width (int): target image width.
|
||||
transforms (str or list of str, optional): transformations applied to model training.
|
||||
Default is 'random_flip'.
|
||||
norm_mean (list): normalization mean values. Default is ImageNet means.
|
||||
norm_std (list): normalization standard deviation values. Default is
|
||||
norm_mean (list or None, optional): normalization mean values. Default is ImageNet means.
|
||||
norm_std (list or None, optional): normalization standard deviation values. Default is
|
||||
ImageNet standard deviation values.
|
||||
"""
|
||||
if transforms is None:
|
||||
|
@ -156,6 +156,9 @@ def build_transforms(height, width, transforms='random_flip', norm_mean=[0.485,
|
|||
if len(transforms) > 0:
|
||||
transforms = [t.lower() for t in transforms]
|
||||
|
||||
if norm_mean is None or norm_std is None:
|
||||
norm_mean = [0.485, 0.456, 0.406] # imagenet mean
|
||||
norm_std = [0.229, 0.224, 0.225] # imagenet std
|
||||
normalize = Normalize(mean=norm_mean, std=norm_std)
|
||||
|
||||
print('Building train transforms ...')
|
||||
|
|
|
@ -35,15 +35,15 @@ class Engine(object):
|
|||
model (nn.Module): model instance.
|
||||
optimizer (Optimizer): an Optimizer.
|
||||
scheduler (LRScheduler, optional): if None, no learning rate decay will be performed.
|
||||
use_cpu (bool, optional): use cpu. Default is False.
|
||||
use_gpu (bool, optional): use gpu. Default is True.
|
||||
"""
|
||||
|
||||
def __init__(self, datamanager, model, optimizer=None, scheduler=None, use_cpu=False):
|
||||
def __init__(self, datamanager, model, optimizer=None, scheduler=None, use_gpu=True):
|
||||
self.datamanager = datamanager
|
||||
self.model = model
|
||||
self.optimizer = optimizer
|
||||
self.scheduler = scheduler
|
||||
self.use_gpu = (torch.cuda.is_available() and not use_cpu)
|
||||
self.use_gpu = (torch.cuda.is_available() and use_gpu)
|
||||
self.writer = None
|
||||
|
||||
# check attributes
|
||||
|
|
|
@ -23,7 +23,7 @@ class ImageSoftmaxEngine(engine.Engine):
|
|||
model (nn.Module): model instance.
|
||||
optimizer (Optimizer): an Optimizer.
|
||||
scheduler (LRScheduler, optional): if None, no learning rate decay will be performed.
|
||||
use_cpu (bool, optional): use cpu. Default is False.
|
||||
use_gpu (bool, optional): use gpu. Default is True.
|
||||
label_smooth (bool, optional): use label smoothing regularizer. Default is True.
|
||||
|
||||
Examples::
|
||||
|
@ -62,9 +62,9 @@ class ImageSoftmaxEngine(engine.Engine):
|
|||
)
|
||||
"""
|
||||
|
||||
def __init__(self, datamanager, model, optimizer, scheduler=None, use_cpu=False,
|
||||
def __init__(self, datamanager, model, optimizer, scheduler=None, use_gpu=True,
|
||||
label_smooth=True):
|
||||
super(ImageSoftmaxEngine, self).__init__(datamanager, model, optimizer, scheduler, use_cpu)
|
||||
super(ImageSoftmaxEngine, self).__init__(datamanager, model, optimizer, scheduler, use_gpu)
|
||||
|
||||
self.criterion = CrossEntropyLoss(
|
||||
num_classes=self.datamanager.num_train_pids,
|
||||
|
|
|
@ -26,7 +26,7 @@ class ImageTripletEngine(engine.Engine):
|
|||
weight_t (float, optional): weight for triplet loss. Default is 1.
|
||||
weight_x (float, optional): weight for softmax loss. Default is 1.
|
||||
scheduler (LRScheduler, optional): if None, no learning rate decay will be performed.
|
||||
use_cpu (bool, optional): use cpu. Default is False.
|
||||
use_gpu (bool, optional): use gpu. Default is True.
|
||||
label_smooth (bool, optional): use label smoothing regularizer. Default is True.
|
||||
|
||||
Examples::
|
||||
|
@ -69,9 +69,9 @@ class ImageTripletEngine(engine.Engine):
|
|||
"""
|
||||
|
||||
def __init__(self, datamanager, model, optimizer, margin=0.3,
|
||||
weight_t=1, weight_x=1, scheduler=None, use_cpu=False,
|
||||
weight_t=1, weight_x=1, scheduler=None, use_gpu=True,
|
||||
label_smooth=True):
|
||||
super(ImageTripletEngine, self).__init__(datamanager, model, optimizer, scheduler, use_cpu)
|
||||
super(ImageTripletEngine, self).__init__(datamanager, model, optimizer, scheduler, use_gpu)
|
||||
|
||||
self.weight_t = weight_t
|
||||
self.weight_x = weight_x
|
||||
|
|
|
@ -20,7 +20,7 @@ class VideoSoftmaxEngine(ImageSoftmaxEngine):
|
|||
model (nn.Module): model instance.
|
||||
optimizer (Optimizer): an Optimizer.
|
||||
scheduler (LRScheduler, optional): if None, no learning rate decay will be performed.
|
||||
use_cpu (bool, optional): use cpu. Default is False.
|
||||
use_gpu (bool, optional): use gpu. Default is True.
|
||||
label_smooth (bool, optional): use label smoothing regularizer. Default is True.
|
||||
pooling_method (str, optional): how to pool features for a tracklet.
|
||||
Default is "avg" (average). Choices are ["avg", "max"].
|
||||
|
@ -65,9 +65,9 @@ class VideoSoftmaxEngine(ImageSoftmaxEngine):
|
|||
"""
|
||||
|
||||
def __init__(self, datamanager, model, optimizer, scheduler=None,
|
||||
use_cpu=False, label_smooth=True, pooling_method='avg'):
|
||||
use_gpu=True, label_smooth=True, pooling_method='avg'):
|
||||
super(VideoSoftmaxEngine, self).__init__(datamanager, model, optimizer, scheduler=scheduler,
|
||||
use_cpu=use_cpu, label_smooth=label_smooth)
|
||||
use_gpu=use_gpu, label_smooth=label_smooth)
|
||||
self.pooling_method = pooling_method
|
||||
|
||||
def _parse_data_for_train(self, data):
|
||||
|
|
|
@ -24,7 +24,7 @@ class VideoTripletEngine(ImageTripletEngine, VideoSoftmaxEngine):
|
|||
weight_t (float, optional): weight for triplet loss. Default is 1.
|
||||
weight_x (float, optional): weight for softmax loss. Default is 1.
|
||||
scheduler (LRScheduler, optional): if None, no learning rate decay will be performed.
|
||||
use_cpu (bool, optional): use cpu. Default is False.
|
||||
use_gpu (bool, optional): use gpu. Default is True.
|
||||
label_smooth (bool, optional): use label smoothing regularizer. Default is True.
|
||||
pooling_method (str, optional): how to pool features for a tracklet.
|
||||
Default is "avg" (average). Choices are ["avg", "max"].
|
||||
|
@ -73,10 +73,10 @@ class VideoTripletEngine(ImageTripletEngine, VideoSoftmaxEngine):
|
|||
"""
|
||||
|
||||
def __init__(self, datamanager, model, optimizer, margin=0.3,
|
||||
weight_t=1, weight_x=1, scheduler=None, use_cpu=False,
|
||||
weight_t=1, weight_x=1, scheduler=None, use_gpu=False,
|
||||
label_smooth=True, pooling_method='avg'):
|
||||
super(VideoTripletEngine, self).__init__(datamanager, model, optimizer, margin=margin,
|
||||
weight_t=weight_t, weight_x=weight_x,
|
||||
scheduler=scheduler, use_cpu=use_cpu,
|
||||
scheduler=scheduler, use_gpu=use_gpu,
|
||||
label_smooth=label_smooth)
|
||||
self.pooling_method = pooling_method
|
||||
self.pooling_method = pooling_method
|
||||
|
|
|
@ -9,6 +9,15 @@ from torch.nn import functional as F
|
|||
import torchvision
|
||||
|
||||
|
||||
pretrained_urls = {
|
||||
'osnet_x1_0': 'https://drive.google.com/uc?id=1LaG1EJpHrxdAxKnSCJ_i0u-nbxSAeiFY',
|
||||
'osnet_x0_75': 'https://drive.google.com/uc?id=1uwA9fElHOk3ZogwbeY5GkLI6QPTX70Hq',
|
||||
'osnet_x0_5': 'https://drive.google.com/uc?id=16DGLbZukvVYgINws8u8deSaOqjybZ83i',
|
||||
'osnet_x0_25': 'https://drive.google.com/uc?id=1rb8UN5ZzPKRc_xvtHlyDh-cSz88YX9hs',
|
||||
'osnet_ibn_x1_0': 'https://drive.google.com/uc?id=1sr90V6irlYYDd4_4ISU2iruoRG8J__6l'
|
||||
}
|
||||
|
||||
|
||||
##########
|
||||
# Basic layers
|
||||
##########
|
||||
|
@ -311,35 +320,116 @@ class OSNet(nn.Module):
|
|||
raise KeyError("Unsupported loss: {}".format(self.loss))
|
||||
|
||||
|
||||
def init_pretrained_weights(model, key=''):
|
||||
"""Initializes model with pretrained weights.
|
||||
|
||||
Layers that don't match with pretrained layers in name or size are kept unchanged.
|
||||
"""
|
||||
import os
|
||||
import errno
|
||||
import gdown
|
||||
from collections import OrderedDict
|
||||
|
||||
def _get_torch_home():
|
||||
ENV_TORCH_HOME = 'TORCH_HOME'
|
||||
ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
|
||||
DEFAULT_CACHE_DIR = '~/.cache'
|
||||
torch_home = os.path.expanduser(
|
||||
os.getenv(ENV_TORCH_HOME,
|
||||
os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'torch')))
|
||||
return torch_home
|
||||
|
||||
torch_home = _get_torch_home()
|
||||
model_dir = os.path.join(torch_home, 'checkpoints')
|
||||
try:
|
||||
os.makedirs(model_dir)
|
||||
except OSError as e:
|
||||
if e.errno == errno.EEXIST:
|
||||
# Directory already exists, ignore.
|
||||
pass
|
||||
else:
|
||||
# Unexpected OSError, re-raise.
|
||||
raise
|
||||
filename = key + '_imagenet.pth'
|
||||
cached_file = os.path.join(model_dir, filename)
|
||||
|
||||
if not os.path.exists(cached_file):
|
||||
gdown.download(pretrained_urls[key], cached_file, quiet=False)
|
||||
|
||||
state_dict = torch.load(cached_file)
|
||||
model_dict = model.state_dict()
|
||||
new_state_dict = OrderedDict()
|
||||
matched_layers, discarded_layers = [], []
|
||||
|
||||
for k, v in state_dict.items():
|
||||
if k.startswith('module.'):
|
||||
k = k[7:] # discard module.
|
||||
|
||||
if k in model_dict and model_dict[k].size() == v.size():
|
||||
new_state_dict[k] = v
|
||||
matched_layers.append(k)
|
||||
else:
|
||||
discarded_layers.append(k)
|
||||
|
||||
model_dict.update(new_state_dict)
|
||||
model.load_state_dict(model_dict)
|
||||
|
||||
if len(matched_layers) == 0:
|
||||
warnings.warn(
|
||||
'The pretrained weights from "{}" cannot be loaded, '
|
||||
'please check the key names manually '
|
||||
'(** ignored and continue **)'.format(cached_file))
|
||||
else:
|
||||
print('Successfully loaded imagenet pretrained weights from "{}"'.format(cached_file))
|
||||
if len(discarded_layers) > 0:
|
||||
print('** The following layers are discarded '
|
||||
'due to unmatched keys or layer size: {}'.format(discarded_layers))
|
||||
|
||||
|
||||
##########
|
||||
# Instantiation
|
||||
##########
|
||||
def osnet_x1_0(num_classes=1000, loss='softmax', **kwargs):
|
||||
def osnet_x1_0(num_classes=1000, pretrained=True, loss='softmax', **kwargs):
|
||||
# standard size (width x1.0)
|
||||
return OSNet(num_classes, blocks=[OSBlock, OSBlock, OSBlock], layers=[2, 2, 2],
|
||||
channels=[64, 256, 384, 512], loss=loss, **kwargs)
|
||||
model = OSNet(num_classes, blocks=[OSBlock, OSBlock, OSBlock], layers=[2, 2, 2],
|
||||
channels=[64, 256, 384, 512], loss=loss, **kwargs)
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, key='osnet_x1_0')
|
||||
return model
|
||||
|
||||
|
||||
def osnet_x0_75(num_classes=1000, loss='softmax', **kwargs):
|
||||
def osnet_x0_75(num_classes=1000, pretrained=True, loss='softmax', **kwargs):
|
||||
# medium size (width x0.75)
|
||||
return OSNet(num_classes, blocks=[OSBlock, OSBlock, OSBlock], layers=[2, 2, 2],
|
||||
channels=[48, 192, 288, 384], loss=loss, **kwargs)
|
||||
model = OSNet(num_classes, blocks=[OSBlock, OSBlock, OSBlock], layers=[2, 2, 2],
|
||||
channels=[48, 192, 288, 384], loss=loss, **kwargs)
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, key='osnet_x0_75')
|
||||
return model
|
||||
|
||||
|
||||
def osnet_x0_5(num_classes=1000, loss='softmax', **kwargs):
|
||||
def osnet_x0_5(num_classes=1000, pretrained=True, loss='softmax', **kwargs):
|
||||
# tiny size (width x0.5)
|
||||
return OSNet(num_classes, blocks=[OSBlock, OSBlock, OSBlock], layers=[2, 2, 2],
|
||||
channels=[32, 128, 192, 256], loss=loss, **kwargs)
|
||||
model = OSNet(num_classes, blocks=[OSBlock, OSBlock, OSBlock], layers=[2, 2, 2],
|
||||
channels=[32, 128, 192, 256], loss=loss, **kwargs)
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, key='osnet_x0_5')
|
||||
return model
|
||||
|
||||
|
||||
def osnet_x0_25(num_classes=1000, loss='softmax', **kwargs):
|
||||
def osnet_x0_25(num_classes=1000, pretrained=True, loss='softmax', **kwargs):
|
||||
# very tiny size (width x0.25)
|
||||
return OSNet(num_classes, blocks=[OSBlock, OSBlock, OSBlock], layers=[2, 2, 2],
|
||||
channels=[16, 64, 96, 128], loss=loss, **kwargs)
|
||||
model = OSNet(num_classes, blocks=[OSBlock, OSBlock, OSBlock], layers=[2, 2, 2],
|
||||
channels=[16, 64, 96, 128], loss=loss, **kwargs)
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, key='osnet_x0_25')
|
||||
return model
|
||||
|
||||
|
||||
def osnet_ibn_x1_0(num_classes=1000, loss='softmax', **kwargs):
|
||||
def osnet_ibn_x1_0(num_classes=1000, pretrained=True, loss='softmax', **kwargs):
|
||||
# standard size (width x1.0) + IBN layer
|
||||
# Ref: Pan et al. Two at Once: Enhancing Learning and Generalization Capacities via IBN-Net. ECCV, 2018.
|
||||
return OSNet(num_classes, blocks=[OSBlock, OSBlock, OSBlock], layers=[2, 2, 2],
|
||||
channels=[64, 256, 384, 512], loss=loss, IN=True, **kwargs)
|
||||
model = OSNet(num_classes, blocks=[OSBlock, OSBlock, OSBlock], layers=[2, 2, 2],
|
||||
channels=[64, 256, 384, 512], loss=loss, IN=True, **kwargs)
|
||||
if pretrained:
|
||||
init_pretrained_weights(model, key='osnet_ibn_x1_0')
|
||||
return model
|
||||
|
|
|
@ -279,4 +279,4 @@ def load_pretrained_weights(model, weight_path):
|
|||
print('Successfully loaded pretrained weights from "{}"'.format(weight_path))
|
||||
if len(discarded_layers) > 0:
|
||||
print('** The following layers are discarded '
|
||||
'due to unmatched keys or layer size: {}'.format(discarded_layers))
|
||||
'due to unmatched keys or layer size: {}'.format(discarded_layers))
|
||||
|
|
Loading…
Reference in New Issue