From 69e12d989d427a09bb2ec0ee3e90ca1a2998d23c Mon Sep 17 00:00:00 2001 From: sherlock Date: Thu, 10 Jan 2019 18:39:31 +0800 Subject: [PATCH] Update first stable version v1.0 --- README.md | 58 ++-- config/__init__.py | 7 + config/defaults.py | 101 +++++++ configs/market_softmax.yml | 39 --- configs/market_softmax_triplet.yml | 41 --- configs/market_triplet.yml | 40 --- configs/softmax.yml | 43 +++ configs/softmax_triplet.yml | 45 +++ core/__init__.py | 11 - core/config.py | 79 ------ core/loader.py | 126 --------- core/solver.py | 187 ------------- data/__init__.py | 7 + data/build.py | 44 +++ data/collate_batch.py | 18 ++ data/datasets/__init__.py | 25 ++ data/datasets/bases.py | 95 +++++++ data/datasets/cuhk03.py | 259 ++++++++++++++++++ data/datasets/dataset_loader.py | 45 +++ data/datasets/dukemtmcreid.py | 106 +++++++ data/datasets/eval_reid.py | 63 +++++ .../datasets/market1501.py | 70 ++--- data/samplers/__init__.py | 7 + data/samplers/triplet_sampler.py | 73 +++++ data/transforms/__init__.py | 7 + data/transforms/build.py | 31 +++ .../transforms/transforms.py | 49 +--- engine/inference.py | 64 +++++ engine/trainer.py | 150 ++++++++++ layers/__init__.py | 28 ++ utils/loss.py => layers/triplet_loss.py | 40 +-- modeling/__init__.py | 13 + modeling/backbones/__init__.py | 6 + {network => modeling/backbones}/resnet.py | 21 +- {network => modeling}/baseline.py | 24 +- network/__init__.py | 13 - scripts/test.sh | 5 - scripts/train_softmax.sh | 8 - scripts/train_softmax_triplet.sh | 8 - scripts/train_triplet.sh | 8 - solver/__init__.py | 8 + solver/build.py | 25 ++ solver/lr_scheduler.py | 56 ++++ tests/__init__.py | 5 + tests/lr_scheduler_test.py | 26 ++ tools/__init__.py | 6 - tools/test.py | 85 +++--- tools/train.py | 128 ++++----- utils/__init__.py | 9 +- utils/iotools.py | 39 +++ utils/logger.py | 30 ++ utils/lr_scheduler.py | 65 ----- utils/meters.py | 54 ---- utils/reid_metric.py | 48 ++++ utils/serialization.py | 35 --- 55 files changed, 1641 insertions(+), 1042 deletions(-) create mode 100644 config/__init__.py create mode 100644 config/defaults.py delete mode 100644 configs/market_softmax.yml delete mode 100644 configs/market_softmax_triplet.yml delete mode 100644 configs/market_triplet.yml create mode 100644 configs/softmax.yml create mode 100644 configs/softmax_triplet.yml delete mode 100644 core/__init__.py delete mode 100644 core/config.py delete mode 100755 core/loader.py delete mode 100644 core/solver.py create mode 100644 data/__init__.py create mode 100644 data/build.py create mode 100644 data/collate_batch.py create mode 100644 data/datasets/__init__.py create mode 100644 data/datasets/bases.py create mode 100644 data/datasets/cuhk03.py create mode 100644 data/datasets/dataset_loader.py create mode 100644 data/datasets/dukemtmcreid.py create mode 100644 data/datasets/eval_reid.py rename core/data_manager.py => data/datasets/market1501.py (52%) mode change 100755 => 100644 create mode 100644 data/samplers/__init__.py create mode 100644 data/samplers/triplet_sampler.py create mode 100644 data/transforms/__init__.py create mode 100644 data/transforms/build.py rename utils/augmenter.py => data/transforms/transforms.py (53%) create mode 100644 engine/inference.py create mode 100644 engine/trainer.py create mode 100644 layers/__init__.py rename utils/loss.py => layers/triplet_loss.py (74%) create mode 100644 modeling/__init__.py create mode 100644 modeling/backbones/__init__.py rename {network => modeling/backbones}/resnet.py (89%) rename {network => modeling}/baseline.py (77%) delete mode 100644 network/__init__.py delete mode 100644 scripts/test.sh delete mode 100644 scripts/train_softmax.sh delete mode 100644 scripts/train_softmax_triplet.sh delete mode 100644 scripts/train_triplet.sh create mode 100644 solver/__init__.py create mode 100644 solver/build.py create mode 100644 solver/lr_scheduler.py create mode 100644 tests/__init__.py create mode 100644 tests/lr_scheduler_test.py create mode 100644 utils/iotools.py create mode 100644 utils/logger.py delete mode 100644 utils/lr_scheduler.py delete mode 100644 utils/meters.py create mode 100644 utils/reid_metric.py delete mode 100644 utils/serialization.py diff --git a/README.md b/README.md index b0bc885..ea9479c 100644 --- a/README.md +++ b/README.md @@ -1,21 +1,23 @@ # ReID_baseline -Baseline model (with bottleneck) for person ReID (using softmax and triplet loss). This is PyTorch version, [mxnet version](https://github.com/L1aoXingyu/reid_baseline_gluon) has a better result and more SOTA methods. +Baseline model (with bottleneck) for person ReID (using softmax and triplet loss). We support -- multi-GPU training -- easy dataset preparation -- end-to-end training and evaluation +- [x] easy dataset preparation +- [x] end-to-end training and evaluation +- [x] high modular management ## Get Started +The designed architecture follows this guide [PyTorch-Project-Template](https://github.com/L1aoXingyu/PyTorch-Project-Template), you can check each folder's purpose by yourself. + 1. `cd` to folder where you want to download this repo 2. Run `git clone https://github.com/L1aoXingyu/reid_baseline.git` 3. Install dependencies: - - [pytorch 0.4](https://pytorch.org/) + - [pytorch 1.0](https://pytorch.org/) - torchvision - - tensorflow (for tensorboard) - - [tensorboardX](https://github.com/lanpa/tensorboardX) + - [ignite](https://github.com/pytorch/ignite) + - [yacs](https://github.com/rbgirshick/yacs) 4. Prepare dataset - + Create a directory to store reid datasets under this repo via ```bash cd reid_baseline @@ -23,39 +25,43 @@ We support ``` 1. Download dataset to `data/` from http://www.liangzheng.org/Project/project_reid.html 2. Extract dataset and rename to `market1501`. The data structure would like: - ``` - market1501/ - bounding_box_test/ - bounding_box_train/ + ```bash + data + market1501 + bounding_box_test/ + bounding_box_train/ ``` 5. Prepare pretrained model if you don't have ```python from torchvision import models models.resnet50(pretrained=True) ``` - Then it will automatically download model in `~.torch/models/`, you should set this path in `config.py` + Then it will automatically download model in `~/.torch/models/`, you should set this path in `config/defaults.py` for all training or set in every single training config file in `configs/`. ## Train -You can run +Most of the configuration files that we provide, you can run this command for training ```bash -bash scripts/train_triplet_softmax.sh +python3 tools/train.py --config_file='configs/market1501_softmax_bs64.yml' +``` + +You can also modify your cfg parameters as follow +```bash +python3 tools/train.py --config_file='configs/market1501_softmax_bs64.yml' INPUT.SIZE_TRAIN '(256, 128)' INPUT.SIZE_TEST '(256, 128)' ``` -in `reid_baseline` folder if you want to train with softmax and triplet loss. You can find others train scripts in `scripts`. ## Results **network architecture** +
+
+ +| cfg | market1501 | cuhk03 | dukemtmc | +| --- | -- | -- | -- | +| softmax, size=(384, 128), batch_size=64 | 92.5 (79.4) | 60.4 (56.1) | 84.6 (68.1) | +| softmax, size=(256, 128), batch_size=64 | 92.0 (80.4) | 60.5 (55.5) | 84.1(68.4) | +| softmax_triplet, size=(384, 128), batch_size=128(32 id x 4 imgs) | 93.2 (82.5) | - | 86.4 (73.1) +| softmax_triplet, size=(256, 128), batch_size=128(32 id x 4 imgs) | 93.8 (83.2) | 65.9 (61.4) | - -| config | Market1501 | -| --- | -- | -| bs(32) size(384,128) softmax | 92.2 (78.5) | -| bs(64) size(384,128) softmax | 92.5 (79.6) | -| bs(32) size(256,128) softmax | 92.0 (78.4) | -| bs(64) size(256,128) softmax | 91.7 (78.3) | -| bs(128) size(256,128) softmax | 91.2 (77.4) | -| triplet(p=32,k=4) size(256,128) | 88.3 (73.8) | -| triplet(p=16,k=4)+softmax size(384,128) | 93.1 (82.0) | -| triplet(p=24,k=4)+softmax size(384,128) | 91.7 (79.0) | diff --git a/config/__init__.py b/config/__init__.py new file mode 100644 index 0000000..5d60189 --- /dev/null +++ b/config/__init__.py @@ -0,0 +1,7 @@ +# encoding: utf-8 +""" +@author: sherlock +@contact: sherlockliao01@gmail.com +""" + +from .defaults import _C as cfg diff --git a/config/defaults.py b/config/defaults.py new file mode 100644 index 0000000..12c5ae5 --- /dev/null +++ b/config/defaults.py @@ -0,0 +1,101 @@ +from yacs.config import CfgNode as CN + +# ----------------------------------------------------------------------------- +# Convention about Training / Test specific parameters +# ----------------------------------------------------------------------------- +# Whenever an argument can be either used for training or for testing, the +# corresponding name will be post-fixed by a _TRAIN for a training parameter, +# or _TEST for a test-specific parameter. +# For example, the number of images during training will be +# IMAGES_PER_BATCH_TRAIN, while the number of images for testing will be +# IMAGES_PER_BATCH_TEST + +# ----------------------------------------------------------------------------- +# Config definition +# ----------------------------------------------------------------------------- + +_C = CN() + +_C.MODEL = CN() +_C.MODEL.DEVICE = "cuda" +_C.MODEL.NAME = 'resnet50' +_C.MODEL.LAST_STRIDE = 1 +_C.MODEL.PRETRAIN_PATH = '' +# ----------------------------------------------------------------------------- +# INPUT +# ----------------------------------------------------------------------------- +_C.INPUT = CN() +# Size of the image during training +_C.INPUT.SIZE_TRAIN = [384, 128] +# Size of the image during test +_C.INPUT.SIZE_TEST = [384, 128] +# Random probability for image horizontal flip +_C.INPUT.PROB = 0.5 +# Values to be used for image normalization +_C.INPUT.PIXEL_MEAN = [0.485, 0.456, 0.406] +# Values to be used for image normalization +_C.INPUT.PIXEL_STD = [0.229, 0.224, 0.225] +# Value of padding size +_C.INPUT.PADDING = 10 + +# ----------------------------------------------------------------------------- +# Dataset +# ----------------------------------------------------------------------------- +_C.DATASETS = CN() +# List of the dataset names for training, as present in paths_catalog.py +_C.DATASETS.NAMES = ('market1501') + +# ----------------------------------------------------------------------------- +# DataLoader +# ----------------------------------------------------------------------------- +_C.DATALOADER = CN() +# Number of data loading threads +_C.DATALOADER.NUM_WORKERS = 8 +# Sampler for data loading +_C.DATALOADER.SAMPLER = 'softmax' +# Number of instance for one batch +_C.DATALOADER.NUM_INSTANCE = 16 + +# ---------------------------------------------------------------------------- # +# Solver +# ---------------------------------------------------------------------------- # +_C.SOLVER = CN() +_C.SOLVER.OPTIMIZER_NAME = "Adam" + +_C.SOLVER.MAX_EPOCHS = 50 + +_C.SOLVER.BASE_LR = 3e-4 +_C.SOLVER.BIAS_LR_FACTOR = 2 + +_C.SOLVER.MOMENTUM = 0.9 + +_C.SOLVER.MARGIN = 0.3 + +_C.SOLVER.WEIGHT_DECAY = 0.0005 +_C.SOLVER.WEIGHT_DECAY_BIAS = 0. + +_C.SOLVER.GAMMA = 0.1 +_C.SOLVER.STEPS = (30, 55) + +_C.SOLVER.WARMUP_FACTOR = 1.0 / 3 +_C.SOLVER.WARMUP_ITERS = 500 +_C.SOLVER.WARMUP_METHOD = "linear" + +_C.SOLVER.CHECKPOINT_PERIOD = 50 +_C.SOLVER.LOG_PERIOD = 100 +_C.SOLVER.EVAL_PERIOD = 50 +# Number of images per batch +# This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will +# see 2 images per batch +_C.SOLVER.IMS_PER_BATCH = 64 + +# This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will +# see 2 images per batch +_C.TEST = CN() +_C.TEST.IMS_PER_BATCH = 128 +_C.TEST.WEIGHT = "" + +# ---------------------------------------------------------------------------- # +# Misc options +# ---------------------------------------------------------------------------- # +_C.OUTPUT_DIR = "" diff --git a/configs/market_softmax.yml b/configs/market_softmax.yml deleted file mode 100644 index 6eed905..0000000 --- a/configs/market_softmax.yml +++ /dev/null @@ -1,39 +0,0 @@ -# configuration for training market1501 - -dataset: - name: market1501 - -aug: - resize_size: [384, 128] - random_mirror: True - pad: 10 - random_crop: True - random_erasing: True - -train: - optimizer: 'Adam' - lr: 0.00035 - num_epochs: 80 - batch_size: 32 - sampler: 'softmax' - wd: 0.0005 - step: [30, 55] - factor: 0.1 - warmup_epoch: 5 - warmup_begin_lr: 0.0000035 - loss_fn: 'softmax' - -test: - batch_size: 128 - -network: - name: 'Baseline' - last_stride: 1 - gpus: '0' - -misc: - eval_step: 20 - save_step: 20 - log_interval: 100 - - diff --git a/configs/market_softmax_triplet.yml b/configs/market_softmax_triplet.yml deleted file mode 100644 index 5cffbed..0000000 --- a/configs/market_softmax_triplet.yml +++ /dev/null @@ -1,41 +0,0 @@ -# configuration for training market1501 - -dataset: - name: market1501 - -aug: - resize_size: [384, 128] - random_mirror: True - pad: 10 - random_crop: True - random_erasing: True - -train: - optimizer: 'Adam' - lr: 0.00035 - num_epochs: 400 - p_size: 16 - k_size: 4 - sampler: 'triplet' - wd: 0.0005 - step: [80, 180, 300] - factor: 0.1 - warmup_epoch: 20 - warmup_begin_lr: 0.0000035 - loss_fn: 'softmax_triplet' - - -test: - batch_size: 128 - -network: - name: 'Baseline' - last_stride: 1 - gpus: '1' - -misc: - eval_step: 50 - save_step: 50 - log_interval: 20 - - diff --git a/configs/market_triplet.yml b/configs/market_triplet.yml deleted file mode 100644 index 31a7931..0000000 --- a/configs/market_triplet.yml +++ /dev/null @@ -1,40 +0,0 @@ -# configuration for training market1501 - -dataset: - name: market1501 - -aug: - resize_size: [384, 128] - random_mirror: True - pad: 10 - random_crop: True - -train: - optimizer: 'Adam' - lr: 0.00035 - num_epochs: 400 - p_size: 32 - k_size: 4 - sampler: 'triplet' - wd: 0.0005 - step: [80, 180, 300] - factor: 0.1 - warmup_epoch: 20 - warmup_begin_lr: 0.0000035 - loss_fn: 'triplet' - - -test: - batch_size: 128 - -network: - name: 'Baseline' - last_stride: 1 - gpus: '1' - -misc: - eval_step: 50 - save_step: 50 - log_interval: 20 - - diff --git a/configs/softmax.yml b/configs/softmax.yml new file mode 100644 index 0000000..5e0ded7 --- /dev/null +++ b/configs/softmax.yml @@ -0,0 +1,43 @@ +MODEL: + PRETRAIN_PATH: '/export/home/lxy/.torch/models/resnet50-19c8e357.pth' + + +INPUT: + SIZE_TRAIN: [384, 128] + SIZE_TEST: [384, 128] + PROB: 0.5 # random horizontal flip + PADDING: 10 + +DATASETS: + NAMES: ('market1501') + +DATALOADER: + SAMPLER: 'softmax' + NUM_WORKERS: 8 + +SOLVER: + OPTIMIZER_NAME: 'Adam' + MAX_EPOCHS: 120 + BASE_LR: 0.00035 + BIAS_LR_FACTOR: 1 + WEIGHT_DECAY: 0.0005 + WEIGHT_DECAY_BIAS: 0.0005 + IMS_PER_BATCH: 64 + + STEPS: [30, 55] + GAMMA: 0.1 + + WARMUP_FACTOR: 0.01 + WARMUP_ITERS: 5 + WARMUP_METHOD: 'linear' + + CHECKPOINT_PERIOD: 20 + LOG_PERIOD: 100 + EVAL_PERIOD: 20 + +TEST: + IMS_PER_BATCH: 256 + +OUTPUT_DIR: "/export/home/lxy/CHECKPOINTS/reid/market1501/softmax_bs64_384x128" + + diff --git a/configs/softmax_triplet.yml b/configs/softmax_triplet.yml new file mode 100644 index 0000000..b999844 --- /dev/null +++ b/configs/softmax_triplet.yml @@ -0,0 +1,45 @@ +MODEL: + PRETRAIN_PATH: '/export/home/lxy/.torch/models/resnet50-19c8e357.pth' + + +INPUT: + SIZE_TRAIN: [384, 128] + SIZE_TEST: [384, 128] + PROB: 0.5 # random horizontal flip + PADDING: 10 + +DATASETS: + NAMES: ('market1501') + +DATALOADER: + SAMPLER: 'softmax_triplet' + NUM_INSTANCE: 4 + NUM_WORKERS: 8 + +SOLVER: + OPTIMIZER_NAME: 'Adam' + MAX_EPOCHS: 120 + BASE_LR: 0.00035 + BIAS_LR_FACTOR: 1 + WEIGHT_DECAY: 0.0005 + WEIGHT_DECAY_BIAS: 0.0005 + IMS_PER_BATCH: 64 + + STEPS: [40, 70] + GAMMA: 0.1 + + WARMUP_FACTOR: 0.01 + WARMUP_ITERS: 10 + WARMUP_METHOD: 'linear' + + CHECKPOINT_PERIOD: 40 + LOG_PERIOD: 100 + EVAL_PERIOD: 40 + +TEST: + IMS_PER_BATCH: 256 + WEIGHT: "path" + +OUTPUT_DIR: "/export/home/lxy/CHECKPOINTS/reid/market1501/softmax_triplet_bs128_384x128" + + diff --git a/core/__init__.py b/core/__init__.py deleted file mode 100644 index 0572b5a..0000000 --- a/core/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -# encoding: utf-8 -""" -@author: sherlock -@contact: sherlockliao01@gmail.com -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -from __future__ import unicode_literals - diff --git a/core/config.py b/core/config.py deleted file mode 100644 index 70c8266..0000000 --- a/core/config.py +++ /dev/null @@ -1,79 +0,0 @@ -# encoding: utf-8 -""" -@author: sherlock -@contact: sherlockliao01@gmail.com -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -from __future__ import unicode_literals - -import yaml -from easydict import EasyDict as edict - -__C = edict() -opt = __C -__C.seed = 0 - -__C.dataset = edict() -__C.dataset.name = 'market1501' -__C.dataset.num_classes = 751 - -__C.aug = edict() -__C.aug.resize_size = [256, 128] -__C.aug.color_jitter = False -__C.aug.random_erasing = False -__C.aug.random_mirror = True -__C.aug.pad = 10 -__C.aug.random_crop = True - -__C.train = edict() -__C.train.optimizer = 'Adam' -__C.train.lr = 3e-4 -__C.train.wd = 5e-4 -__C.train.momentum = 0.9 -__C.train.step = [80, 180, 300] -__C.train.warmup_epoch = 20 -__C.train.warmup_begin_lr = 3e-6 -__C.train.factor = 0.1 -__C.train.margin = 0.3 -__C.train.num_epochs = 400 -__C.train.sampler = 'softmax' -__C.train.p_size = 32 # number of person in a single gpu -__C.train.k_size = 4 # number of images per person -__C.train.batch_size = 128 -__C.train.loss_fn = 'softmax' # softmax, triplet, softmax_triplet -__C.train.triplet_normalize = False - -__C.test = edict() -__C.test.batch_size = 128 -__C.test.load_path = '/mnt/truenas/scratch/xingyu.liao/DATA/mx-ckpt' - -__C.network = edict() -__C.network.depth = 50 -__C.network.name = 'Baseline' -__C.network.last_stride = 1 -__C.network.gpus = "1" -__C.network.workers = 8 - -__C.misc = edict() -__C.misc.log_interval = 10 -__C.misc.eval_step = 50 -__C.misc.save_step = 50 -__C.misc.save_dir = '' - - -def update_config(config_file): - exp_config = None - with open(config_file) as f: - exp_config = edict(yaml.load(f)) - for k, v in exp_config.items(): - if k in __C: - if isinstance(v, dict): - for vk, vv in v.items(): - __C[k][vk] = vv - else: - __C[k] = v - else: - raise ValueError("key must exist in configs.py") diff --git a/core/loader.py b/core/loader.py deleted file mode 100755 index ce56564..0000000 --- a/core/loader.py +++ /dev/null @@ -1,126 +0,0 @@ -from __future__ import print_function, absolute_import - -from collections import defaultdict - -import numpy as np -import torchvision.transforms as T -from PIL import Image -from torch.utils.data import Dataset, Sampler, DataLoader - -from utils import augmenter -from .data_manager import init_dataset - - -def read_image(img_path): - """Keep reading image until succeed. - This can avoid IOError incurred by heavy IO process.""" - got_img = False - while not got_img: - try: - img = Image.open(img_path).convert("RGB") - got_img = True - except IOError: - print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path)) - pass - return img - - -class ImageData(Dataset): - def __init__(self, dataset, transform): - self.dataset = dataset - self.transform = transform - - def __getitem__(self, item): - img, pid, camid = self.dataset[item] - img = read_image(img) - if self.transform is not None: - img = self.transform(img) - return img, pid, camid - - def __len__(self): - return len(self.dataset) - - -class RandomIdentitySampler(Sampler): - def __init__(self, data_source, num_instances=4): - self.data_source = data_source - self.num_instances = num_instances - self.index_dic = defaultdict(list) - for index, (_, pid, _) in enumerate(data_source): - self.index_dic[pid].append(index) - self.pids = list(self.index_dic.keys()) - self.num_identities = len(self.pids) - - def __iter__(self): - indices = np.random.permutation(self.num_identities) - ret = [] - for i in indices: - pid = self.pids[i] - t = self.index_dic[pid] - replace = False if len(t) >= self.num_instances else True - t = np.random.choice(t, size=self.num_instances, replace=replace) - ret.extend(t) - return iter(ret) - - def __len__(self): - return self.num_identities * self.num_instances - - -def get_data_provider(opt): - num_gpus = (len(opt.network.gpus) + 1) // 2 - test_batch_size = opt.test.batch_size * num_gpus - - # data augmenter - random_mirror = opt.aug.get('random_mirror', False) - pad = opt.aug.get('pad', False) - random_crop = opt.aug.get('random_crop', False) - random_erasing = opt.aug.get('random_erasing', False) - - h, w = opt.aug.resize_size - train_aug = list() - train_aug.append(T.Resize((h, w))) - if random_mirror: - train_aug.append(T.RandomHorizontalFlip()) - if pad: - train_aug.append(T.Pad(padding=pad)) - if random_crop: - train_aug.append(T.RandomCrop((h, w))) - train_aug.append(T.ToTensor()) - train_aug.append(T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])) - if random_erasing: - train_aug.append(augmenter.RandomErasing()) - train_aug = T.Compose(train_aug) - - test_aug = list() - test_aug.append(T.Resize((h, w))) - test_aug.append(T.ToTensor()) - test_aug.append(T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])) - test_aug = T.Compose(test_aug) - - dataset = init_dataset(opt.dataset.name) - train_set = ImageData(dataset.train, train_aug) - test_set = ImageData(dataset.query + dataset.gallery, test_aug) - - if opt.train.sampler == 'softmax': - train_batch_size = opt.train.batch_size * num_gpus - train_loader = DataLoader(train_set, batch_size=train_batch_size, shuffle=True, - num_workers=opt.network.workers, pin_memory=True, drop_last=True) - elif opt.train.sampler == 'triplet': - train_batch_size = opt.train.p_size * num_gpus * opt.train.k_size - train_loader = DataLoader(train_set, batch_size=train_batch_size, - sampler=RandomIdentitySampler(dataset.train, opt.train.k_size), - num_workers=opt.network.workers, pin_memory=True) - else: - raise ValueError('sampler must be softmax or triplet, but get {}'.format(opt.train.sampler)) - - test_loader = DataLoader(test_set, batch_size=test_batch_size, num_workers=opt.network.workers, pin_memory=True) - return train_loader, test_loader, len(dataset.query) # return number of query - - -if __name__ == "__main__": - from config import opt - - train_loader, test_loader, num_query = get_data_provider(opt) - from IPython import embed - - embed() diff --git a/core/solver.py b/core/solver.py deleted file mode 100644 index a630792..0000000 --- a/core/solver.py +++ /dev/null @@ -1,187 +0,0 @@ -# encoding: utf-8 -""" -@author: sherlock -@contact: sherlockliao01@gmail.com -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -from __future__ import unicode_literals - -import logging -import time - -import numpy as np -import torch - -from utils.meters import AverageMeter -from utils.serialization import save_checkpoint - - -class Solver(object): - def __init__(self, opt, net): - self.opt = opt - self.net = net - self.loss = AverageMeter('loss') - self.acc = AverageMeter('acc') - - def fit(self, train_data, test_data, num_query, optimizer, criterion, lr_scheduler): - best_rank1 = -np.inf - for epoch in range(self.opt.train.num_epochs): - self.loss.reset() - self.acc.reset() - self.net.train() - # update learning rate - lr = lr_scheduler.update(epoch) - for param_group in optimizer.param_groups: - param_group['lr'] = lr - logging.info('Epoch [{}] learning rate update to {:.3e}'.format(epoch, lr)) - - tic = time.time() - btic = time.time() - for i, inputs in enumerate(train_data): - data, pids, _ = inputs - label = pids.cuda() - score, feat = self.net(data) - loss = criterion(score, feat, label) - - optimizer.zero_grad() - loss.backward() - optimizer.step() - - self.loss.update(loss.item()) - acc = (score.max(1)[1] == label.long()).float().mean().item() - self.acc.update(acc) - - log_interval = self.opt.misc.log_interval - if log_interval and not (i + 1) % log_interval: - loss_name, loss_value = self.loss.get() - metric_name, metric_value = self.acc.get() - logging.info('Epoch[%d] Batch [%d]\tSpeed: %f samples/sec\t%s=%f\t' - '%s=%f' % ( - epoch, i + 1, train_data.batch_size * log_interval / (time.time() - btic), - loss_name, loss_value, - metric_name, metric_value - )) - btic = time.time() - - loss_name, loss_value = self.loss.get() - metric_name, metric_value = self.acc.get() - throughput = int(train_data.batch_size * len(train_data) / (time.time() - tic)) - - logging.info('[Epoch %d] training: %s=%f\t%s=%f' % ( - epoch, loss_name, loss_value, metric_name, metric_value)) - logging.info('[Epoch %d] speed: %d samples/sec\ttime cost: %f' % (epoch, throughput, time.time() - tic)) - - is_best = False - if test_data is not None and self.opt.misc.eval_step and not (epoch + 1) % self.opt.misc.eval_step: - rank1 = self.test_func(test_data, num_query) - is_best = rank1 > best_rank1 - if is_best: - best_rank1 = rank1 - state_dict = self.net.module.state_dict() - if not (epoch + 1) % self.opt.misc.save_step: - save_checkpoint({ - 'state_dict': state_dict, - 'epoch': epoch + 1, - }, is_best=is_best, save_dir=self.opt.misc.save_dir, - filename=self.opt.network.name + '.pth.tar') - - def test_func(self, test_data, num_query): - self.net.eval() - feat, person, camera = list(), list(), list() - for inputs in test_data: - data, pids, camids = inputs - with torch.no_grad(): - outputs = self.net(data).cpu() - feat.append(outputs) - person.extend(pids.numpy()) - camera.extend(camids.numpy()) - feat = torch.cat(feat, 0) - qf = feat[:num_query] - q_pids = np.asarray(person[:num_query]) - q_camids = np.asarray(camera[:num_query]) - gf = feat[num_query:] - g_pids = np.asarray(person[num_query:]) - g_camids = np.asarray(camera[num_query:]) - - logging.info("Extracted features for query set, obtained {}-by-{} matrix".format( - qf.shape[0], qf.shape[1])) - logging.info("Extracted features for gallery set, obtained {}-by-{} matrix".format( - gf.shape[0], gf.shape[1])) - - logging.info("Computing distance matrix") - - m, n = qf.shape[0], gf.shape[0] - distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \ - torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t() - distmat.addmm_(1, -2, qf, gf.t()) - distmat = distmat.numpy() - - logging.info("Computing CMC and mAP") - cmc, mAP = self.eval_func(distmat, q_pids, g_pids, q_camids, g_camids) - - print("Results ----------") - print("mAP: {:.1%}".format(mAP)) - print("CMC curve") - for r in [1, 5, 10]: - print("Rank-{:<3}: {:.1%}".format(r, cmc[r - 1])) - print("------------------") - return cmc[0] - - @staticmethod - def eval_func(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50): - """Evaluation with market1501 metric - Key: for each query identity, its gallery images from the same camera view are discarded. - """ - num_q, num_g = distmat.shape - if num_g < max_rank: - max_rank = num_g - print("Note: number of gallery samples is quite small, got {}".format(num_g)) - indices = np.argsort(distmat, axis=1) - matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) - - # compute cmc curve for each query - all_cmc = [] - all_AP = [] - num_valid_q = 0. # number of valid query - for q_idx in range(num_q): - # get query pid and camid - q_pid = q_pids[q_idx] - q_camid = q_camids[q_idx] - - # remove gallery samples that have the same pid and camid with query - order = indices[q_idx] - remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) - keep = np.invert(remove) - - # compute cmc curve - # binary vector, positions with value 1 are correct matches - orig_cmc = matches[q_idx][keep] - if not np.any(orig_cmc): - # this condition is true when query identity does not appear in gallery - continue - - cmc = orig_cmc.cumsum() - cmc[cmc > 1] = 1 - - all_cmc.append(cmc[:max_rank]) - num_valid_q += 1. - - # compute average precision - # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision - num_rel = orig_cmc.sum() - tmp_cmc = orig_cmc.cumsum() - tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)] - tmp_cmc = np.asarray(tmp_cmc) * orig_cmc - AP = tmp_cmc.sum() / num_rel - all_AP.append(AP) - - assert num_valid_q > 0, "Error: all query identities do not appear in gallery" - - all_cmc = np.asarray(all_cmc).astype(np.float32) - all_cmc = all_cmc.sum(0) / num_valid_q - mAP = np.mean(all_AP) - - return all_cmc, mAP diff --git a/data/__init__.py b/data/__init__.py new file mode 100644 index 0000000..dafbc80 --- /dev/null +++ b/data/__init__.py @@ -0,0 +1,7 @@ +# encoding: utf-8 +""" +@author: sherlock +@contact: sherlockliao01@gmail.com +""" + +from .build import make_data_loader diff --git a/data/build.py b/data/build.py new file mode 100644 index 0000000..5fe9fc5 --- /dev/null +++ b/data/build.py @@ -0,0 +1,44 @@ +# encoding: utf-8 +""" +@author: liaoxingyu +@contact: sherlockliao01@gmail.com +""" + +from torch.utils.data import DataLoader + +from .collate_batch import train_collate_fn, val_collate_fn +from .datasets import init_dataset, ImageDataset +from .samplers import RandomIdentitySampler +from .transforms import build_transforms + + +def make_data_loader(cfg): + train_transforms = build_transforms(cfg, is_train=True) + val_transforms = build_transforms(cfg, is_train=False) + num_workers = cfg.DATALOADER.NUM_WORKERS + if len(cfg.DATASETS.NAMES) == 1: + dataset = init_dataset(cfg.DATASETS.NAMES) + else: + # TODO: add multi dataset to train + dataset = init_dataset(cfg.DATASETS.NAMES) + + num_classes = dataset.num_train_pids + train_set = ImageDataset(dataset.train, train_transforms) + if cfg.DATALOADER.SAMPLER == 'softmax': + train_loader = DataLoader( + train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH, shuffle=True, num_workers=num_workers, + collate_fn=train_collate_fn + ) + else: + train_loader = DataLoader( + train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH, + sampler=RandomIdentitySampler(dataset.train, cfg.SOLVER.IMS_PER_BATCH, cfg.DATALOADER.NUM_INSTANCE), + num_workers=num_workers, collate_fn=train_collate_fn + ) + + val_set = ImageDataset(dataset.query + dataset.gallery, val_transforms) + val_loader = DataLoader( + val_set, batch_size=cfg.TEST.IMS_PER_BATCH, shuffle=False, num_workers=num_workers, + collate_fn=val_collate_fn + ) + return train_loader, val_loader, len(dataset.query), num_classes diff --git a/data/collate_batch.py b/data/collate_batch.py new file mode 100644 index 0000000..7d28901 --- /dev/null +++ b/data/collate_batch.py @@ -0,0 +1,18 @@ +# encoding: utf-8 +""" +@author: liaoxingyu +@contact: sherlockliao01@gmail.com +""" + +import torch + + +def train_collate_fn(batch): + imgs, pids, _, _, = zip(*batch) + pids = torch.tensor(pids, dtype=torch.int64) + return torch.stack(imgs, dim=0), pids + + +def val_collate_fn(batch): + imgs, pids, camids, _ = zip(*batch) + return torch.stack(imgs, dim=0), pids, camids diff --git a/data/datasets/__init__.py b/data/datasets/__init__.py new file mode 100644 index 0000000..9baf635 --- /dev/null +++ b/data/datasets/__init__.py @@ -0,0 +1,25 @@ +# encoding: utf-8 +""" +@author: liaoxingyu +@contact: sherlockliao01@gmail.com +""" +from .cuhk03 import CUHK03 +from .dukemtmcreid import DukeMTMCreID +from .market1501 import Market1501 +from .dataset_loader import ImageDataset + +__factory = { + 'market1501': Market1501, + 'cuhk03': CUHK03, + 'dukemtmc': DukeMTMCreID +} + + +def get_names(): + return __factory.keys() + + +def init_dataset(name, *args, **kwargs): + if name not in __factory.keys(): + raise KeyError("Unknown datasets: {}".format(name)) + return __factory[name](*args, **kwargs) diff --git a/data/datasets/bases.py b/data/datasets/bases.py new file mode 100644 index 0000000..88b3536 --- /dev/null +++ b/data/datasets/bases.py @@ -0,0 +1,95 @@ +# encoding: utf-8 +""" +@author: sherlock +@contact: sherlockliao01@gmail.com +""" + +import numpy as np + + +class BaseDataset(object): + """ + Base class of reid dataset + """ + + def get_imagedata_info(self, data): + pids, cams = [], [] + for _, pid, camid in data: + pids += [pid] + cams += [camid] + pids = set(pids) + cams = set(cams) + num_pids = len(pids) + num_cams = len(cams) + num_imgs = len(data) + return num_pids, num_imgs, num_cams + + def get_videodata_info(self, data, return_tracklet_stats=False): + pids, cams, tracklet_stats = [], [], [] + for img_paths, pid, camid in data: + pids += [pid] + cams += [camid] + tracklet_stats += [len(img_paths)] + pids = set(pids) + cams = set(cams) + num_pids = len(pids) + num_cams = len(cams) + num_tracklets = len(data) + if return_tracklet_stats: + return num_pids, num_tracklets, num_cams, tracklet_stats + return num_pids, num_tracklets, num_cams + + def print_dataset_statistics(self): + raise NotImplementedError + + +class BaseImageDataset(BaseDataset): + """ + Base class of image reid dataset + """ + + def print_dataset_statistics(self, train, query, gallery): + num_train_pids, num_train_imgs, num_train_cams = self.get_imagedata_info(train) + num_query_pids, num_query_imgs, num_query_cams = self.get_imagedata_info(query) + num_gallery_pids, num_gallery_imgs, num_gallery_cams = self.get_imagedata_info(gallery) + + print("Dataset statistics:") + print(" ----------------------------------------") + print(" subset | # ids | # images | # cameras") + print(" ----------------------------------------") + print(" train | {:5d} | {:8d} | {:9d}".format(num_train_pids, num_train_imgs, num_train_cams)) + print(" query | {:5d} | {:8d} | {:9d}".format(num_query_pids, num_query_imgs, num_query_cams)) + print(" gallery | {:5d} | {:8d} | {:9d}".format(num_gallery_pids, num_gallery_imgs, num_gallery_cams)) + print(" ----------------------------------------") + + +class BaseVideoDataset(BaseDataset): + """ + Base class of video reid dataset + """ + + def print_dataset_statistics(self, train, query, gallery): + num_train_pids, num_train_tracklets, num_train_cams, train_tracklet_stats = \ + self.get_videodata_info(train, return_tracklet_stats=True) + + num_query_pids, num_query_tracklets, num_query_cams, query_tracklet_stats = \ + self.get_videodata_info(query, return_tracklet_stats=True) + + num_gallery_pids, num_gallery_tracklets, num_gallery_cams, gallery_tracklet_stats = \ + self.get_videodata_info(gallery, return_tracklet_stats=True) + + tracklet_stats = train_tracklet_stats + query_tracklet_stats + gallery_tracklet_stats + min_num = np.min(tracklet_stats) + max_num = np.max(tracklet_stats) + avg_num = np.mean(tracklet_stats) + + print("Dataset statistics:") + print(" -------------------------------------------") + print(" subset | # ids | # tracklets | # cameras") + print(" -------------------------------------------") + print(" train | {:5d} | {:11d} | {:9d}".format(num_train_pids, num_train_tracklets, num_train_cams)) + print(" query | {:5d} | {:11d} | {:9d}".format(num_query_pids, num_query_tracklets, num_query_cams)) + print(" gallery | {:5d} | {:11d} | {:9d}".format(num_gallery_pids, num_gallery_tracklets, num_gallery_cams)) + print(" -------------------------------------------") + print(" number of images per tracklet: {} ~ {}, average {:.2f}".format(min_num, max_num, avg_num)) + print(" -------------------------------------------") diff --git a/data/datasets/cuhk03.py b/data/datasets/cuhk03.py new file mode 100644 index 0000000..5af18f2 --- /dev/null +++ b/data/datasets/cuhk03.py @@ -0,0 +1,259 @@ +# encoding: utf-8 +""" +@author: liaoxingyu +@contact: liaoxingyu2@jd.com +""" + +import h5py +import os.path as osp +from scipy.io import loadmat +from scipy.misc import imsave + +from utils.iotools import mkdir_if_missing, write_json, read_json +from .bases import BaseImageDataset + + +class CUHK03(BaseImageDataset): + """ + CUHK03 + Reference: + Li et al. DeepReID: Deep Filter Pairing Neural Network for Person Re-identification. CVPR 2014. + URL: http://www.ee.cuhk.edu.hk/~xgwang/CUHK_identification.html#! + + Dataset statistics: + # identities: 1360 + # images: 13164 + # cameras: 6 + # splits: 20 (classic) + Args: + split_id (int): split index (default: 0) + cuhk03_labeled (bool): whether to load labeled images; if false, detected images are loaded (default: False) + """ + dataset_dir = 'cuhk03' + + def __init__(self, root='/export/home/lxy/DATA/reid', split_id=0, cuhk03_labeled=False, + cuhk03_classic_split=False, verbose=True, + **kwargs): + super(CUHK03, self).__init__() + self.dataset_dir = osp.join(root, self.dataset_dir) + self.data_dir = osp.join(self.dataset_dir, 'cuhk03_release') + self.raw_mat_path = osp.join(self.data_dir, 'cuhk-03.mat') + + self.imgs_detected_dir = osp.join(self.dataset_dir, 'images_detected') + self.imgs_labeled_dir = osp.join(self.dataset_dir, 'images_labeled') + + self.split_classic_det_json_path = osp.join(self.dataset_dir, 'splits_classic_detected.json') + self.split_classic_lab_json_path = osp.join(self.dataset_dir, 'splits_classic_labeled.json') + + self.split_new_det_json_path = osp.join(self.dataset_dir, 'splits_new_detected.json') + self.split_new_lab_json_path = osp.join(self.dataset_dir, 'splits_new_labeled.json') + + self.split_new_det_mat_path = osp.join(self.dataset_dir, 'cuhk03_new_protocol_config_detected.mat') + self.split_new_lab_mat_path = osp.join(self.dataset_dir, 'cuhk03_new_protocol_config_labeled.mat') + + self._check_before_run() + self._preprocess() + + if cuhk03_labeled: + image_type = 'labeled' + split_path = self.split_classic_lab_json_path if cuhk03_classic_split else self.split_new_lab_json_path + else: + image_type = 'detected' + split_path = self.split_classic_det_json_path if cuhk03_classic_split else self.split_new_det_json_path + + splits = read_json(split_path) + assert split_id < len(splits), "Condition split_id ({}) < len(splits) ({}) is false".format(split_id, + len(splits)) + split = splits[split_id] + print("Split index = {}".format(split_id)) + + train = split['train'] + query = split['query'] + gallery = split['gallery'] + + if verbose: + print("=> CUHK03 ({}) loaded".format(image_type)) + self.print_dataset_statistics(train, query, gallery) + + self.train = train + self.query = query + self.gallery = gallery + + self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) + self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) + self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) + + def _check_before_run(self): + """Check if all files are available before going deeper""" + if not osp.exists(self.dataset_dir): + raise RuntimeError("'{}' is not available".format(self.dataset_dir)) + if not osp.exists(self.data_dir): + raise RuntimeError("'{}' is not available".format(self.data_dir)) + if not osp.exists(self.raw_mat_path): + raise RuntimeError("'{}' is not available".format(self.raw_mat_path)) + if not osp.exists(self.split_new_det_mat_path): + raise RuntimeError("'{}' is not available".format(self.split_new_det_mat_path)) + if not osp.exists(self.split_new_lab_mat_path): + raise RuntimeError("'{}' is not available".format(self.split_new_lab_mat_path)) + + def _preprocess(self): + """ + This function is a bit complex and ugly, what it does is + 1. Extract data from cuhk-03.mat and save as png images. + 2. Create 20 classic splits. (Li et al. CVPR'14) + 3. Create new split. (Zhong et al. CVPR'17) + """ + print( + "Note: if root path is changed, the previously generated json files need to be re-generated (delete them first)") + if osp.exists(self.imgs_labeled_dir) and \ + osp.exists(self.imgs_detected_dir) and \ + osp.exists(self.split_classic_det_json_path) and \ + osp.exists(self.split_classic_lab_json_path) and \ + osp.exists(self.split_new_det_json_path) and \ + osp.exists(self.split_new_lab_json_path): + return + + mkdir_if_missing(self.imgs_detected_dir) + mkdir_if_missing(self.imgs_labeled_dir) + + print("Extract image data from {} and save as png".format(self.raw_mat_path)) + mat = h5py.File(self.raw_mat_path, 'r') + + def _deref(ref): + return mat[ref][:].T + + def _process_images(img_refs, campid, pid, save_dir): + img_paths = [] # Note: some persons only have images for one view + for imgid, img_ref in enumerate(img_refs): + img = _deref(img_ref) + # skip empty cell + if img.size == 0 or img.ndim < 3: continue + # images are saved with the following format, index-1 (ensure uniqueness) + # campid: index of camera pair (1-5) + # pid: index of person in 'campid'-th camera pair + # viewid: index of view, {1, 2} + # imgid: index of image, (1-10) + viewid = 1 if imgid < 5 else 2 + img_name = '{:01d}_{:03d}_{:01d}_{:02d}.png'.format(campid + 1, pid + 1, viewid, imgid + 1) + img_path = osp.join(save_dir, img_name) + if not osp.isfile(img_path): + imsave(img_path, img) + img_paths.append(img_path) + return img_paths + + def _extract_img(name): + print("Processing {} images (extract and save) ...".format(name)) + meta_data = [] + imgs_dir = self.imgs_detected_dir if name == 'detected' else self.imgs_labeled_dir + for campid, camp_ref in enumerate(mat[name][0]): + camp = _deref(camp_ref) + num_pids = camp.shape[0] + for pid in range(num_pids): + img_paths = _process_images(camp[pid, :], campid, pid, imgs_dir) + assert len(img_paths) > 0, "campid{}-pid{} has no images".format(campid, pid) + meta_data.append((campid + 1, pid + 1, img_paths)) + print("- done camera pair {} with {} identities".format(campid + 1, num_pids)) + return meta_data + + meta_detected = _extract_img('detected') + meta_labeled = _extract_img('labeled') + + def _extract_classic_split(meta_data, test_split): + train, test = [], [] + num_train_pids, num_test_pids = 0, 0 + num_train_imgs, num_test_imgs = 0, 0 + for i, (campid, pid, img_paths) in enumerate(meta_data): + + if [campid, pid] in test_split: + for img_path in img_paths: + camid = int(osp.basename(img_path).split('_')[2]) - 1 # make it 0-based + test.append((img_path, num_test_pids, camid)) + num_test_pids += 1 + num_test_imgs += len(img_paths) + else: + for img_path in img_paths: + camid = int(osp.basename(img_path).split('_')[2]) - 1 # make it 0-based + train.append((img_path, num_train_pids, camid)) + num_train_pids += 1 + num_train_imgs += len(img_paths) + return train, num_train_pids, num_train_imgs, test, num_test_pids, num_test_imgs + + print("Creating classic splits (# = 20) ...") + splits_classic_det, splits_classic_lab = [], [] + for split_ref in mat['testsets'][0]: + test_split = _deref(split_ref).tolist() + + # create split for detected images + train, num_train_pids, num_train_imgs, test, num_test_pids, num_test_imgs = \ + _extract_classic_split(meta_detected, test_split) + splits_classic_det.append({ + 'train': train, 'query': test, 'gallery': test, + 'num_train_pids': num_train_pids, 'num_train_imgs': num_train_imgs, + 'num_query_pids': num_test_pids, 'num_query_imgs': num_test_imgs, + 'num_gallery_pids': num_test_pids, 'num_gallery_imgs': num_test_imgs, + }) + + # create split for labeled images + train, num_train_pids, num_train_imgs, test, num_test_pids, num_test_imgs = \ + _extract_classic_split(meta_labeled, test_split) + splits_classic_lab.append({ + 'train': train, 'query': test, 'gallery': test, + 'num_train_pids': num_train_pids, 'num_train_imgs': num_train_imgs, + 'num_query_pids': num_test_pids, 'num_query_imgs': num_test_imgs, + 'num_gallery_pids': num_test_pids, 'num_gallery_imgs': num_test_imgs, + }) + + write_json(splits_classic_det, self.split_classic_det_json_path) + write_json(splits_classic_lab, self.split_classic_lab_json_path) + + def _extract_set(filelist, pids, pid2label, idxs, img_dir, relabel): + tmp_set = [] + unique_pids = set() + for idx in idxs: + img_name = filelist[idx][0] + camid = int(img_name.split('_')[2]) - 1 # make it 0-based + pid = pids[idx] + if relabel: pid = pid2label[pid] + img_path = osp.join(img_dir, img_name) + tmp_set.append((img_path, int(pid), camid)) + unique_pids.add(pid) + return tmp_set, len(unique_pids), len(idxs) + + def _extract_new_split(split_dict, img_dir): + train_idxs = split_dict['train_idx'].flatten() - 1 # index-0 + pids = split_dict['labels'].flatten() + train_pids = set(pids[train_idxs]) + pid2label = {pid: label for label, pid in enumerate(train_pids)} + query_idxs = split_dict['query_idx'].flatten() - 1 + gallery_idxs = split_dict['gallery_idx'].flatten() - 1 + filelist = split_dict['filelist'].flatten() + train_info = _extract_set(filelist, pids, pid2label, train_idxs, img_dir, relabel=True) + query_info = _extract_set(filelist, pids, pid2label, query_idxs, img_dir, relabel=False) + gallery_info = _extract_set(filelist, pids, pid2label, gallery_idxs, img_dir, relabel=False) + return train_info, query_info, gallery_info + + print("Creating new splits for detected images (767/700) ...") + train_info, query_info, gallery_info = _extract_new_split( + loadmat(self.split_new_det_mat_path), + self.imgs_detected_dir, + ) + splits = [{ + 'train': train_info[0], 'query': query_info[0], 'gallery': gallery_info[0], + 'num_train_pids': train_info[1], 'num_train_imgs': train_info[2], + 'num_query_pids': query_info[1], 'num_query_imgs': query_info[2], + 'num_gallery_pids': gallery_info[1], 'num_gallery_imgs': gallery_info[2], + }] + write_json(splits, self.split_new_det_json_path) + + print("Creating new splits for labeled images (767/700) ...") + train_info, query_info, gallery_info = _extract_new_split( + loadmat(self.split_new_lab_mat_path), + self.imgs_labeled_dir, + ) + splits = [{ + 'train': train_info[0], 'query': query_info[0], 'gallery': gallery_info[0], + 'num_train_pids': train_info[1], 'num_train_imgs': train_info[2], + 'num_query_pids': query_info[1], 'num_query_imgs': query_info[2], + 'num_gallery_pids': gallery_info[1], 'num_gallery_imgs': gallery_info[2], + }] + write_json(splits, self.split_new_lab_json_path) diff --git a/data/datasets/dataset_loader.py b/data/datasets/dataset_loader.py new file mode 100644 index 0000000..b75ead8 --- /dev/null +++ b/data/datasets/dataset_loader.py @@ -0,0 +1,45 @@ +# encoding: utf-8 +""" +@author: liaoxingyu +@contact: sherlockliao01@gmail.com +""" + +import os.path as osp +from PIL import Image +from torch.utils.data import Dataset + + +def read_image(img_path): + """Keep reading image until succeed. + This can avoid IOError incurred by heavy IO process.""" + got_img = False + if not osp.exists(img_path): + raise IOError("{} does not exist".format(img_path)) + while not got_img: + try: + img = Image.open(img_path).convert('RGB') + got_img = True + except IOError: + print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path)) + pass + return img + + +class ImageDataset(Dataset): + """Image Person ReID Dataset""" + + def __init__(self, dataset, transform=None): + self.dataset = dataset + self.transform = transform + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, index): + img_path, pid, camid = self.dataset[index] + img = read_image(img_path) + + if self.transform is not None: + img = self.transform(img) + + return img, pid, camid, img_path diff --git a/data/datasets/dukemtmcreid.py b/data/datasets/dukemtmcreid.py new file mode 100644 index 0000000..ccceceb --- /dev/null +++ b/data/datasets/dukemtmcreid.py @@ -0,0 +1,106 @@ +# encoding: utf-8 +""" +@author: liaoxingyu +@contact: liaoxingyu2@jd.com +""" + +import glob +import re +import urllib +import zipfile + +import os.path as osp + +from utils.iotools import mkdir_if_missing +from .bases import BaseImageDataset + + +class DukeMTMCreID(BaseImageDataset): + """ + DukeMTMC-reID + Reference: + 1. Ristani et al. Performance Measures and a Data Set for Multi-Target, Multi-Camera Tracking. ECCVW 2016. + 2. Zheng et al. Unlabeled Samples Generated by GAN Improve the Person Re-identification Baseline in vitro. ICCV 2017. + URL: https://github.com/layumi/DukeMTMC-reID_evaluation + + Dataset statistics: + # identities: 1404 (train + query) + # images:16522 (train) + 2228 (query) + 17661 (gallery) + # cameras: 8 + """ + dataset_dir = 'dukemtmc-reid' + + def __init__(self, root='/export/home/lxy/DATA/reid', verbose=True, **kwargs): + super(DukeMTMCreID, self).__init__() + self.dataset_dir = osp.join(root, self.dataset_dir) + self.dataset_url = 'http://vision.cs.duke.edu/DukeMTMC/data/misc/DukeMTMC-reID.zip' + self.train_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/bounding_box_train') + self.query_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/query') + self.gallery_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/bounding_box_test') + + self._download_data() + self._check_before_run() + + train = self._process_dir(self.train_dir, relabel=True) + query = self._process_dir(self.query_dir, relabel=False) + gallery = self._process_dir(self.gallery_dir, relabel=False) + + if verbose: + print("=> DukeMTMC-reID loaded") + self.print_dataset_statistics(train, query, gallery) + + self.train = train + self.query = query + self.gallery = gallery + + self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) + self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) + self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) + + def _download_data(self): + if osp.exists(self.dataset_dir): + print("This dataset has been downloaded.") + return + + print("Creating directory {}".format(self.dataset_dir)) + mkdir_if_missing(self.dataset_dir) + fpath = osp.join(self.dataset_dir, osp.basename(self.dataset_url)) + + print("Downloading DukeMTMC-reID dataset") + urllib.urlretrieve(self.dataset_url, fpath) + + print("Extracting files") + zip_ref = zipfile.ZipFile(fpath, 'r') + zip_ref.extractall(self.dataset_dir) + zip_ref.close() + + def _check_before_run(self): + """Check if all files are available before going deeper""" + if not osp.exists(self.dataset_dir): + raise RuntimeError("'{}' is not available".format(self.dataset_dir)) + if not osp.exists(self.train_dir): + raise RuntimeError("'{}' is not available".format(self.train_dir)) + if not osp.exists(self.query_dir): + raise RuntimeError("'{}' is not available".format(self.query_dir)) + if not osp.exists(self.gallery_dir): + raise RuntimeError("'{}' is not available".format(self.gallery_dir)) + + def _process_dir(self, dir_path, relabel=False): + img_paths = glob.glob(osp.join(dir_path, '*.jpg')) + pattern = re.compile(r'([-\d]+)_c(\d)') + + pid_container = set() + for img_path in img_paths: + pid, _ = map(int, pattern.search(img_path).groups()) + pid_container.add(pid) + pid2label = {pid: label for label, pid in enumerate(pid_container)} + + dataset = [] + for img_path in img_paths: + pid, camid = map(int, pattern.search(img_path).groups()) + assert 1 <= camid <= 8 + camid -= 1 # index starts from 0 + if relabel: pid = pid2label[pid] + dataset.append((img_path, pid, camid)) + + return dataset diff --git a/data/datasets/eval_reid.py b/data/datasets/eval_reid.py new file mode 100644 index 0000000..38f9783 --- /dev/null +++ b/data/datasets/eval_reid.py @@ -0,0 +1,63 @@ +# encoding: utf-8 +""" +@author: liaoxingyu +@contact: sherlockliao01@gmail.com +""" + +import numpy as np + + +def eval_func(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50): + """Evaluation with market1501 metric + Key: for each query identity, its gallery images from the same camera view are discarded. + """ + num_q, num_g = distmat.shape + if num_g < max_rank: + max_rank = num_g + print("Note: number of gallery samples is quite small, got {}".format(num_g)) + indices = np.argsort(distmat, axis=1) + matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) + + # compute cmc curve for each query + all_cmc = [] + all_AP = [] + num_valid_q = 0. # number of valid query + for q_idx in range(num_q): + # get query pid and camid + q_pid = q_pids[q_idx] + q_camid = q_camids[q_idx] + + # remove gallery samples that have the same pid and camid with query + order = indices[q_idx] + remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) + keep = np.invert(remove) + + # compute cmc curve + # binary vector, positions with value 1 are correct matches + orig_cmc = matches[q_idx][keep] + if not np.any(orig_cmc): + # this condition is true when query identity does not appear in gallery + continue + + cmc = orig_cmc.cumsum() + cmc[cmc > 1] = 1 + + all_cmc.append(cmc[:max_rank]) + num_valid_q += 1. + + # compute average precision + # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision + num_rel = orig_cmc.sum() + tmp_cmc = orig_cmc.cumsum() + tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)] + tmp_cmc = np.asarray(tmp_cmc) * orig_cmc + AP = tmp_cmc.sum() / num_rel + all_AP.append(AP) + + assert num_valid_q > 0, "Error: all query identities do not appear in gallery" + + all_cmc = np.asarray(all_cmc).astype(np.float32) + all_cmc = all_cmc.sum(0) / num_valid_q + mAP = np.mean(all_AP) + + return all_cmc, mAP diff --git a/core/data_manager.py b/data/datasets/market1501.py old mode 100755 new mode 100644 similarity index 52% rename from core/data_manager.py rename to data/datasets/market1501.py index 470a23f..8e39b0c --- a/core/data_manager.py +++ b/data/datasets/market1501.py @@ -1,13 +1,18 @@ -from __future__ import print_function, absolute_import +# encoding: utf-8 +""" +@author: sherlock +@contact: sherlockliao01@gmail.com +""" import glob import re -from os import path as osp -"""Dataset classes""" +import os.path as osp + +from .bases import BaseImageDataset -class Market1501(object): +class Market1501(BaseImageDataset): """ Market1501 Reference: @@ -18,9 +23,10 @@ class Market1501(object): # identities: 1501 (+1 for background) # images: 12936 (train) + 3368 (query) + 15913 (gallery) """ - dataset_dir = 'Market-1501-v15.09.15' + dataset_dir = 'market1501' - def __init__(self, root='/home/test2/DATA/market1501/raw/'): + def __init__(self, root='/export/home/lxy/DATA/reid', verbose=True, **kwargs): + super(Market1501, self).__init__() self.dataset_dir = osp.join(root, self.dataset_dir) self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') self.query_dir = osp.join(self.dataset_dir, 'query') @@ -28,31 +34,21 @@ class Market1501(object): self._check_before_run() - train, num_train_pids, num_train_imgs = self._process_dir(self.train_dir, relabel=True) - query, num_query_pids, num_query_imgs = self._process_dir(self.query_dir, relabel=False) - gallery, num_gallery_pids, num_gallery_imgs = self._process_dir(self.gallery_dir, relabel=False) - num_total_pids = num_train_pids + num_query_pids - num_total_imgs = num_train_imgs + num_query_imgs + num_gallery_imgs + train = self._process_dir(self.train_dir, relabel=True) + query = self._process_dir(self.query_dir, relabel=False) + gallery = self._process_dir(self.gallery_dir, relabel=False) - print("=> Market1501 loaded") - print("Dataset statistics:") - print(" ------------------------------") - print(" subset | # ids | # images") - print(" ------------------------------") - print(" train | {:5d} | {:8d}".format(num_train_pids, num_train_imgs)) - print(" query | {:5d} | {:8d}".format(num_query_pids, num_query_imgs)) - print(" gallery | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_imgs)) - print(" ------------------------------") - print(" total | {:5d} | {:8d}".format(num_total_pids, num_total_imgs)) - print(" ------------------------------") + if verbose: + print("=> Market1501 loaded") + self.print_dataset_statistics(train, query, gallery) self.train = train self.query = query self.gallery = gallery - self.num_train_pids = num_train_pids - self.num_query_pids = num_query_pids - self.num_gallery_pids = num_gallery_pids + self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) + self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) + self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) def _check_before_run(self): """Check if all files are available before going deeper""" @@ -79,31 +75,11 @@ class Market1501(object): dataset = [] for img_path in img_paths: pid, camid = map(int, pattern.search(img_path).groups()) - if pid == -1: - continue # junk images are just ignored + if pid == -1: continue # junk images are just ignored assert 0 <= pid <= 1501 # pid == 0 means background assert 1 <= camid <= 6 camid -= 1 # index starts from 0 if relabel: pid = pid2label[pid] dataset.append((img_path, pid, camid)) - num_pids = len(pid_container) - num_imgs = len(dataset) - return dataset, num_pids, num_imgs - - -"""Create datasets""" - -__factory = { - 'market1501': Market1501 -} - - -def get_names(): - return __factory.keys() - - -def init_dataset(name, *args, **kwargs): - if name not in __factory.keys(): - raise KeyError("Unknown datasets: {}".format(name)) - return __factory[name](*args, **kwargs) + return dataset diff --git a/data/samplers/__init__.py b/data/samplers/__init__.py new file mode 100644 index 0000000..e0c5684 --- /dev/null +++ b/data/samplers/__init__.py @@ -0,0 +1,7 @@ +# encoding: utf-8 +""" +@author: liaoxingyu +@contact: sherlockliao01@gmail.com +""" + +from .triplet_sampler import RandomIdentitySampler diff --git a/data/samplers/triplet_sampler.py b/data/samplers/triplet_sampler.py new file mode 100644 index 0000000..65ead5d --- /dev/null +++ b/data/samplers/triplet_sampler.py @@ -0,0 +1,73 @@ +# encoding: utf-8 +""" +@author: liaoxingyu +@contact: liaoxingyu2@jd.com +""" + +import copy +import random +from collections import defaultdict + +import numpy as np +from torch.utils.data.sampler import Sampler + + +class RandomIdentitySampler(Sampler): + """ + Randomly sample N identities, then for each identity, + randomly sample K instances, therefore batch size is N*K. + Args: + - data_source (list): list of (img_path, pid, camid). + - num_instances (int): number of instances per identity in a batch. + - batch_size (int): number of examples in a batch. + """ + + def __init__(self, data_source, batch_size, num_instances): + self.data_source = data_source + self.batch_size = batch_size + self.num_instances = num_instances + self.num_pids_per_batch = self.batch_size // self.num_instances + self.index_dic = defaultdict(list) + for index, (_, pid, _) in enumerate(self.data_source): + self.index_dic[pid].append(index) + self.pids = list(self.index_dic.keys()) + + # estimate number of examples in an epoch + self.length = 0 + for pid in self.pids: + idxs = self.index_dic[pid] + num = len(idxs) + if num < self.num_instances: + num = self.num_instances + self.length += num - num % self.num_instances + + def __iter__(self): + batch_idxs_dict = defaultdict(list) + + for pid in self.pids: + idxs = copy.deepcopy(self.index_dic[pid]) + if len(idxs) < self.num_instances: + idxs = np.random.choice(idxs, size=self.num_instances, replace=True) + random.shuffle(idxs) + batch_idxs = [] + for idx in idxs: + batch_idxs.append(idx) + if len(batch_idxs) == self.num_instances: + batch_idxs_dict[pid].append(batch_idxs) + batch_idxs = [] + + avai_pids = copy.deepcopy(self.pids) + final_idxs = [] + + while len(avai_pids) >= self.num_pids_per_batch: + selected_pids = random.sample(avai_pids, self.num_pids_per_batch) + for pid in selected_pids: + batch_idxs = batch_idxs_dict[pid].pop(0) + final_idxs.extend(batch_idxs) + if len(batch_idxs_dict[pid]) == 0: + avai_pids.remove(pid) + + return iter(final_idxs) + + def __len__(self): + return self.length diff --git a/data/transforms/__init__.py b/data/transforms/__init__.py new file mode 100644 index 0000000..51357aa --- /dev/null +++ b/data/transforms/__init__.py @@ -0,0 +1,7 @@ +# encoding: utf-8 +""" +@author: sherlock +@contact: sherlockliao01@gmail.com +""" + +from .build import build_transforms diff --git a/data/transforms/build.py b/data/transforms/build.py new file mode 100644 index 0000000..8562353 --- /dev/null +++ b/data/transforms/build.py @@ -0,0 +1,31 @@ +# encoding: utf-8 +""" +@author: liaoxingyu +@contact: liaoxingyu2@jd.com +""" + +import torchvision.transforms as T + +from .transforms import RandomErasing + + +def build_transforms(cfg, is_train=True): + normalize_transform = T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD) + if is_train: + transform = T.Compose([ + T.Resize(cfg.INPUT.SIZE_TRAIN), + T.RandomHorizontalFlip(p=cfg.INPUT.PROB), + T.Pad(cfg.INPUT.PADDING), + T.RandomCrop(cfg.INPUT.SIZE_TRAIN), + T.ToTensor(), + normalize_transform, + RandomErasing(probability=cfg.INPUT.PROB, mean=cfg.INPUT.PIXEL_MEAN) + ]) + else: + transform = T.Compose([ + T.Resize(cfg.INPUT.SIZE_TEST), + T.ToTensor(), + normalize_transform + ]) + + return transform diff --git a/utils/augmenter.py b/data/transforms/transforms.py similarity index 53% rename from utils/augmenter.py rename to data/transforms/transforms.py index 6f8dda1..b1ac617 100644 --- a/utils/augmenter.py +++ b/data/transforms/transforms.py @@ -1,57 +1,12 @@ # encoding: utf-8 """ @author: liaoxingyu -@contact: sherlockliao01@gmail.com +@contact: liaoxingyu2@jd.com """ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -from __future__ import unicode_literals - import math import random -from PIL import Image - - -class Random2DTranslation(object): - """ - With a probability, first increase image size to (1 + 1/8), and then perform random crop. - - Args: - height (int): target height. - width (int): target width. - p (float): probability of performing this transformation. Default: 0.5. - """ - - def __init__(self, height, width, p=0.5, interpolation=Image.BILINEAR): - self.height = height - self.width = width - self.p = p - self.interpolation = interpolation - - def __call__(self, img): - """ - Args: - img (PIL Image): Image to be cropped. - - Returns: - PIL Image: Cropped image. - """ - if random.random() < self.p: - return img.resize((self.width, self.height), self.interpolation) - new_width, new_height = int( - round(self.width * 1.125)), int(round(self.height * 1.125)) - resized_img = img.resize((new_width, new_height), self.interpolation) - x_maxrange = new_width - self.width - y_maxrange = new_height - self.height - x1 = int(round(random.uniform(0, x_maxrange))) - y1 = int(round(random.uniform(0, y_maxrange))) - croped_img = resized_img.crop( - (x1, y1, x1 + self.width, y1 + self.height)) - return croped_img - class RandomErasing(object): """ Randomly selects a rectangle region in an image and erases its pixels. @@ -65,7 +20,7 @@ class RandomErasing(object): mean: Erasing value. """ - def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=[0.4914, 0.4822, 0.4465]): + def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=(0.4914, 0.4822, 0.4465)): self.probability = probability self.mean = mean self.sl = sl diff --git a/engine/inference.py b/engine/inference.py new file mode 100644 index 0000000..3f912ef --- /dev/null +++ b/engine/inference.py @@ -0,0 +1,64 @@ +# encoding: utf-8 +""" +@author: sherlock +@contact: sherlockliao01@gmail.com +""" +import logging + +import torch +from ignite.engine import Engine + +from utils.reid_metric import R1_mAP + + +def create_supervised_evaluator(model, metrics, + device=None): + """ + Factory function for creating an evaluator for supervised models + + Args: + model (`torch.nn.Module`): the model to train + metrics (dict of str - :class:`ignite.metrics.Metric`): a map of metric names to Metrics + device (str, optional): device type specification (default: None). + Applies to both model and batches. + Returns: + Engine: an evaluator engine with supervised inference function + """ + if device: + model.to(device) + + def _inference(engine, batch): + model.eval() + with torch.no_grad(): + data, pids, camids = batch + data = data.cuda() + feat = model(data) + return feat, pids, camids + + engine = Engine(_inference) + + for name, metric in metrics.items(): + metric.attach(engine, name) + + return engine + + +def inference( + cfg, + model, + val_loader, + num_query +): + device = cfg.MODEL.DEVICE + + logger = logging.getLogger("reid_baseline.inference") + logger.info("Start inferencing") + evaluator = create_supervised_evaluator(model, metrics={'r1_mAP': R1_mAP(num_query)}, + device=device) + + evaluator.run(val_loader) + cmc, mAP = evaluator.state.metrics['r1_mAP'] + logger.info('Validation Results') + logger.info("mAP: {:.1%}".format(mAP)) + for r in [1, 5, 10]: + logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1])) diff --git a/engine/trainer.py b/engine/trainer.py new file mode 100644 index 0000000..b1e9ab7 --- /dev/null +++ b/engine/trainer.py @@ -0,0 +1,150 @@ +# encoding: utf-8 +""" +@author: sherlock +@contact: sherlockliao01@gmail.com +""" + +import logging + +import torch +from ignite.engine import Engine, Events +from ignite.handlers import ModelCheckpoint, Timer +from ignite.metrics import RunningAverage + +from utils.reid_metric import R1_mAP + + +def create_supervised_trainer(model, optimizer, loss_fn, + device=None): + """ + Factory function for creating a trainer for supervised models + + Args: + model (`torch.nn.Module`): the model to train + optimizer (`torch.optim.Optimizer`): the optimizer to use + loss_fn (torch.nn loss function): the loss function to use + device (str, optional): device type specification (default: None). + Applies to both model and batches. + + Returns: + Engine: a trainer engine with supervised update function + """ + if device: + model.to(device) + + def _update(engine, batch): + model.train() + optimizer.zero_grad() + img, target = batch + img = img.cuda() + target = target.cuda() + score, feat = model(img) + loss = loss_fn(score, feat, target) + loss.backward() + optimizer.step() + # compute acc + acc = (score.max(1)[1] == target).float().mean() + return loss.item(), acc.item() + + return Engine(_update) + + +def create_supervised_evaluator(model, metrics, + device=None): + """ + Factory function for creating an evaluator for supervised models + + Args: + model (`torch.nn.Module`): the model to train + metrics (dict of str - :class:`ignite.metrics.Metric`): a map of metric names to Metrics + device (str, optional): device type specification (default: None). + Applies to both model and batches. + Returns: + Engine: an evaluator engine with supervised inference function + """ + if device: + model.to(device) + + def _inference(engine, batch): + model.eval() + with torch.no_grad(): + data, pids, camids = batch + data = data.cuda() + feat = model(data) + return feat, pids, camids + + engine = Engine(_inference) + + for name, metric in metrics.items(): + metric.attach(engine, name) + + return engine + + +def do_train( + cfg, + model, + train_loader, + val_loader, + optimizer, + scheduler, + loss_fn, + num_query +): + log_period = cfg.SOLVER.LOG_PERIOD + checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD + eval_period = cfg.SOLVER.EVAL_PERIOD + output_dir = cfg.OUTPUT_DIR + device = cfg.MODEL.DEVICE + epochs = cfg.SOLVER.MAX_EPOCHS + + logger = logging.getLogger("reid_baseline.train") + logger.info("Start training") + trainer = create_supervised_trainer(model, optimizer, loss_fn, device=device) + evaluator = create_supervised_evaluator(model, metrics={'r1_mAP': R1_mAP(num_query)}, device=device) + checkpointer = ModelCheckpoint(output_dir, cfg.MODEL.NAME, checkpoint_period, n_saved=10, require_empty=False) + timer = Timer(average=True) + + trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'model': model.state_dict(), + 'optimizer': optimizer.state_dict()}) + timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, + pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) + + # average metric to attach on trainer + RunningAverage(output_transform=lambda x: x[0]).attach(trainer, 'avg_loss') + RunningAverage(output_transform=lambda x: x[1]).attach(trainer, 'avg_acc') + + @trainer.on(Events.EPOCH_STARTED) + def adjust_learning_rate(engine): + scheduler.step() + + @trainer.on(Events.ITERATION_COMPLETED) + def log_training_loss(engine): + iter = (engine.state.iteration - 1) % len(train_loader) + 1 + + if iter % log_period == 0: + logger.info("Epoch[{}] Iteration[{}/{}] Loss: {:.3f}, Acc: {:.3f}, Base Lr: {:.2e}" + .format(engine.state.epoch, iter, len(train_loader), + engine.state.metrics['avg_loss'], engine.state.metrics['avg_acc'], + scheduler.get_lr()[0])) + + # adding handlers using `trainer.on` decorator API + @trainer.on(Events.EPOCH_COMPLETED) + def print_times(engine): + logger.info('Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]' + .format(engine.state.epoch, timer.value() * timer.step_count, + train_loader.batch_size / timer.value())) + logger.info('-' * 10) + timer.reset() + + @trainer.on(Events.EPOCH_COMPLETED) + def log_validation_results(engine): + if engine.state.epoch % eval_period == 0: + evaluator.run(val_loader) + cmc, mAP = evaluator.state.metrics['r1_mAP'] + logger.info("Validation Results - Epoch: {}".format(engine.state.epoch)) + logger.info("mAP: {:.1%}".format(mAP)) + for r in [1, 5, 10]: + logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1])) + + trainer.run(train_loader, max_epochs=epochs) diff --git a/layers/__init__.py b/layers/__init__.py new file mode 100644 index 0000000..d28327b --- /dev/null +++ b/layers/__init__.py @@ -0,0 +1,28 @@ +# encoding: utf-8 +""" +@author: liaoxingyu +@contact: sherlockliao01@gmail.com +""" + +import torch.nn.functional as F + +from .triplet_loss import TripletLoss + + +def make_loss(cfg): + sampler = cfg.DATALOADER.SAMPLER + triplet = TripletLoss(cfg.SOLVER.MARGIN) # triplet loss + + if sampler == 'softmax': + def loss_func(score, feat, target): + return F.cross_entropy(score, target) + elif cfg.DATALOADER.SAMPLER == 'triplet': + def loss_func(score, feat, target): + return triplet(feat, target)[0] + elif cfg.DATALOADER.SAMPLER == 'softmax_triplet': + def loss_func(score, feat, target): + return F.cross_entropy(score, target) + triplet(feat, target)[0] + else: + print('expected sampler should be softmax, triplet or softmax_triplet, ' + 'but got {}'.format(cfg.DATALOADER.SAMPLER)) + return loss_func diff --git a/utils/loss.py b/layers/triplet_loss.py similarity index 74% rename from utils/loss.py rename to layers/triplet_loss.py index 33d2c5e..3987530 100644 --- a/utils/loss.py +++ b/layers/triplet_loss.py @@ -1,17 +1,10 @@ # encoding: utf-8 """ @author: liaoxingyu -@contact: xyliao1993@qq.com +@contact: sherlockliao01@gmail.com """ - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -from __future__ import unicode_literals - import torch from torch import nn -import torch.nn.functional as F def normalize(x, axis=-1): @@ -121,34 +114,3 @@ class TripletLoss(object): else: loss = self.ranking_loss(dist_an - dist_ap, y) return loss, dist_ap, dist_an - - -class CrossEntropyLabelSmooth(nn.Module): - """Cross entropy loss with label smoothing regularizer. - Reference: - Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. - Equation: y = (1 - epsilon) * y + epsilon / K. - Args: - num_classes (int): number of classes. - epsilon (float): weight. - """ - - def __init__(self, num_classes, epsilon=0.1, use_gpu=True): - super(CrossEntropyLabelSmooth, self).__init__() - self.num_classes = num_classes - self.epsilon = epsilon - self.use_gpu = use_gpu - self.logsoftmax = nn.LogSoftmax(dim=1) - - def forward(self, inputs, targets): - """ - Args: - inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) - targets: ground truth labels with shape (num_classes) - """ - log_probs = self.logsoftmax(inputs) - targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).cpu(), 1) - if self.use_gpu: targets = targets.cuda() - targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes - loss = (- targets * log_probs).mean(0).sum() - return loss diff --git a/modeling/__init__.py b/modeling/__init__.py new file mode 100644 index 0000000..39ca962 --- /dev/null +++ b/modeling/__init__.py @@ -0,0 +1,13 @@ +# encoding: utf-8 +""" +@author: sherlock +@contact: sherlockliao01@gmail.com +""" + +from .baseline import Baseline + + +def build_model(cfg, num_classes): + if cfg.MODEL.NAME == 'resnet50': + model = Baseline(num_classes, cfg.MODEL.LAST_STRIDE, cfg.MODEL.PRETRAIN_PATH) + return model diff --git a/modeling/backbones/__init__.py b/modeling/backbones/__init__.py new file mode 100644 index 0000000..eb25c85 --- /dev/null +++ b/modeling/backbones/__init__.py @@ -0,0 +1,6 @@ +# encoding: utf-8 +""" +@author: liaoxingyu +@contact: sherlockliao01@gmail.com +""" + diff --git a/network/resnet.py b/modeling/backbones/resnet.py similarity index 89% rename from network/resnet.py rename to modeling/backbones/resnet.py index fa6ca5d..0f44b38 100644 --- a/network/resnet.py +++ b/modeling/backbones/resnet.py @@ -1,17 +1,12 @@ # encoding: utf-8 """ -@author: liaoxingyu -@contact: liaoxingyu@megvii.com +@author: liaoxingyu +@contact: sherlockliao01@gmail.com """ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -from __future__ import unicode_literals - import math -import torch as th +import torch from torch import nn @@ -98,7 +93,7 @@ class ResNet(nn.Module): return x def load_param(self, model_path): - param_dict = th.load(model_path) + param_dict = torch.load(model_path) for i in param_dict: if 'fc' in i: continue @@ -112,11 +107,3 @@ class ResNet(nn.Module): elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() - - -if __name__ == "__main__": - net = ResNet(last_stride=2) - import torch - - x = net(torch.zeros(1, 3, 256, 128)) - print(x.shape) diff --git a/network/baseline.py b/modeling/baseline.py similarity index 77% rename from network/baseline.py rename to modeling/baseline.py index 05e70c7..b7fdda7 100644 --- a/network/baseline.py +++ b/modeling/baseline.py @@ -1,17 +1,12 @@ # encoding: utf-8 """ @author: liaoxingyu -@contact: xyliao1993@qq.com +@contact: sherlockliao01@gmail.com """ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -from __future__ import unicode_literals - from torch import nn -from .resnet import ResNet +from .backbones.resnet import ResNet def weights_init_kaiming(m): @@ -40,11 +35,12 @@ def weights_init_classifier(m): class Baseline(nn.Module): in_planes = 2048 - def __init__(self, num_classes=10, last_stride=1, model_path='/home/test2/.torch/models/resnet50-19c8e357.pth'): + def __init__(self, num_classes, last_stride, model_path): super(Baseline, self).__init__() self.base = ResNet(last_stride) self.base.load_param(model_path) self.gap = nn.AdaptiveAvgPool2d(1) + # self.gap = nn.AdaptiveMaxPool2d(1) self.num_classes = num_classes self.bottleneck = nn.BatchNorm1d(self.in_planes) @@ -63,15 +59,3 @@ class Baseline(nn.Module): return cls_score, global_feat # global feature for triplet loss else: return feat - - -if __name__ == '__main__': - # net = Baseline(751).cuda(1) - import torch - - net = ResNet(1).cuda(1) - x = torch.ones(128, 3, 256, 128).cuda(1) - y = net(x) - from IPython import embed - - embed() diff --git a/network/__init__.py b/network/__init__.py deleted file mode 100644 index 30a1387..0000000 --- a/network/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# encoding: utf-8 -""" -@author: liaoxingyu -@contact: xyliao1993@qq.com -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -from __future__ import unicode_literals - - -from .baseline import Baseline diff --git a/scripts/test.sh b/scripts/test.sh deleted file mode 100644 index c96c2d4..0000000 --- a/scripts/test.sh +++ /dev/null @@ -1,5 +0,0 @@ -#!/usr/bin/env bash - - -python3 tools/test.py --config_file='configs/market_softmax_triplet.yml' \ ---load_model='/home/test2/liaoxingyu/pytorch-ckpt/reid/market_softmax_triplet/350_Baseline350.pth.tar' \ No newline at end of file diff --git a/scripts/train_softmax.sh b/scripts/train_softmax.sh deleted file mode 100644 index c07b0af..0000000 --- a/scripts/train_softmax.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/usr/bin/env bash - -checkpoint_dir=/home/test2/liaoxingyu/pytorch-ckpt/reid/market_softmax/ -mkdir -p ${checkpoint_dir} - -python3 tools/train.py --config_file='configs/market_softmax.yml' \ ---save_dir=${checkpoint_dir} | tee ${checkpoint_dir}/train.log - diff --git a/scripts/train_softmax_triplet.sh b/scripts/train_softmax_triplet.sh deleted file mode 100644 index 4899e07..0000000 --- a/scripts/train_softmax_triplet.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/usr/bin/env bash - -checkpoint_dir=/home/test2/liaoxingyu/pytorch-ckpt/reid/market_softmax_triplet/ -mkdir -p ${checkpoint_dir} - -python3 tools/train.py --config_file='configs/market_softmax_triplet.yml' \ ---save_dir=${checkpoint_dir} | tee ${checkpoint_dir}/train.log - diff --git a/scripts/train_triplet.sh b/scripts/train_triplet.sh deleted file mode 100644 index 2859a3f..0000000 --- a/scripts/train_triplet.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/usr/bin/env bash - -checkpoint_dir=/home/test2/liaoxingyu/pytorch-ckpt/reid/market_triplet/ -mkdir -p ${checkpoint_dir} - -python3 tools/train.py --config_file='configs/market_triplet.yml' \ ---save_dir=${checkpoint_dir} | tee ${checkpoint_dir}/train.log - diff --git a/solver/__init__.py b/solver/__init__.py new file mode 100644 index 0000000..fb7c7dd --- /dev/null +++ b/solver/__init__.py @@ -0,0 +1,8 @@ +# encoding: utf-8 +""" +@author: sherlock +@contact: sherlockliao01@gmail.com +""" + +from .build import make_optimizer +from .lr_scheduler import WarmupMultiStepLR \ No newline at end of file diff --git a/solver/build.py b/solver/build.py new file mode 100644 index 0000000..98df468 --- /dev/null +++ b/solver/build.py @@ -0,0 +1,25 @@ +# encoding: utf-8 +""" +@author: sherlock +@contact: sherlockliao01@gmail.com +""" + +import torch + + +def make_optimizer(cfg, model): + params = [] + for key, value in model.named_parameters(): + if not value.requires_grad: + continue + lr = cfg.SOLVER.BASE_LR + weight_decay = cfg.SOLVER.WEIGHT_DECAY + if "bias" in key: + lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR + weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS + params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] + if cfg.SOLVER.OPTIMIZER_NAME == 'SGD': + optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params, momentum=cfg.SOLVER.MOMENTUM) + else: + optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params) + return optimizer diff --git a/solver/lr_scheduler.py b/solver/lr_scheduler.py new file mode 100644 index 0000000..7e9f82e --- /dev/null +++ b/solver/lr_scheduler.py @@ -0,0 +1,56 @@ +# encoding: utf-8 +""" +@author: liaoxingyu +@contact: sherlockliao01@gmail.com +""" +from bisect import bisect_right +import torch + + +# FIXME ideally this would be achieved with a CombinedLRScheduler, +# separating MultiStepLR with WarmupLR +# but the current LRScheduler design doesn't allow it + +class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): + def __init__( + self, + optimizer, + milestones, + gamma=0.1, + warmup_factor=1.0 / 3, + warmup_iters=500, + warmup_method="linear", + last_epoch=-1, + ): + if not list(milestones) == sorted(milestones): + raise ValueError( + "Milestones should be a list of" " increasing integers. Got {}", + milestones, + ) + + if warmup_method not in ("constant", "linear"): + raise ValueError( + "Only 'constant' or 'linear' warmup_method accepted" + "got {}".format(warmup_method) + ) + self.milestones = milestones + self.gamma = gamma + self.warmup_factor = warmup_factor + self.warmup_iters = warmup_iters + self.warmup_method = warmup_method + super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) + + def get_lr(self): + warmup_factor = 1 + if self.last_epoch < self.warmup_iters: + if self.warmup_method == "constant": + warmup_factor = self.warmup_factor + elif self.warmup_method == "linear": + alpha = self.last_epoch / self.warmup_iters + warmup_factor = self.warmup_factor * (1 - alpha) + alpha + return [ + base_lr + * warmup_factor + * self.gamma ** bisect_right(self.milestones, self.last_epoch) + for base_lr in self.base_lrs + ] diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e404ab1 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,5 @@ +# encoding: utf-8 +""" +@author: sherlock +@contact: sherlockliao01@gmail.com +""" diff --git a/tests/lr_scheduler_test.py b/tests/lr_scheduler_test.py new file mode 100644 index 0000000..eb9ee4a --- /dev/null +++ b/tests/lr_scheduler_test.py @@ -0,0 +1,26 @@ +import sys +import unittest + +import torch +from torch import nn + +sys.path.append('.') +from solver.lr_scheduler import WarmupMultiStepLR +from solver.build import make_optimizer +from config import cfg + + +class MyTestCase(unittest.TestCase): + def test_something(self): + net = nn.Linear(10, 10) + optimizer = make_optimizer(cfg, net) + lr_scheduler = WarmupMultiStepLR(optimizer, [20, 40], warmup_iters=10) + for i in range(50): + lr_scheduler.step() + for j in range(3): + print(i, lr_scheduler.get_lr()[0]) + optimizer.step() + + +if __name__ == '__main__': + unittest.main() diff --git a/tools/__init__.py b/tools/__init__.py index 0572b5a..e404ab1 100644 --- a/tools/__init__.py +++ b/tools/__init__.py @@ -3,9 +3,3 @@ @author: sherlock @contact: sherlockliao01@gmail.com """ - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -from __future__ import unicode_literals - diff --git a/tools/test.py b/tools/test.py index 74b58ed..886b2d5 100644 --- a/tools/test.py +++ b/tools/test.py @@ -4,64 +4,61 @@ @contact: sherlockliao01@gmail.com """ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -from __future__ import unicode_literals - import argparse -import logging import os import sys -from pprint import pprint +from os import mkdir import torch -from torch import nn from torch.backends import cudnn -import network -from core.config import opt, update_config -from core.loader import get_data_provider -from core.solver import Solver - -FORMAT = '[%(levelname)s]: %(message)s' -logging.basicConfig( - level=logging.INFO, - format=FORMAT, - stream=sys.stdout -) - - -def test(args): - logging.info('======= user config ======') - logging.info(pprint(opt)) - logging.info(pprint(args)) - logging.info('======= end ======') - - train_data, test_data, num_query = get_data_provider(opt) - - net = getattr(network, opt.network.name)(opt.dataset.num_classes, opt.network.last_stride) - net.load_state_dict(torch.load(args.load_model)['state_dict']) - net = nn.DataParallel(net).cuda() - - mod = Solver(opt, net) - mod.test_func(test_data, num_query) +sys.path.append('.') +from config import cfg +from data import make_data_loader +from engine.inference import inference +from modeling import build_model +from utils.logger import setup_logger def main(): - parser = argparse.ArgumentParser(description='reid model testing') - parser.add_argument('--config_file', type=str, default=None, - help='Optional config file for params') - parser.add_argument('--load_model', type=str, required=True, - help='load trained model for testing') + parser = argparse.ArgumentParser(description="ReID Baseline Inference") + parser.add_argument( + "--config_file", default="", help="path to config file", type=str + ) + parser.add_argument("opts", help="Modify config options using the command-line", default=None, + nargs=argparse.REMAINDER) args = parser.parse_args() - if args.config_file is not None: - update_config(args.config_file) - os.environ["CUDA_VISIBLE_DEVICES"] = opt.network.gpus + num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 + + if args.config_file != "": + cfg.merge_from_file(args.config_file) + cfg.merge_from_list(args.opts) + cfg.freeze() + + output_dir = cfg.OUTPUT_DIR + if output_dir and not os.path.exists(output_dir): + mkdir(output_dir) + + logger = setup_logger("reid_baseline", output_dir, 0) + logger.info("Using {} GPUS".format(num_gpus)) + logger.info(args) + + if args.config_file != "": + logger.info("Loaded configuration file {}".format(args.config_file)) + with open(args.config_file, 'r') as cf: + config_str = "\n" + cf.read() + logger.info(config_str) + logger.info("Running with config:\n{}".format(cfg)) + cudnn.benchmark = True - test(args) + + train_loader, val_loader, num_query, num_classes = make_data_loader(cfg) + model = build_model(cfg, num_classes) + model.load_state_dict(torch.load(cfg.TEST.WEIGHT)) + + inference(cfg, model, val_loader, num_query) if __name__ == '__main__': diff --git a/tools/train.py b/tools/train.py index 0db20ec..aa1f8a5 100644 --- a/tools/train.py +++ b/tools/train.py @@ -4,95 +4,83 @@ @contact: sherlockliao01@gmail.com """ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -from __future__ import unicode_literals - import argparse -import logging import os import sys -from pprint import pprint -import torch -from torch import nn from torch.backends import cudnn -import network -from core.config import opt, update_config -from core.loader import get_data_provider -from core.solver import Solver -from utils.loss import TripletLoss -from utils.lr_scheduler import LRScheduler +sys.path.append('.') +from config import cfg +from data import make_data_loader +from engine.trainer import do_train +from modeling import build_model +from layers import make_loss +from solver import make_optimizer, WarmupMultiStepLR -FORMAT = '[%(levelname)s]: %(message)s' -logging.basicConfig( - level=logging.INFO, - format=FORMAT, - stream=sys.stdout -) +from utils.logger import setup_logger -def train(args): - logging.info('======= user config ======') - logging.info(pprint(opt)) - logging.info(pprint(args)) - logging.info('======= end ======') +def train(cfg): + # prepare dataset + train_loader, val_loader, num_query, num_classes = make_data_loader(cfg) + # prepare model + model = build_model(cfg, num_classes) - train_data, test_data, num_query = get_data_provider(opt) + optimizer = make_optimizer(cfg, model) + scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR, + cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD) - net = getattr(network, opt.network.name)(opt.dataset.num_classes, opt.network.last_stride) + loss_func = make_loss(cfg) - optimizer = getattr(torch.optim, opt.train.optimizer)(net.parameters(), lr=opt.train.lr, weight_decay=opt.train.wd) - ce_loss = nn.CrossEntropyLoss() - triplet_loss = TripletLoss(margin=opt.train.margin) + arguments = {} - def ce_loss_func(scores, feat, labels): - ce = ce_loss(scores, labels) - return ce - - def tri_loss_func(scores, feat, labels): - tri = triplet_loss(feat, labels)[0] - return tri - - def ce_tri_loss_func(scores, feat, labels): - ce = ce_loss(scores, labels) - triplet = triplet_loss(feat, labels)[0] - return ce + triplet - - if opt.train.loss_fn == 'softmax': - loss_fn = ce_loss_func - elif opt.train.loss_fn == 'triplet': - loss_fn = tri_loss_func - elif opt.train.loss_fn == 'softmax_triplet': - loss_fn = ce_tri_loss_func - else: - raise ValueError('Unknown loss func {}'.format(opt.train.loss_fn)) - - lr_scheduler = LRScheduler(base_lr=opt.train.lr, step=opt.train.step, - factor=opt.train.factor, warmup_epoch=opt.train.warmup_epoch, - warmup_begin_lr=opt.train.warmup_begin_lr) - net = nn.DataParallel(net).cuda() - mod = Solver(opt, net) - mod.fit(train_data=train_data, test_data=test_data, num_query=num_query, optimizer=optimizer, - criterion=loss_fn, lr_scheduler=lr_scheduler) + do_train( + cfg, + model, + train_loader, + val_loader, + optimizer, + scheduler, + loss_func, + num_query + ) def main(): - parser = argparse.ArgumentParser(description='reid model training') - parser.add_argument('--config_file', type=str, default=None, required=True, - help='Optional config file for params') - parser.add_argument('--save_dir', type=str, default=None, required=True, - help='model save checkpoint directory') + parser = argparse.ArgumentParser(description="ReID Baseline Training") + parser.add_argument( + "--config_file", default="", help="path to config file", type=str + ) + parser.add_argument("opts", help="Modify config options using the command-line", default=None, + nargs=argparse.REMAINDER) args = parser.parse_args() - if args.config_file is not None: - update_config(args.config_file) - opt.misc.save_dir = args.save_dir - os.environ["CUDA_VISIBLE_DEVICES"] = opt.network.gpus + + num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 + + if args.config_file != "": + cfg.merge_from_file(args.config_file) + cfg.merge_from_list(args.opts) + cfg.freeze() + + output_dir = cfg.OUTPUT_DIR + if output_dir and not os.path.exists(output_dir): + os.makedirs(output_dir) + + logger = setup_logger("reid_baseline", output_dir, 0) + logger.info("Using {} GPUS".format(num_gpus)) + logger.info(args) + + if args.config_file != "": + logger.info("Loaded configuration file {}".format(args.config_file)) + with open(args.config_file, 'r') as cf: + config_str = "\n" + cf.read() + logger.info(config_str) + logger.info("Running with config:\n{}".format(cfg)) + cudnn.benchmark = True - train(args) + train(cfg) if __name__ == '__main__': diff --git a/utils/__init__.py b/utils/__init__.py index 4ec04de..42be7d8 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -1,11 +1,6 @@ # encoding: utf-8 """ -@author: liaoxingyu -@contact: xyliao1993@qq.com +@author: sherlock +@contact: sherlockliao01@gmail.com """ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -from __future__ import unicode_literals - diff --git a/utils/iotools.py b/utils/iotools.py new file mode 100644 index 0000000..2aac16d --- /dev/null +++ b/utils/iotools.py @@ -0,0 +1,39 @@ +# encoding: utf-8 +""" +@author: sherlock +@contact: sherlockliao01@gmail.com +""" + +import errno +import json +import os + +import os.path as osp + + +def mkdir_if_missing(directory): + if not osp.exists(directory): + try: + os.makedirs(directory) + except OSError as e: + if e.errno != errno.EEXIST: + raise + + +def check_isfile(path): + isfile = osp.isfile(path) + if not isfile: + print("=> Warning: no file found at '{}' (ignored)".format(path)) + return isfile + + +def read_json(fpath): + with open(fpath, 'r') as f: + obj = json.load(f) + return obj + + +def write_json(obj, fpath): + mkdir_if_missing(osp.dirname(fpath)) + with open(fpath, 'w') as f: + json.dump(obj, f, indent=4, separators=(',', ': ')) diff --git a/utils/logger.py b/utils/logger.py new file mode 100644 index 0000000..a9c32e7 --- /dev/null +++ b/utils/logger.py @@ -0,0 +1,30 @@ +# encoding: utf-8 +""" +@author: sherlock +@contact: sherlockliao01@gmail.com +""" + +import logging +import os +import sys + + +def setup_logger(name, save_dir, distributed_rank): + logger = logging.getLogger(name) + logger.setLevel(logging.DEBUG) + # don't log results for the non-master process + if distributed_rank > 0: + return logger + ch = logging.StreamHandler(stream=sys.stdout) + ch.setLevel(logging.DEBUG) + formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s") + ch.setFormatter(formatter) + logger.addHandler(ch) + + if save_dir: + fh = logging.FileHandler(os.path.join(save_dir, "log.txt"), mode='w') + fh.setLevel(logging.DEBUG) + fh.setFormatter(formatter) + logger.addHandler(fh) + + return logger diff --git a/utils/lr_scheduler.py b/utils/lr_scheduler.py deleted file mode 100644 index 9a74d9c..0000000 --- a/utils/lr_scheduler.py +++ /dev/null @@ -1,65 +0,0 @@ -# encoding: utf-8 -""" -@author: sherlock -@contact: sherlockliao01@gmail.com -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -from __future__ import unicode_literals - - -class LRScheduler(object): - """Base class of a learning rate scheduler. - - A scheduler returns a new learning rate based on the number of updates that have - been performed. - - Parameters - ---------- - base_lr : float, optional - The initial learning rate. - warmup_epoch: int - number of warmup steps used before this scheduler starts decay - warmup_begin_lr: float - if using warmup, the learning rate from which it starts warming up - warmup_mode: string - warmup can be done in two modes. - 'linear' mode gradually increases lr with each step in equal increments - 'constant' mode keeps lr at warmup_begin_lr for warmup_steps - """ - - def __init__(self, base_lr=0.01, step=(30, 60), factor=0.1, - warmup_epoch=0, warmup_begin_lr=0, warmup_mode='linear'): - self.base_lr = base_lr - self.learning_rate = base_lr - self.step = step - self.factor = factor - assert isinstance(warmup_epoch, int) - self.warmup_epoch = warmup_epoch - - self.warmup_final_lr = base_lr - self.warmup_begin_lr = warmup_begin_lr - if self.warmup_begin_lr > self.warmup_final_lr: - raise ValueError("Base lr has to be higher than warmup_begin_lr") - if self.warmup_epoch < 0: - raise ValueError("Warmup steps has to be positive or 0") - if warmup_mode not in ['linear', 'constant']: - raise ValueError("Supports only linear and constant modes of warmup") - self.warmup_mode = warmup_mode - - def update(self, num_epoch): - if self.warmup_epoch > num_epoch: - # warmup strategy - if self.warmup_mode == 'linear': - self.learning_rate = self.warmup_begin_lr + (self.warmup_final_lr - self.warmup_begin_lr) * \ - num_epoch / self.warmup_epoch - elif self.warmup_mode == 'constant': - self.learning_rate = self.warmup_begin_lr - - else: - count = sum([1 for s in self.step if s <= num_epoch]) - self.learning_rate = self.base_lr * pow(self.factor, count) - return self.learning_rate - diff --git a/utils/meters.py b/utils/meters.py deleted file mode 100644 index fd2b9ea..0000000 --- a/utils/meters.py +++ /dev/null @@ -1,54 +0,0 @@ -# encoding: utf-8 -""" -@author: liaoxingyu -@contact: xyliao1993@qq.com -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -from __future__ import unicode_literals - -import math - -import numpy as np - - -class AverageMeter(object): - def __init__(self, name): - self.name = name - self.n = 0 - self.sum = 0.0 - self.var = 0.0 - self.val = 0.0 - self.mean = np.nan - self.std = np.nan - - def update(self, value, n=1): - self.val = value - self.sum += value - self.var += value * value - self.n += n - - if self.n == 0: - self.mean, self.std = np.nan, np.nan - elif self.n == 1: - self.mean, self.std = self.sum, np.inf - else: - self.mean = self.sum / self.n - self.std = math.sqrt( - (self.var - self.n * self.mean * self.mean) / (self.n - 1.0)) - - def value(self): - return self.mean, self.std - - def get(self): - return self.name, self.mean - - def reset(self): - self.n = 0 - self.sum = 0.0 - self.var = 0.0 - self.val = 0.0 - self.mean = np.nan - self.std = np.nan diff --git a/utils/reid_metric.py b/utils/reid_metric.py new file mode 100644 index 0000000..644dbf0 --- /dev/null +++ b/utils/reid_metric.py @@ -0,0 +1,48 @@ +# encoding: utf-8 +""" +@author: liaoxingyu +@contact: sherlockliao01@gmail.com +""" + +import numpy as np +import torch +from ignite.metrics import Metric + +from data.datasets.eval_reid import eval_func + + +class R1_mAP(Metric): + def __init__(self, num_query, max_rank=50): + super(R1_mAP, self).__init__() + self.num_query = num_query + self.max_rank = max_rank + + def reset(self): + self.feats = [] + self.pids = [] + self.camids = [] + + def update(self, output): + feat, pid, camid = output + self.feats.append(feat) + self.pids.extend(np.asarray(pid)) + self.camids.extend(np.asarray(camid)) + + def compute(self): + feats = torch.cat(self.feats, dim=0) + # query + qf = feats[:self.num_query] + q_pids = np.asarray(self.pids[:self.num_query]) + q_camids = np.asarray(self.camids[:self.num_query]) + # gallery + gf = feats[self.num_query:] + g_pids = np.asarray(self.pids[self.num_query:]) + g_camids = np.asarray(self.camids[self.num_query:]) + m, n = qf.shape[0], gf.shape[0] + distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \ + torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t() + distmat.addmm_(1, -2, qf, gf.t()) + distmat = distmat.cpu().numpy() + cmc, mAP = eval_func(distmat, q_pids, g_pids, q_camids, g_camids) + + return cmc, mAP diff --git a/utils/serialization.py b/utils/serialization.py deleted file mode 100644 index f6d008f..0000000 --- a/utils/serialization.py +++ /dev/null @@ -1,35 +0,0 @@ -# encoding: utf-8 -""" -@author: liaoxingyu -@contact: xyliao1993@qq.com -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -from __future__ import unicode_literals - -import errno -import os -import shutil -import sys - -import os.path as osp -import torch - - -def mkdir_if_missing(dir_path): - try: - os.makedirs(dir_path) - except OSError as e: - if e.errno != errno.EEXIST: - raise - - -def save_checkpoint(state, is_best, save_dir, filename='checkpoint.pth.tar'): - fpath = '_'.join((str(state['epoch']), filename)) - fpath = osp.join(save_dir, fpath) - mkdir_if_missing(save_dir) - torch.save(state, fpath) - if is_best: - shutil.copy(fpath, osp.join(save_dir, 'model_best.pth.tar'))