mirror of https://github.com/JDAI-CV/fast-reid.git
<feature> fix bug and remove other dependency
1. remove pytorch-lightning dependency, and now it's pure pytorch code. 2. add prefetcher dataloader to accelerate data loading 3. add ibn and new training engine 4. add multi-dataset supportpull/43/head
parent
15a911879a
commit
9a3d365e9b
17
README.md
17
README.md
|
@ -5,7 +5,8 @@ A strong baseline (state-of-the-art) for person re-identification.
|
|||
We support
|
||||
- [x] easy dataset preparation
|
||||
- [x] end-to-end training and evaluation
|
||||
- [ ] multi-GPU distributed training
|
||||
- [x] multi-GPU distributed training
|
||||
- [x] fast data loader with prefetcher
|
||||
- [ ] fast training speed with fp16
|
||||
- [x] fast evaluation with cython
|
||||
- [ ] support both image and video reid
|
||||
|
@ -26,7 +27,7 @@ The designed architecture follows this guide [PyTorch-Project-Template](https://
|
|||
3. Install dependencies:
|
||||
- [pytorch 1.0.0+](https://pytorch.org/)
|
||||
- torchvision
|
||||
- [pytorch-lightning](https://github.com/williamFalcon/pytorch-lightning)
|
||||
- tensorboard
|
||||
- [yacs](https://github.com/rbgirshick/yacs)
|
||||
4. Prepare dataset
|
||||
|
||||
|
@ -56,21 +57,21 @@ The designed architecture follows this guide [PyTorch-Project-Template](https://
|
|||
## Train
|
||||
Most of the configuration files that we provide, you can run this command for training market1501
|
||||
```bash
|
||||
bash scripts/train_market.sh
|
||||
bash scripts/train_openset.sh
|
||||
```
|
||||
|
||||
Or you can just run code below to modify your cfg parameters
|
||||
```bash
|
||||
python3 tools/train.py -cfg='configs/softmax.yml' INPUT.SIZE_TRAIN '(256, 128)' INPUT.SIZE_TEST '(256, 128)'
|
||||
CUDA_VISIBLE_DEVICES='0,1' python tools/train.py -cfg='configs/softmax_triplet.yml' DATASETS.NAMES '("dukemtmc","market1501",)' SOLVER.IMS_PER_BATCH '256'
|
||||
```
|
||||
|
||||
## Test
|
||||
You can test your model's performance directly by running this command
|
||||
```bash
|
||||
python3 tools/test.py DATASET.TEST_NAMES 'duke' \
|
||||
MODEL.BACKBONE 'resnet50' \
|
||||
MODEL.WITH_IBN 'True' \
|
||||
TEST.WEIGHT '/save/trained_model/path'
|
||||
CUDA_VISIBLE_DEVICES='0' python tools/test.py -cfg='configs/softmax_triplet.yml' DATASET.TEST_NAMES 'dukemtmc' \
|
||||
MODEL.BACKBONE 'resnet50' \
|
||||
MODEL.WITH_IBN 'True' \
|
||||
TEST.WEIGHT '/save/trained_model/path'
|
||||
```
|
||||
|
||||
## Experiment Results
|
||||
|
|
|
@ -20,7 +20,8 @@ _C = CN()
|
|||
# MODEL
|
||||
# -----------------------------------------------------------------------------
|
||||
_C.MODEL = CN()
|
||||
_C.MODEL.GPUS = [0]
|
||||
_C.MODEL.NAME = 'baseline'
|
||||
_C.MODEL.DIST_BACKEND = 'dp'
|
||||
# Model backbone
|
||||
_C.MODEL.BACKBONE = 'resnet50'
|
||||
# Last stride for backbone
|
||||
|
@ -84,6 +85,7 @@ _C.DATALOADER = CN()
|
|||
_C.DATALOADER.SAMPLER = 'softmax'
|
||||
# Number of instance for each person
|
||||
_C.DATALOADER.NUM_INSTANCE = 4
|
||||
_C.DATALOADER.NUM_WORKERS = 8
|
||||
|
||||
# ---------------------------------------------------------------------------- #
|
||||
# Solver
|
||||
|
@ -110,11 +112,11 @@ _C.SOLVER.WEIGHT_DECAY_BIAS = 0.
|
|||
_C.SOLVER.GAMMA = 0.1
|
||||
_C.SOLVER.STEPS = (30, 55)
|
||||
|
||||
_C.SOLVER.WARMUP_FACTOR = 1.0 / 3
|
||||
_C.SOLVER.WARMUP_ITERS = 500
|
||||
_C.SOLVER.WARMUP_FACTOR = 0.1
|
||||
_C.SOLVER.WARMUP_ITERS = 10
|
||||
_C.SOLVER.WARMUP_METHOD = "linear"
|
||||
|
||||
_C.SOLVER.CHECKPOINT_PERIOD = 50
|
||||
_C.SOLVER.LOG_INTERVAL = 30
|
||||
_C.SOLVER.EVAL_PERIOD = 50
|
||||
# Number of images per batch
|
||||
# This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will
|
||||
|
@ -131,4 +133,4 @@ _C.TEST.WEIGHT = ""
|
|||
# ---------------------------------------------------------------------------- #
|
||||
# Misc options
|
||||
# ---------------------------------------------------------------------------- #
|
||||
_C.OUTPUT_DIR = ""
|
||||
_C.OUTPUT_DIR = "logs/"
|
|
@ -8,17 +8,17 @@ DATALOADER:
|
|||
SOLVER:
|
||||
OPT: 'adam'
|
||||
LOSSTYPE: ('softmax',)
|
||||
MAX_EPOCHS: 80
|
||||
MAX_EPOCHS: 100
|
||||
BASE_LR: 0.00035
|
||||
WEIGHT_DECAY: 0.0005
|
||||
WEIGHT_DECAY_BIAS: 0.0005
|
||||
IMS_PER_BATCH: 64
|
||||
|
||||
STEPS: [30, 55]
|
||||
STEPS: [30, 55, 80]
|
||||
GAMMA: 0.1
|
||||
|
||||
WARMUP_FACTOR: 0.01
|
||||
WARMUP_ITERS: 5
|
||||
WARMUP_ITERS: 10
|
||||
WARMUP_METHOD: 'linear'
|
||||
|
||||
CHECKPOINT_PERIOD: 20
|
||||
|
@ -27,6 +27,5 @@ SOLVER:
|
|||
TEST:
|
||||
IMS_PER_BATCH: 256
|
||||
|
||||
OUTPUT_DIR: "/export/home/lxy/CHECKPOINTS/reid/market1501/softmax_bs64_384x128"
|
||||
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ MODEL:
|
|||
|
||||
|
||||
DATASETS:
|
||||
NAMES: ("market1501",)
|
||||
NAMES: ('market1501',)
|
||||
TEST_NAMES: "market1501"
|
||||
|
||||
DATALOADER:
|
||||
|
@ -13,11 +13,10 @@ DATALOADER:
|
|||
NUM_INSTANCE: 4
|
||||
|
||||
SOLVER:
|
||||
OPT: 'adam'
|
||||
OPT: "adam"
|
||||
LOSSTYPE: ('softmax', 'triplet')
|
||||
MAX_EPOCHS: 150
|
||||
# BASE_LR: 0.00035
|
||||
BASE_LR: 0.0007
|
||||
MAX_EPOCHS: 120
|
||||
BASE_LR: 0.00035
|
||||
WEIGHT_DECAY: 0.0005
|
||||
WEIGHT_DECAY_BIAS: 0.0005
|
||||
IMS_PER_BATCH: 64
|
||||
|
@ -34,6 +33,5 @@ TEST:
|
|||
IMS_PER_BATCH: 512
|
||||
WEIGHT: "path"
|
||||
|
||||
OUTPUT_DIR: "logs/market/batch256/"
|
||||
|
||||
|
||||
|
|
|
@ -4,4 +4,4 @@
|
|||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from .build import get_dataloader
|
||||
from .build import get_dataloader, get_test_dataloader
|
||||
|
|
|
@ -10,69 +10,55 @@ import re
|
|||
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .collate_batch import tng_collate_fn
|
||||
from .datasets import ImageDataset, CUHK03
|
||||
from .collate_batch import fast_collate_fn
|
||||
from .datasets import ImageDataset
|
||||
from .samplers import RandomIdentitySampler
|
||||
from .transforms import build_transforms
|
||||
from .datasets import init_dataset
|
||||
|
||||
|
||||
def get_dataloader(cfg):
|
||||
tng_tfms = build_transforms(cfg, is_train=True)
|
||||
val_tfms = build_transforms(cfg, is_train=False)
|
||||
|
||||
def _process_dir(dir_path):
|
||||
img_paths = []
|
||||
img_paths = glob.glob(os.path.join(dir_path, '*.jpg'))
|
||||
pattern = re.compile(r'([-\d]+)_c(\d*)')
|
||||
v_paths = []
|
||||
for img_path in img_paths:
|
||||
pid, camid = map(int, pattern.search(img_path).groups())
|
||||
pid = int(pid)
|
||||
if pid == -1: continue # junk images are just ignored
|
||||
v_paths.append([img_path,pid,camid])
|
||||
return v_paths
|
||||
|
||||
market_train_path = 'datasets/Market-1501-v15.09.15/bounding_box_train'
|
||||
duke_train_path = 'datasets/DukeMTMC-reID/bounding_box_train'
|
||||
cuhk03_train_path = 'datasets/cuhk03/'
|
||||
|
||||
market_query_path = 'datasets/Market-1501-v15.09.15/query'
|
||||
marker_gallery_path = 'datasets/Market-1501-v15.09.15/bounding_box_test'
|
||||
duke_query_path = 'datasets/DukeMTMC-reID/query'
|
||||
duek_gallery_path = 'datasets/DukeMTMC-reID/bounding_box_test'
|
||||
|
||||
print('prepare training set ...')
|
||||
train_img_items = list()
|
||||
for d in cfg.DATASETS.NAMES:
|
||||
if d == 'market1501': train_img_items.extend(_process_dir(market_train_path))
|
||||
elif d == 'duke': train_img_items.extend(_process_dir(duke_train_path))
|
||||
elif d == 'cuhk03': train_img_items.extend(CUHK03().train)
|
||||
else:
|
||||
raise NameError(f"{d} is not available")
|
||||
# dataset = init_dataset(d, combineall=True)
|
||||
dataset = init_dataset(d)
|
||||
train_img_items.extend(dataset.train)
|
||||
|
||||
if cfg.DATASETS.TEST_NAMES == "market1501":
|
||||
query_names = _process_dir(market_query_path)
|
||||
gallery_names = _process_dir(marker_gallery_path)
|
||||
elif cfg.DATASETS.TEST_NAMES == 'duke':
|
||||
query_names = _process_dir(duke_query_path)
|
||||
gallery_names = _process_dir(duek_gallery_path)
|
||||
else:
|
||||
print(f"not support {cfg.DATASETS.TEST_NAMES} test set")
|
||||
|
||||
num_workers = min(16, len(os.sched_getaffinity(0)))
|
||||
print('prepare test set ...')
|
||||
dataset = init_dataset(cfg.DATASETS.TEST_NAMES)
|
||||
query_names, gallery_names = dataset.query, dataset.gallery
|
||||
|
||||
tng_set = ImageDataset(train_img_items, tng_tfms, relabel=True)
|
||||
if cfg.DATALOADER.SAMPLER == 'softmax':
|
||||
tng_dataloader = DataLoader(tng_set, cfg.SOLVER.IMS_PER_BATCH, shuffle=True,
|
||||
num_workers=num_workers, collate_fn=tng_collate_fn,
|
||||
pin_memory=True)
|
||||
elif cfg.DATALOADER.SAMPLER == 'triplet':
|
||||
data_sampler = RandomIdentitySampler(train_img_items, cfg.SOLVER.IMS_PER_BATCH, cfg.DATALOADER.NUM_INSTANCE)
|
||||
tng_dataloader = DataLoader(tng_set, cfg.SOLVER.IMS_PER_BATCH, sampler=data_sampler,
|
||||
num_workers=num_workers, collate_fn=tng_collate_fn,
|
||||
pin_memory=True)
|
||||
else:
|
||||
raise NameError(f"{cfg.DATALOADER.SAMPLER} sampler is not support")
|
||||
|
||||
num_workers = cfg.DATALOADER.NUM_WORKERS
|
||||
data_sampler = None
|
||||
if cfg.DATALOADER.SAMPLER == 'triplet':
|
||||
data_sampler = RandomIdentitySampler(tng_set.img_items, cfg.SOLVER.IMS_PER_BATCH, cfg.DATALOADER.NUM_INSTANCE)
|
||||
|
||||
tng_dataloader = DataLoader(tng_set, cfg.SOLVER.IMS_PER_BATCH, shuffle=(data_sampler is None),
|
||||
num_workers=num_workers, sampler=data_sampler,
|
||||
collate_fn=fast_collate_fn, pin_memory=True)
|
||||
|
||||
val_set = ImageDataset(query_names+gallery_names, val_tfms, relabel=False)
|
||||
val_dataloader = DataLoader(val_set, cfg.TEST.IMS_PER_BATCH, num_workers=num_cpus)
|
||||
val_dataloader = DataLoader(val_set, cfg.TEST.IMS_PER_BATCH, num_workers=num_workers,
|
||||
collate_fn=fast_collate_fn, pin_memory=True)
|
||||
return tng_dataloader, val_dataloader, tng_set.c, len(query_names)
|
||||
|
||||
|
||||
def get_test_dataloader(cfg):
|
||||
val_tfms = build_transforms(cfg, is_train=False)
|
||||
|
||||
print('prepare test set ...')
|
||||
dataset = init_dataset(cfg.DATASETS.TEST_NAMES)
|
||||
query_names, gallery_names = dataset.query, dataset.gallery
|
||||
|
||||
num_workers = cfg.DATALOADER.NUM_WORKERS
|
||||
|
||||
test_set = ImageDataset(query_names+gallery_names, val_tfms, relabel=False)
|
||||
test_dataloader = DataLoader(test_set, cfg.TEST.IMS_PER_BATCH, num_workers=num_workers,
|
||||
collate_fn=fast_collate_fn, pin_memory=True)
|
||||
return test_dataloader, len(query_names)
|
||||
|
|
|
@ -4,9 +4,31 @@
|
|||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
def tng_collate_fn(batch):
|
||||
def fast_collate_fn(batch):
|
||||
imgs, pids, camids = zip(*batch)
|
||||
return torch.stack(imgs, dim=0), torch.tensor(pids).long()
|
||||
is_ndarray = isinstance(imgs[0], np.ndarray)
|
||||
if not is_ndarray: # PIL Image object
|
||||
w = imgs[0].size[0]
|
||||
h = imgs[0].size[1]
|
||||
else:
|
||||
w = imgs[0].shape[1]
|
||||
h = imgs[0].shape[0]
|
||||
tensor = torch.zeros((len(imgs), 3, h, w), dtype=torch.uint8)
|
||||
for i, img in enumerate(imgs):
|
||||
if not is_ndarray:
|
||||
img = np.asarray(img, dtype=np.uint8)
|
||||
numpy_array = np.rollaxis(img, 2)
|
||||
tensor[i] += torch.from_numpy(numpy_array)
|
||||
return tensor, torch.tensor(pids).long(), camids
|
||||
|
||||
|
||||
# def dcl_collate_fn(batch):
|
||||
# imgs, swap_imgs, pids = zip(*batch)
|
||||
# imgs = torch.stack(imgs, dim=0)
|
||||
# swap_imgs = torch.stack(swap_imgs, dim=0)
|
||||
# # pids *= 2
|
||||
# swap_labels = [1] * imgs.size()[0] + [0] * swap_imgs.size()[0]
|
||||
# # return torch.cat([imgs, swap_imgs], dim=0), (tensor(pids).long(), tensor(swap_labels).long())
|
||||
# return imgs, (tensor(pids).long(), tensor(swap_labels).long())
|
||||
|
|
|
@ -6,12 +6,14 @@
|
|||
from .cuhk03 import CUHK03
|
||||
from .dukemtmcreid import DukeMTMCreID
|
||||
from .market1501 import Market1501
|
||||
from .msmt17 import MSMT17
|
||||
from .dataset_loader import *
|
||||
|
||||
__factory = {
|
||||
'market1501': Market1501,
|
||||
'cuhk03': CUHK03,
|
||||
'dukemtmc': DukeMTMCreID
|
||||
'dukemtmc': DukeMTMCreID,
|
||||
'msmt17': MSMT17
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -4,92 +4,309 @@
|
|||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import copy
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class BaseDataset(object):
|
||||
"""
|
||||
Base class of reid dataset
|
||||
def read_image(img_path):
|
||||
"""Keep reading image until succeed.
|
||||
This can avoid IOError incurred by heavy IO process."""
|
||||
got_img = False
|
||||
if not os.path.exists(img_path):
|
||||
raise IOError("{} does not exist".format(img_path))
|
||||
while not got_img:
|
||||
try:
|
||||
img = Image.open(img_path).convert('RGB')
|
||||
got_img = True
|
||||
except IOError:
|
||||
print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path))
|
||||
pass
|
||||
return img
|
||||
|
||||
|
||||
class Dataset(object):
|
||||
"""An abstract class representing a Dataset.
|
||||
This is the base class for ``ImageDataset`` and ``VideoDataset``.
|
||||
Args:
|
||||
train (list): contains tuples of (img_path(s), pid, camid).
|
||||
query (list): contains tuples of (img_path(s), pid, camid).
|
||||
gallery (list): contains tuples of (img_path(s), pid, camid).
|
||||
transform: transform function.
|
||||
mode (str): 'train', 'query' or 'gallery'.
|
||||
combineall (bool): combines train, query and gallery in a
|
||||
dataset for training.
|
||||
verbose (bool): show information.
|
||||
"""
|
||||
_junk_pids = [] # contains useless person IDs, e.g. background, false detections
|
||||
|
||||
def get_imagedata_info(self, data):
|
||||
pids, cams = [], []
|
||||
for _, pid, camid in data:
|
||||
pids += [pid]
|
||||
cams += [camid]
|
||||
pids = set(pids)
|
||||
cams = set(cams)
|
||||
num_pids = len(pids)
|
||||
num_cams = len(cams)
|
||||
num_imgs = len(data)
|
||||
return num_pids, num_imgs, num_cams
|
||||
def __init__(self, train, query, gallery, transform=None, mode='train',
|
||||
combineall=False, verbose=True, **kwargs):
|
||||
self.train = train
|
||||
self.query = query
|
||||
self.gallery = gallery
|
||||
self.transform = transform
|
||||
self.mode = mode
|
||||
self.combineall = combineall
|
||||
self.verbose = verbose
|
||||
|
||||
def get_videodata_info(self, data, return_tracklet_stats=False):
|
||||
pids, cams, tracklet_stats = [], [], []
|
||||
for img_paths, pid, camid in data:
|
||||
pids += [pid]
|
||||
cams += [camid]
|
||||
tracklet_stats += [len(img_paths)]
|
||||
pids = set(pids)
|
||||
cams = set(cams)
|
||||
num_pids = len(pids)
|
||||
num_cams = len(cams)
|
||||
num_tracklets = len(data)
|
||||
if return_tracklet_stats:
|
||||
return num_pids, num_tracklets, num_cams, tracklet_stats
|
||||
return num_pids, num_tracklets, num_cams
|
||||
self.num_train_pids = self.get_num_pids(self.train)
|
||||
self.num_train_cams = self.get_num_cams(self.train)
|
||||
|
||||
def print_dataset_statistics(self):
|
||||
if self.combineall:
|
||||
self.combine_all()
|
||||
|
||||
if self.mode == 'train':
|
||||
self.data = self.train
|
||||
elif self.mode == 'query':
|
||||
self.data = self.query
|
||||
elif self.mode == 'gallery':
|
||||
self.data = self.gallery
|
||||
else:
|
||||
raise ValueError('Invalid mode. Got {}, but expected to be '
|
||||
'one of [train | query | gallery]'.format(self.mode))
|
||||
|
||||
if self.verbose:
|
||||
self.show_summary()
|
||||
|
||||
def __getitem__(self, index):
|
||||
raise NotImplementedError
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
class BaseImageDataset(BaseDataset):
|
||||
"""
|
||||
Base class of image reid dataset
|
||||
def __add__(self, other):
|
||||
"""Adds two datasets together (only the train set)."""
|
||||
train = copy.deepcopy(self.train)
|
||||
|
||||
for img_path, pid, camid in other.train:
|
||||
pid += self.num_train_pids
|
||||
camid += self.num_train_cams
|
||||
train.append((img_path, pid, camid))
|
||||
|
||||
###################################
|
||||
# Things to do beforehand:
|
||||
# 1. set verbose=False to avoid unnecessary print
|
||||
# 2. set combineall=False because combineall would have been applied
|
||||
# if it was True for a specific dataset, setting it to True will
|
||||
# create new IDs that should have been included
|
||||
###################################
|
||||
if isinstance(train[0][0], str):
|
||||
return ImageDataset(
|
||||
train, self.query, self.gallery,
|
||||
transform=self.transform,
|
||||
mode=self.mode,
|
||||
combineall=False,
|
||||
verbose=False
|
||||
)
|
||||
else:
|
||||
return VideoDataset(
|
||||
train, self.query, self.gallery,
|
||||
transform=self.transform,
|
||||
mode=self.mode,
|
||||
combineall=False,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
def __radd__(self, other):
|
||||
"""Supports sum([dataset1, dataset2, dataset3])."""
|
||||
if other == 0:
|
||||
return self
|
||||
else:
|
||||
return self.__add__(other)
|
||||
|
||||
def parse_data(self, data):
|
||||
"""Parses data list and returns the number of person IDs
|
||||
and the number of camera views.
|
||||
Args:
|
||||
data (list): contains tuples of (img_path(s), pid, camid)
|
||||
"""
|
||||
pids = set()
|
||||
cams = set()
|
||||
for _, pid, camid in data:
|
||||
pids.add(pid)
|
||||
cams.add(camid)
|
||||
return len(pids), len(cams)
|
||||
|
||||
def get_num_pids(self, data):
|
||||
"""Returns the number of training person identities."""
|
||||
return self.parse_data(data)[0]
|
||||
|
||||
def get_num_cams(self, data):
|
||||
"""Returns the number of training cameras."""
|
||||
return self.parse_data(data)[1]
|
||||
|
||||
def show_summary(self):
|
||||
"""Shows dataset statistics."""
|
||||
pass
|
||||
|
||||
def combine_all(self):
|
||||
"""Combines train, query and gallery in a dataset for training."""
|
||||
combined = copy.deepcopy(self.train)
|
||||
|
||||
# relabel pids in gallery (query shares the same scope)
|
||||
g_pids = set()
|
||||
for _, pid, _ in self.gallery:
|
||||
if pid in self._junk_pids:
|
||||
continue
|
||||
g_pids.add(pid)
|
||||
pid2label = {pid: i for i, pid in enumerate(g_pids)}
|
||||
|
||||
def _combine_data(data):
|
||||
for img_path, pid, camid in data:
|
||||
if pid in self._junk_pids:
|
||||
continue
|
||||
pid = pid2label[pid] + self.num_train_pids
|
||||
combined.append((img_path, pid, camid))
|
||||
|
||||
_combine_data(self.query)
|
||||
_combine_data(self.gallery)
|
||||
|
||||
self.train = combined
|
||||
self.num_train_pids = self.get_num_pids(self.train)
|
||||
|
||||
def check_before_run(self, required_files):
|
||||
"""Checks if required files exist before going deeper.
|
||||
Args:
|
||||
required_files (str or list): string file name(s).
|
||||
"""
|
||||
if isinstance(required_files, str):
|
||||
required_files = [required_files]
|
||||
|
||||
for fpath in required_files:
|
||||
if not os.path.exists(fpath):
|
||||
raise RuntimeError('"{}" is not found'.format(fpath))
|
||||
|
||||
def __repr__(self):
|
||||
num_train_pids, num_train_cams = self.parse_data(self.train)
|
||||
num_query_pids, num_query_cams = self.parse_data(self.query)
|
||||
num_gallery_pids, num_gallery_cams = self.parse_data(self.gallery)
|
||||
|
||||
msg = ' ----------------------------------------\n' \
|
||||
' subset | # ids | # items | # cameras\n' \
|
||||
' ----------------------------------------\n' \
|
||||
' train | {:5d} | {:7d} | {:9d}\n' \
|
||||
' query | {:5d} | {:7d} | {:9d}\n' \
|
||||
' gallery | {:5d} | {:7d} | {:9d}\n' \
|
||||
' ----------------------------------------\n' \
|
||||
' items: images/tracklets for image/video dataset\n'.format(
|
||||
num_train_pids, len(self.train), num_train_cams,
|
||||
num_query_pids, len(self.query), num_query_cams,
|
||||
num_gallery_pids, len(self.gallery), num_gallery_cams
|
||||
)
|
||||
|
||||
return msg
|
||||
|
||||
|
||||
class ImageDataset(Dataset):
|
||||
"""A base class representing ImageDataset.
|
||||
All other image datasets should subclass it.
|
||||
``__getitem__`` returns an image given index.
|
||||
It will return ``img``, ``pid``, ``camid`` and ``img_path``
|
||||
where ``img`` has shape (channel, height, width). As a result,
|
||||
data in each batch has shape (batch_size, channel, height, width).
|
||||
"""
|
||||
|
||||
def print_dataset_statistics(self, train, query, gallery):
|
||||
num_train_pids, num_train_imgs, num_train_cams = self.get_imagedata_info(train)
|
||||
num_query_pids, num_query_imgs, num_query_cams = self.get_imagedata_info(query)
|
||||
num_gallery_pids, num_gallery_imgs, num_gallery_cams = self.get_imagedata_info(gallery)
|
||||
def __init__(self, train, query, gallery, **kwargs):
|
||||
super(ImageDataset, self).__init__(train, query, gallery, **kwargs)
|
||||
|
||||
print("Dataset statistics:")
|
||||
print(" ----------------------------------------")
|
||||
print(" subset | # ids | # images | # cameras")
|
||||
print(" ----------------------------------------")
|
||||
print(" train | {:5d} | {:8d} | {:9d}".format(num_train_pids, num_train_imgs, num_train_cams))
|
||||
print(" query | {:5d} | {:8d} | {:9d}".format(num_query_pids, num_query_imgs, num_query_cams))
|
||||
print(" gallery | {:5d} | {:8d} | {:9d}".format(num_gallery_pids, num_gallery_imgs, num_gallery_cams))
|
||||
print(" ----------------------------------------")
|
||||
def __getitem__(self, index):
|
||||
img_path, pid, camid = self.data[index]
|
||||
img = read_image(img_path)
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
return img, pid, camid, img_path
|
||||
|
||||
def show_summary(self):
|
||||
num_train_pids, num_train_cams = self.parse_data(self.train)
|
||||
num_query_pids, num_query_cams = self.parse_data(self.query)
|
||||
num_gallery_pids, num_gallery_cams = self.parse_data(self.gallery)
|
||||
|
||||
print('=> Loaded {}'.format(self.__class__.__name__))
|
||||
print(' ----------------------------------------')
|
||||
print(' subset | # ids | # images | # cameras')
|
||||
print(' ----------------------------------------')
|
||||
print(' train | {:5d} | {:8d} | {:9d}'.format(num_train_pids, len(self.train), num_train_cams))
|
||||
print(' query | {:5d} | {:8d} | {:9d}'.format(num_query_pids, len(self.query), num_query_cams))
|
||||
print(' gallery | {:5d} | {:8d} | {:9d}'.format(num_gallery_pids, len(self.gallery), num_gallery_cams))
|
||||
print(' ----------------------------------------')
|
||||
|
||||
|
||||
class BaseVideoDataset(BaseDataset):
|
||||
"""
|
||||
Base class of video reid dataset
|
||||
class VideoDataset(Dataset):
|
||||
"""A base class representing VideoDataset.
|
||||
All other video datasets should subclass it.
|
||||
``__getitem__`` returns an image given index.
|
||||
It will return ``imgs``, ``pid`` and ``camid``
|
||||
where ``imgs`` has shape (seq_len, channel, height, width). As a result,
|
||||
data in each batch has shape (batch_size, seq_len, channel, height, width).
|
||||
"""
|
||||
|
||||
def print_dataset_statistics(self, train, query, gallery):
|
||||
num_train_pids, num_train_tracklets, num_train_cams, train_tracklet_stats = \
|
||||
self.get_videodata_info(train, return_tracklet_stats=True)
|
||||
def __init__(self, train, query, gallery, seq_len=15, sample_method='evenly', **kwargs):
|
||||
super(VideoDataset, self).__init__(train, query, gallery, **kwargs)
|
||||
self.seq_len = seq_len
|
||||
self.sample_method = sample_method
|
||||
|
||||
num_query_pids, num_query_tracklets, num_query_cams, query_tracklet_stats = \
|
||||
self.get_videodata_info(query, return_tracklet_stats=True)
|
||||
if self.transform is None:
|
||||
raise RuntimeError('transform must not be None')
|
||||
|
||||
num_gallery_pids, num_gallery_tracklets, num_gallery_cams, gallery_tracklet_stats = \
|
||||
self.get_videodata_info(gallery, return_tracklet_stats=True)
|
||||
def __getitem__(self, index):
|
||||
img_paths, pid, camid = self.data[index]
|
||||
num_imgs = len(img_paths)
|
||||
|
||||
tracklet_stats = train_tracklet_stats + query_tracklet_stats + gallery_tracklet_stats
|
||||
min_num = np.min(tracklet_stats)
|
||||
max_num = np.max(tracklet_stats)
|
||||
avg_num = np.mean(tracklet_stats)
|
||||
if self.sample_method == 'random':
|
||||
# Randomly samples seq_len images from a tracklet of length num_imgs,
|
||||
# if num_imgs is smaller than seq_len, then replicates images
|
||||
indices = np.arange(num_imgs)
|
||||
replace = False if num_imgs>=self.seq_len else True
|
||||
indices = np.random.choice(indices, size=self.seq_len, replace=replace)
|
||||
# sort indices to keep temporal order (comment it to be order-agnostic)
|
||||
indices = np.sort(indices)
|
||||
|
||||
print("Dataset statistics:")
|
||||
print(" -------------------------------------------")
|
||||
print(" subset | # ids | # tracklets | # cameras")
|
||||
print(" -------------------------------------------")
|
||||
print(" train | {:5d} | {:11d} | {:9d}".format(num_train_pids, num_train_tracklets, num_train_cams))
|
||||
print(" query | {:5d} | {:11d} | {:9d}".format(num_query_pids, num_query_tracklets, num_query_cams))
|
||||
print(" gallery | {:5d} | {:11d} | {:9d}".format(num_gallery_pids, num_gallery_tracklets, num_gallery_cams))
|
||||
print(" -------------------------------------------")
|
||||
print(" number of images per tracklet: {} ~ {}, average {:.2f}".format(min_num, max_num, avg_num))
|
||||
print(" -------------------------------------------")
|
||||
elif self.sample_method == 'evenly':
|
||||
# Evenly samples seq_len images from a tracklet
|
||||
if num_imgs >= self.seq_len:
|
||||
num_imgs -= num_imgs % self.seq_len
|
||||
indices = np.arange(0, num_imgs, num_imgs/self.seq_len)
|
||||
else:
|
||||
# if num_imgs is smaller than seq_len, simply replicate the last image
|
||||
# until the seq_len requirement is satisfied
|
||||
indices = np.arange(0, num_imgs)
|
||||
num_pads = self.seq_len - num_imgs
|
||||
indices = np.concatenate([indices, np.ones(num_pads).astype(np.int32)*(num_imgs-1)])
|
||||
assert len(indices) == self.seq_len
|
||||
|
||||
elif self.sample_method == 'all':
|
||||
# Samples all images in a tracklet. batch_size must be set to 1
|
||||
indices = np.arange(num_imgs)
|
||||
|
||||
else:
|
||||
raise ValueError('Unknown sample method: {}'.format(self.sample_method))
|
||||
|
||||
imgs = []
|
||||
for index in indices:
|
||||
img_path = img_paths[int(index)]
|
||||
img = read_image(img_path)
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
img = img.unsqueeze(0) # img must be torch.Tensor
|
||||
imgs.append(img)
|
||||
imgs = torch.cat(imgs, dim=0)
|
||||
|
||||
return imgs, pid, camid
|
||||
|
||||
def show_summary(self):
|
||||
num_train_pids, num_train_cams = self.parse_data(self.train)
|
||||
num_query_pids, num_query_cams = self.parse_data(self.query)
|
||||
num_gallery_pids, num_gallery_cams = self.parse_data(self.gallery)
|
||||
|
||||
print('=> Loaded {}'.format(self.__class__.__name__))
|
||||
print(' -------------------------------------------')
|
||||
print(' subset | # ids | # tracklets | # cameras')
|
||||
print(' -------------------------------------------')
|
||||
print(' train | {:5d} | {:11d} | {:9d}'.format(num_train_pids, len(self.train), num_train_cams))
|
||||
print(' query | {:5d} | {:11d} | {:9d}'.format(num_query_pids, len(self.query), num_query_cams))
|
||||
print(' gallery | {:5d} | {:11d} | {:9d}'.format(num_gallery_pids, len(self.gallery), num_gallery_cams))
|
||||
print(' -------------------------------------------')
|
||||
|
|
|
@ -4,38 +4,34 @@
|
|||
@contact: liaoxingyu2@jd.com
|
||||
"""
|
||||
|
||||
import h5py
|
||||
import os.path as osp
|
||||
from scipy.io import loadmat
|
||||
from scipy.misc import imsave
|
||||
|
||||
from utils.iotools import mkdir_if_missing, write_json, read_json
|
||||
from .bases import BaseImageDataset
|
||||
from .bases import ImageDataset
|
||||
|
||||
|
||||
class CUHK03(BaseImageDataset):
|
||||
"""
|
||||
CUHK03
|
||||
class CUHK03(ImageDataset):
|
||||
"""CUHK03.
|
||||
|
||||
Reference:
|
||||
Li et al. DeepReID: Deep Filter Pairing Neural Network for Person Re-identification. CVPR 2014.
|
||||
URL: http://www.ee.cuhk.edu.hk/~xgwang/CUHK_identification.html#!
|
||||
Li et al. DeepReID: Deep Filter Pairing Neural Network for Person Re-identification. CVPR 2014.
|
||||
|
||||
URL: `<http://www.ee.cuhk.edu.hk/~xgwang/CUHK_identification.html#!>`_
|
||||
|
||||
Dataset statistics:
|
||||
# identities: 1360
|
||||
# images: 13164
|
||||
# cameras: 6
|
||||
# splits: 20 (classic)
|
||||
Args:
|
||||
split_id (int): split index (default: 0)
|
||||
cuhk03_labeled (bool): whether to load labeled images; if false, detected images are loaded (default: False)
|
||||
- identities: 1360.
|
||||
- images: 13164.
|
||||
- cameras: 6.
|
||||
- splits: 20 (classic).
|
||||
"""
|
||||
dataset_dir = 'cuhk03'
|
||||
dataset_url = None
|
||||
|
||||
def __init__(self, root='datasets', split_id=0, cuhk03_labeled=False, cuhk03_classic_split=False, **kwargs):
|
||||
# self.root = osp.abspath(osp.expanduser(root))
|
||||
self.root = root
|
||||
self.dataset_dir = osp.join(self.root, self.dataset_dir)
|
||||
|
||||
def __init__(self, root='datasets', split_id=0, cuhk03_labeled=False,
|
||||
cuhk03_classic_split=False, verbose=False,
|
||||
**kwargs):
|
||||
super(CUHK03, self).__init__()
|
||||
self.dataset_dir = osp.join(root, self.dataset_dir)
|
||||
self.data_dir = osp.join(self.dataset_dir, 'cuhk03_release')
|
||||
self.raw_mat_path = osp.join(self.data_dir, 'cuhk-03.mat')
|
||||
|
||||
|
@ -51,72 +47,54 @@ class CUHK03(BaseImageDataset):
|
|||
self.split_new_det_mat_path = osp.join(self.dataset_dir, 'cuhk03_new_protocol_config_detected.mat')
|
||||
self.split_new_lab_mat_path = osp.join(self.dataset_dir, 'cuhk03_new_protocol_config_labeled.mat')
|
||||
|
||||
self._check_before_run()
|
||||
self._preprocess()
|
||||
required_files = [
|
||||
self.dataset_dir,
|
||||
self.data_dir,
|
||||
self.raw_mat_path,
|
||||
self.split_new_det_mat_path,
|
||||
self.split_new_lab_mat_path
|
||||
]
|
||||
self.check_before_run(required_files)
|
||||
|
||||
self.preprocess_split()
|
||||
|
||||
if cuhk03_labeled:
|
||||
image_type = 'labeled'
|
||||
split_path = self.split_classic_lab_json_path if cuhk03_classic_split else self.split_new_lab_json_path
|
||||
else:
|
||||
image_type = 'detected'
|
||||
split_path = self.split_classic_det_json_path if cuhk03_classic_split else self.split_new_det_json_path
|
||||
|
||||
splits = read_json(split_path)
|
||||
assert split_id < len(splits), "Condition split_id ({}) < len(splits) ({}) is false".format(split_id,
|
||||
assert split_id < len(splits), 'Condition split_id ({}) < len(splits) ({}) is false'.format(split_id,
|
||||
len(splits))
|
||||
split = splits[split_id]
|
||||
print("Split index = {}".format(split_id))
|
||||
|
||||
train = split['train']
|
||||
query = split['query']
|
||||
gallery = split['gallery']
|
||||
|
||||
if verbose:
|
||||
print("=> CUHK03 ({}) loaded".format(image_type))
|
||||
self.print_dataset_statistics(train, query, gallery)
|
||||
super(CUHK03, self).__init__(train, query, gallery, **kwargs)
|
||||
|
||||
self.train = train
|
||||
self.query = query
|
||||
self.gallery = gallery
|
||||
|
||||
self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)
|
||||
self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)
|
||||
self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)
|
||||
|
||||
def _check_before_run(self):
|
||||
"""Check if all files are available before going deeper"""
|
||||
if not osp.exists(self.dataset_dir):
|
||||
raise RuntimeError("'{}' is not available".format(self.dataset_dir))
|
||||
if not osp.exists(self.data_dir):
|
||||
raise RuntimeError("'{}' is not available".format(self.data_dir))
|
||||
if not osp.exists(self.raw_mat_path):
|
||||
raise RuntimeError("'{}' is not available".format(self.raw_mat_path))
|
||||
if not osp.exists(self.split_new_det_mat_path):
|
||||
raise RuntimeError("'{}' is not available".format(self.split_new_det_mat_path))
|
||||
if not osp.exists(self.split_new_lab_mat_path):
|
||||
raise RuntimeError("'{}' is not available".format(self.split_new_lab_mat_path))
|
||||
|
||||
def _preprocess(self):
|
||||
"""
|
||||
This function is a bit complex and ugly, what it does is
|
||||
1. Extract data from cuhk-03.mat and save as png images.
|
||||
2. Create 20 classic splits. (Li et al. CVPR'14)
|
||||
3. Create new split. (Zhong et al. CVPR'17)
|
||||
"""
|
||||
print(
|
||||
"Note: if root path is changed, the previously generated json files need to be re-generated (delete them first)")
|
||||
if osp.exists(self.imgs_labeled_dir) and \
|
||||
osp.exists(self.imgs_detected_dir) and \
|
||||
osp.exists(self.split_classic_det_json_path) and \
|
||||
osp.exists(self.split_classic_lab_json_path) and \
|
||||
osp.exists(self.split_new_det_json_path) and \
|
||||
osp.exists(self.split_new_lab_json_path):
|
||||
def preprocess_split(self):
|
||||
# This function is a bit complex and ugly, what it does is
|
||||
# 1. extract data from cuhk-03.mat and save as png images
|
||||
# 2. create 20 classic splits (Li et al. CVPR'14)
|
||||
# 3. create new split (Zhong et al. CVPR'17)
|
||||
if osp.exists(self.imgs_labeled_dir) \
|
||||
and osp.exists(self.imgs_detected_dir) \
|
||||
and osp.exists(self.split_classic_det_json_path) \
|
||||
and osp.exists(self.split_classic_lab_json_path) \
|
||||
and osp.exists(self.split_new_det_json_path) \
|
||||
and osp.exists(self.split_new_lab_json_path):
|
||||
return
|
||||
|
||||
import h5py
|
||||
from imageio import imwrite
|
||||
from scipy.io import loadmat
|
||||
|
||||
mkdir_if_missing(self.imgs_detected_dir)
|
||||
mkdir_if_missing(self.imgs_labeled_dir)
|
||||
|
||||
print("Extract image data from {} and save as png".format(self.raw_mat_path))
|
||||
print('Extract image data from "{}" and save as png'.format(self.raw_mat_path))
|
||||
mat = h5py.File(self.raw_mat_path, 'r')
|
||||
|
||||
def _deref(ref):
|
||||
|
@ -126,8 +104,8 @@ class CUHK03(BaseImageDataset):
|
|||
img_paths = [] # Note: some persons only have images for one view
|
||||
for imgid, img_ref in enumerate(img_refs):
|
||||
img = _deref(img_ref)
|
||||
# skip empty cell
|
||||
if img.size == 0 or img.ndim < 3: continue
|
||||
if img.size == 0 or img.ndim < 3:
|
||||
continue # skip empty cell
|
||||
# images are saved with the following format, index-1 (ensure uniqueness)
|
||||
# campid: index of camera pair (1-5)
|
||||
# pid: index of person in 'campid'-th camera pair
|
||||
|
@ -137,22 +115,22 @@ class CUHK03(BaseImageDataset):
|
|||
img_name = '{:01d}_{:03d}_{:01d}_{:02d}.png'.format(campid + 1, pid + 1, viewid, imgid + 1)
|
||||
img_path = osp.join(save_dir, img_name)
|
||||
if not osp.isfile(img_path):
|
||||
imsave(img_path, img)
|
||||
imwrite(img_path, img)
|
||||
img_paths.append(img_path)
|
||||
return img_paths
|
||||
|
||||
def _extract_img(name):
|
||||
print("Processing {} images (extract and save) ...".format(name))
|
||||
def _extract_img(image_type):
|
||||
print('Processing {} images ...'.format(image_type))
|
||||
meta_data = []
|
||||
imgs_dir = self.imgs_detected_dir if name == 'detected' else self.imgs_labeled_dir
|
||||
for campid, camp_ref in enumerate(mat[name][0]):
|
||||
imgs_dir = self.imgs_detected_dir if image_type == 'detected' else self.imgs_labeled_dir
|
||||
for campid, camp_ref in enumerate(mat[image_type][0]):
|
||||
camp = _deref(camp_ref)
|
||||
num_pids = camp.shape[0]
|
||||
for pid in range(num_pids):
|
||||
img_paths = _process_images(camp[pid, :], campid, pid, imgs_dir)
|
||||
assert len(img_paths) > 0, "campid{}-pid{} has no images".format(campid, pid)
|
||||
assert len(img_paths) > 0, 'campid{}-pid{} has no images'.format(campid, pid)
|
||||
meta_data.append((campid + 1, pid + 1, img_paths))
|
||||
print("- done camera pair {} with {} identities".format(campid + 1, num_pids))
|
||||
print('- done camera pair {} with {} identities'.format(campid + 1, num_pids))
|
||||
return meta_data
|
||||
|
||||
meta_detected = _extract_img('detected')
|
||||
|
@ -178,7 +156,7 @@ class CUHK03(BaseImageDataset):
|
|||
num_train_imgs += len(img_paths)
|
||||
return train, num_train_pids, num_train_imgs, test, num_test_pids, num_test_imgs
|
||||
|
||||
print("Creating classic splits (# = 20) ...")
|
||||
print('Creating classic splits (# = 20) ...')
|
||||
splits_classic_det, splits_classic_lab = [], []
|
||||
for split_ref in mat['testsets'][0]:
|
||||
test_split = _deref(split_ref).tolist()
|
||||
|
@ -187,20 +165,30 @@ class CUHK03(BaseImageDataset):
|
|||
train, num_train_pids, num_train_imgs, test, num_test_pids, num_test_imgs = \
|
||||
_extract_classic_split(meta_detected, test_split)
|
||||
splits_classic_det.append({
|
||||
'train': train, 'query': test, 'gallery': test,
|
||||
'num_train_pids': num_train_pids, 'num_train_imgs': num_train_imgs,
|
||||
'num_query_pids': num_test_pids, 'num_query_imgs': num_test_imgs,
|
||||
'num_gallery_pids': num_test_pids, 'num_gallery_imgs': num_test_imgs,
|
||||
'train': train,
|
||||
'query': test,
|
||||
'gallery': test,
|
||||
'num_train_pids': num_train_pids,
|
||||
'num_train_imgs': num_train_imgs,
|
||||
'num_query_pids': num_test_pids,
|
||||
'num_query_imgs': num_test_imgs,
|
||||
'num_gallery_pids': num_test_pids,
|
||||
'num_gallery_imgs': num_test_imgs
|
||||
})
|
||||
|
||||
# create split for labeled images
|
||||
train, num_train_pids, num_train_imgs, test, num_test_pids, num_test_imgs = \
|
||||
_extract_classic_split(meta_labeled, test_split)
|
||||
splits_classic_lab.append({
|
||||
'train': train, 'query': test, 'gallery': test,
|
||||
'num_train_pids': num_train_pids, 'num_train_imgs': num_train_imgs,
|
||||
'num_query_pids': num_test_pids, 'num_query_imgs': num_test_imgs,
|
||||
'num_gallery_pids': num_test_pids, 'num_gallery_imgs': num_test_imgs,
|
||||
'train': train,
|
||||
'query': test,
|
||||
'gallery': test,
|
||||
'num_train_pids': num_train_pids,
|
||||
'num_train_imgs': num_train_imgs,
|
||||
'num_query_pids': num_test_pids,
|
||||
'num_query_imgs': num_test_imgs,
|
||||
'num_gallery_pids': num_test_pids,
|
||||
'num_gallery_imgs': num_test_imgs
|
||||
})
|
||||
|
||||
write_json(splits_classic_det, self.split_classic_det_json_path)
|
||||
|
@ -213,7 +201,8 @@ class CUHK03(BaseImageDataset):
|
|||
img_name = filelist[idx][0]
|
||||
camid = int(img_name.split('_')[2]) - 1 # make it 0-based
|
||||
pid = pids[idx]
|
||||
if relabel: pid = pid2label[pid]
|
||||
if relabel:
|
||||
pid = pid2label[pid]
|
||||
img_path = osp.join(img_dir, img_name)
|
||||
tmp_set.append((img_path, int(pid), camid))
|
||||
unique_pids.add(pid)
|
||||
|
@ -232,28 +221,39 @@ class CUHK03(BaseImageDataset):
|
|||
gallery_info = _extract_set(filelist, pids, pid2label, gallery_idxs, img_dir, relabel=False)
|
||||
return train_info, query_info, gallery_info
|
||||
|
||||
print("Creating new splits for detected images (767/700) ...")
|
||||
print('Creating new split for detected images (767/700) ...')
|
||||
train_info, query_info, gallery_info = _extract_new_split(
|
||||
loadmat(self.split_new_det_mat_path),
|
||||
self.imgs_detected_dir,
|
||||
self.imgs_detected_dir
|
||||
)
|
||||
splits = [{
|
||||
'train': train_info[0], 'query': query_info[0], 'gallery': gallery_info[0],
|
||||
'num_train_pids': train_info[1], 'num_train_imgs': train_info[2],
|
||||
'num_query_pids': query_info[1], 'num_query_imgs': query_info[2],
|
||||
'num_gallery_pids': gallery_info[1], 'num_gallery_imgs': gallery_info[2],
|
||||
split = [{
|
||||
'train': train_info[0],
|
||||
'query': query_info[0],
|
||||
'gallery': gallery_info[0],
|
||||
'num_train_pids': train_info[1],
|
||||
'num_train_imgs': train_info[2],
|
||||
'num_query_pids': query_info[1],
|
||||
'num_query_imgs': query_info[2],
|
||||
'num_gallery_pids': gallery_info[1],
|
||||
'num_gallery_imgs': gallery_info[2]
|
||||
}]
|
||||
write_json(splits, self.split_new_det_json_path)
|
||||
write_json(split, self.split_new_det_json_path)
|
||||
|
||||
print("Creating new splits for labeled images (767/700) ...")
|
||||
print('Creating new split for labeled images (767/700) ...')
|
||||
train_info, query_info, gallery_info = _extract_new_split(
|
||||
loadmat(self.split_new_lab_mat_path),
|
||||
self.imgs_labeled_dir,
|
||||
self.imgs_labeled_dir
|
||||
)
|
||||
splits = [{
|
||||
'train': train_info[0], 'query': query_info[0], 'gallery': gallery_info[0],
|
||||
'num_train_pids': train_info[1], 'num_train_imgs': train_info[2],
|
||||
'num_query_pids': query_info[1], 'num_query_imgs': query_info[2],
|
||||
'num_gallery_pids': gallery_info[1], 'num_gallery_imgs': gallery_info[2],
|
||||
split = [{
|
||||
'train': train_info[0],
|
||||
'query': query_info[0],
|
||||
'gallery': gallery_info[0],
|
||||
'num_train_pids': train_info[1],
|
||||
'num_train_imgs': train_info[2],
|
||||
'num_query_pids': query_info[1],
|
||||
'num_query_imgs': query_info[2],
|
||||
'num_gallery_pids': gallery_info[1],
|
||||
'num_gallery_imgs': gallery_info[2]
|
||||
}]
|
||||
write_json(splits, self.split_new_lab_json_path)
|
||||
write_json(split, self.split_new_lab_json_path)
|
||||
|
||||
|
|
|
@ -33,16 +33,20 @@ class ImageDataset(Dataset):
|
|||
"""Image Person ReID Dataset"""
|
||||
|
||||
def __init__(self, img_items, transform=None, relabel=True):
|
||||
self.img_items,self.tfms,self.relabel = img_items,transform,relabel
|
||||
self.tfms,self.relabel = transform,relabel
|
||||
|
||||
self.pid2label = None
|
||||
if self.relabel:
|
||||
self.img_items = []
|
||||
pids = set()
|
||||
for i, item in enumerate(self.img_items):
|
||||
pid = self.get_pids(item[0]) # path
|
||||
self.img_items[i][1] = pid # replace pid
|
||||
for i, item in enumerate(img_items):
|
||||
pid = self.get_pids(item[0], item[1]) # path
|
||||
self.img_items.append((item[0], pid, item[2])) # replace pid
|
||||
pids.add(pid)
|
||||
self.pids = pids
|
||||
self.pid2label = dict([(p, i) for i, p in enumerate(self.pids)])
|
||||
else:
|
||||
self.img_items = img_items
|
||||
|
||||
@property
|
||||
def c(self):
|
||||
|
@ -59,14 +63,9 @@ class ImageDataset(Dataset):
|
|||
if self.relabel: pid = self.pid2label[pid]
|
||||
return img, pid, camid
|
||||
|
||||
def get_pids(self, file_path):
|
||||
def get_pids(self, file_path, pid):
|
||||
""" Suitable for muilti-dataset training """
|
||||
if 'cuhk03' in file_path:
|
||||
prefix = 'cuhk'
|
||||
pid = '_'.join(file_path.split('/')[-1].split('_')[0:2])
|
||||
else:
|
||||
prefix = file_path.split('/')[1]
|
||||
pat = re.compile(r'([-\d]+)_c(\d)')
|
||||
pid, _ = pat.search(file_path).groups()
|
||||
return prefix + '_' + pid
|
||||
if 'cuhk03' in file_path: prefix = 'cuhk'
|
||||
else: prefix = file_path.split('/')[1]
|
||||
return prefix + '_' + str(pid)
|
||||
|
||||
|
|
|
@ -5,87 +5,52 @@
|
|||
"""
|
||||
|
||||
import glob
|
||||
import re
|
||||
import urllib
|
||||
import zipfile
|
||||
|
||||
import os.path as osp
|
||||
import re
|
||||
|
||||
from utils.iotools import mkdir_if_missing
|
||||
from .bases import BaseImageDataset
|
||||
from .bases import ImageDataset
|
||||
|
||||
|
||||
class DukeMTMCreID(BaseImageDataset):
|
||||
"""
|
||||
DukeMTMC-reID
|
||||
class DukeMTMCreID(ImageDataset):
|
||||
"""DukeMTMC-reID.
|
||||
|
||||
Reference:
|
||||
1. Ristani et al. Performance Measures and a Data Set for Multi-Target, Multi-Camera Tracking. ECCVW 2016.
|
||||
2. Zheng et al. Unlabeled Samples Generated by GAN Improve the Person Re-identification Baseline in vitro. ICCV 2017.
|
||||
URL: https://github.com/layumi/DukeMTMC-reID_evaluation
|
||||
- Ristani et al. Performance Measures and a Data Set for Multi-Target, Multi-Camera Tracking. ECCVW 2016.
|
||||
- Zheng et al. Unlabeled Samples Generated by GAN Improve the Person Re-identification Baseline in vitro. ICCV 2017.
|
||||
|
||||
URL: `<https://github.com/layumi/DukeMTMC-reID_evaluation>`_
|
||||
|
||||
Dataset statistics:
|
||||
# identities: 1404 (train + query)
|
||||
# images:16522 (train) + 2228 (query) + 17661 (gallery)
|
||||
# cameras: 8
|
||||
- identities: 1404 (train + query).
|
||||
- images:16522 (train) + 2228 (query) + 17661 (gallery).
|
||||
- cameras: 8.
|
||||
"""
|
||||
dataset_dir = 'dukemtmc-reid'
|
||||
dataset_dir = 'DukeMTMC-reID'
|
||||
dataset_url = 'http://vision.cs.duke.edu/DukeMTMC/data/misc/DukeMTMC-reID.zip'
|
||||
|
||||
def __init__(self, root='/export/home/lxy/DATA/reid', verbose=True, **kwargs):
|
||||
super(DukeMTMCreID, self).__init__()
|
||||
self.dataset_dir = osp.join(root, self.dataset_dir)
|
||||
self.dataset_url = 'http://vision.cs.duke.edu/DukeMTMC/data/misc/DukeMTMC-reID.zip'
|
||||
self.train_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/bounding_box_train')
|
||||
self.query_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/query')
|
||||
self.gallery_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/bounding_box_test')
|
||||
def __init__(self, root='datasets', **kwargs):
|
||||
# self.root = osp.abspath(osp.expanduser(root))
|
||||
self.root = root
|
||||
self.dataset_dir = osp.join(self.root, self.dataset_dir)
|
||||
self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train')
|
||||
self.query_dir = osp.join(self.dataset_dir, 'query')
|
||||
self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test')
|
||||
|
||||
self._download_data()
|
||||
self._check_before_run()
|
||||
required_files = [
|
||||
self.dataset_dir,
|
||||
self.train_dir,
|
||||
self.query_dir,
|
||||
self.gallery_dir
|
||||
]
|
||||
self.check_before_run(required_files)
|
||||
|
||||
train = self._process_dir(self.train_dir, relabel=True)
|
||||
query = self._process_dir(self.query_dir, relabel=False)
|
||||
gallery = self._process_dir(self.gallery_dir, relabel=False)
|
||||
train = self.process_dir(self.train_dir, relabel=True)
|
||||
query = self.process_dir(self.query_dir, relabel=False)
|
||||
gallery = self.process_dir(self.gallery_dir, relabel=False)
|
||||
|
||||
if verbose:
|
||||
print("=> DukeMTMC-reID loaded")
|
||||
self.print_dataset_statistics(train, query, gallery)
|
||||
super(DukeMTMCreID, self).__init__(train, query, gallery, **kwargs)
|
||||
|
||||
self.train = train
|
||||
self.query = query
|
||||
self.gallery = gallery
|
||||
|
||||
self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)
|
||||
self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)
|
||||
self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)
|
||||
|
||||
def _download_data(self):
|
||||
if osp.exists(self.dataset_dir):
|
||||
print("This dataset has been downloaded.")
|
||||
return
|
||||
|
||||
print("Creating directory {}".format(self.dataset_dir))
|
||||
mkdir_if_missing(self.dataset_dir)
|
||||
fpath = osp.join(self.dataset_dir, osp.basename(self.dataset_url))
|
||||
|
||||
print("Downloading DukeMTMC-reID dataset")
|
||||
urllib.urlretrieve(self.dataset_url, fpath)
|
||||
|
||||
print("Extracting files")
|
||||
zip_ref = zipfile.ZipFile(fpath, 'r')
|
||||
zip_ref.extractall(self.dataset_dir)
|
||||
zip_ref.close()
|
||||
|
||||
def _check_before_run(self):
|
||||
"""Check if all files are available before going deeper"""
|
||||
if not osp.exists(self.dataset_dir):
|
||||
raise RuntimeError("'{}' is not available".format(self.dataset_dir))
|
||||
if not osp.exists(self.train_dir):
|
||||
raise RuntimeError("'{}' is not available".format(self.train_dir))
|
||||
if not osp.exists(self.query_dir):
|
||||
raise RuntimeError("'{}' is not available".format(self.query_dir))
|
||||
if not osp.exists(self.gallery_dir):
|
||||
raise RuntimeError("'{}' is not available".format(self.gallery_dir))
|
||||
|
||||
def _process_dir(self, dir_path, relabel=False):
|
||||
def process_dir(self, dir_path, relabel=False):
|
||||
img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
|
||||
pattern = re.compile(r'([-\d]+)_c(\d)')
|
||||
|
||||
|
@ -95,12 +60,12 @@ class DukeMTMCreID(BaseImageDataset):
|
|||
pid_container.add(pid)
|
||||
pid2label = {pid: label for label, pid in enumerate(pid_container)}
|
||||
|
||||
dataset = []
|
||||
data = []
|
||||
for img_path in img_paths:
|
||||
pid, camid = map(int, pattern.search(img_path).groups())
|
||||
assert 1 <= camid <= 8
|
||||
camid -= 1 # index starts from 0
|
||||
if relabel: pid = pid2label[pid]
|
||||
dataset.append((img_path, pid, camid))
|
||||
data.append((img_path, pid, camid))
|
||||
|
||||
return dataset
|
||||
return data
|
||||
|
|
|
@ -3,19 +3,18 @@
|
|||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
import numpy as np
|
||||
import copy
|
||||
from collections import defaultdict
|
||||
import sys
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
from csrc.eval_cylib.eval_metrics_cy import evaluate_cy
|
||||
IS_CYTHON_AVAI = True
|
||||
print("Using Cython evaluation code as the backend")
|
||||
# print("Using Cython evaluation code as the backend")
|
||||
except ImportError:
|
||||
IS_CYTHON_AVAI = False
|
||||
warnings.warn("Cython evaluation is UNAVAILABLE, which is highly recommended")
|
||||
# warnings.warn("Cython evaluation is UNAVAILABLE, which is highly recommended")
|
||||
|
||||
|
||||
def eval_cuhk03(distmat, q_pids, g_pids, q_camids, g_camids, max_rank):
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.backends import cudnn
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
def eval_roc(distmat, q_pids, g_pids, q_cmaids, g_camids, t_start=0.1, t_end=0.9):
|
||||
# sort cosine dist from large to small
|
||||
indices = np.argsort(distmat, axis=1)[:, ::-1]
|
||||
# query id and gallery id match
|
||||
matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)
|
||||
|
||||
new_dist = []
|
||||
new_matches = []
|
||||
# Remove the same identity in the same camera.
|
||||
num_q = distmat.shape[0]
|
||||
for q_idx in range(num_q):
|
||||
q_pid = q_pids[q_idx]
|
||||
q_camid = q_cmaids[q_idx]
|
||||
|
||||
order = indices[q_idx]
|
||||
remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid)
|
||||
keep = np.invert(remove)
|
||||
new_matches.extend(matches[q_idx][keep].tolist())
|
||||
new_dist.extend(distmat[q_idx][indices[q_idx]][keep].tolist())
|
||||
|
||||
fpr = []
|
||||
tpr = []
|
||||
fps = []
|
||||
tps = []
|
||||
thresholds = np.arange(t_start, t_end, 0.02)
|
||||
|
||||
# get number of positive and negative examples in the dataset
|
||||
p = sum(new_matches)
|
||||
n = len(new_matches) - p
|
||||
|
||||
# iteration through all thresholds and determine fraction of true positives
|
||||
# and false positives found at this threshold
|
||||
for t in thresholds:
|
||||
fp = 0
|
||||
tp = 0
|
||||
for i in range(len(new_dist)):
|
||||
if new_dist[i] > t:
|
||||
if new_matches[i] == 1:
|
||||
tp += 1
|
||||
else:
|
||||
fp += 1
|
||||
fpr.append(fp / float(n))
|
||||
tpr.append(tp / float(p))
|
||||
fps.append(fp)
|
||||
tps.append(tp)
|
||||
return fpr, tpr, fps, tps, p, n, thresholds
|
|
@ -9,77 +9,88 @@ import re
|
|||
|
||||
import os.path as osp
|
||||
|
||||
from .bases import BaseImageDataset
|
||||
from .bases import ImageDataset
|
||||
import warnings
|
||||
|
||||
|
||||
class Market1501(BaseImageDataset):
|
||||
"""
|
||||
Market1501
|
||||
class Market1501(ImageDataset):
|
||||
"""Market1501.
|
||||
|
||||
Reference:
|
||||
Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015.
|
||||
URL: http://www.liangzheng.org/Project/project_reid.html
|
||||
Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015.
|
||||
|
||||
URL: `<http://www.liangzheng.org/Project/project_reid.html>`_
|
||||
|
||||
Dataset statistics:
|
||||
# identities: 1501 (+1 for background)
|
||||
# images: 12936 (train) + 3368 (query) + 15913 (gallery)
|
||||
- identities: 1501 (+1 for background).
|
||||
- images: 12936 (train) + 3368 (query) + 15913 (gallery).
|
||||
"""
|
||||
dataset_dir = 'market1501'
|
||||
_junk_pids = [0, -1]
|
||||
dataset_dir = ''
|
||||
dataset_url = 'http://188.138.127.15:81/Datasets/Market-1501-v15.09.15.zip'
|
||||
|
||||
def __init__(self, root='/export/home/lxy/DATA/reid', verbose=True, **kwargs):
|
||||
super(Market1501, self).__init__()
|
||||
self.dataset_dir = osp.join(root, self.dataset_dir)
|
||||
self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train')
|
||||
self.query_dir = osp.join(self.dataset_dir, 'query')
|
||||
self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test')
|
||||
def __init__(self, root='datasets', market1501_500k=False, **kwargs):
|
||||
# self.root = osp.abspath(osp.expanduser(root))
|
||||
self.root = root
|
||||
self.dataset_dir = osp.join(self.root, self.dataset_dir)
|
||||
|
||||
self._check_before_run()
|
||||
# allow alternative directory structure
|
||||
self.data_dir = self.dataset_dir
|
||||
data_dir = osp.join(self.data_dir, 'Market-1501-v15.09.15')
|
||||
if osp.isdir(data_dir):
|
||||
self.data_dir = data_dir
|
||||
else:
|
||||
warnings.warn('The current data structure is deprecated. Please '
|
||||
'put data folders such as "bounding_box_train" under '
|
||||
'"Market-1501-v15.09.15".')
|
||||
|
||||
train = self._process_dir(self.train_dir, relabel=True)
|
||||
query = self._process_dir(self.query_dir, relabel=False)
|
||||
gallery = self._process_dir(self.gallery_dir, relabel=False)
|
||||
self.train_dir = osp.join(self.data_dir, 'bounding_box_train')
|
||||
self.query_dir = osp.join(self.data_dir, 'query')
|
||||
self.gallery_dir = osp.join(self.data_dir, 'bounding_box_test')
|
||||
self.extra_gallery_dir = osp.join(self.data_dir, 'images')
|
||||
self.market1501_500k = market1501_500k
|
||||
|
||||
if verbose:
|
||||
print("=> Market1501 loaded")
|
||||
self.print_dataset_statistics(train, query, gallery)
|
||||
required_files = [
|
||||
self.data_dir,
|
||||
self.train_dir,
|
||||
self.query_dir,
|
||||
self.gallery_dir
|
||||
]
|
||||
if self.market1501_500k:
|
||||
required_files.append(self.extra_gallery_dir)
|
||||
self.check_before_run(required_files)
|
||||
|
||||
self.train = train
|
||||
self.query = query
|
||||
self.gallery = gallery
|
||||
train = self.process_dir(self.train_dir, relabel=True)
|
||||
query = self.process_dir(self.query_dir, relabel=False)
|
||||
gallery = self.process_dir(self.gallery_dir, relabel=False)
|
||||
if self.market1501_500k:
|
||||
gallery += self.process_dir(self.extra_gallery_dir, relabel=False)
|
||||
|
||||
self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)
|
||||
self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)
|
||||
self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)
|
||||
super(Market1501, self).__init__(train, query, gallery, **kwargs)
|
||||
|
||||
def _check_before_run(self):
|
||||
"""Check if all files are available before going deeper"""
|
||||
if not osp.exists(self.dataset_dir):
|
||||
raise RuntimeError("'{}' is not available".format(self.dataset_dir))
|
||||
if not osp.exists(self.train_dir):
|
||||
raise RuntimeError("'{}' is not available".format(self.train_dir))
|
||||
if not osp.exists(self.query_dir):
|
||||
raise RuntimeError("'{}' is not available".format(self.query_dir))
|
||||
if not osp.exists(self.gallery_dir):
|
||||
raise RuntimeError("'{}' is not available".format(self.gallery_dir))
|
||||
|
||||
def _process_dir(self, dir_path, relabel=False):
|
||||
def process_dir(self, dir_path, relabel=False):
|
||||
img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
|
||||
pattern = re.compile(r'([-\d]+)_c(\d)')
|
||||
|
||||
pid_container = set()
|
||||
for img_path in img_paths:
|
||||
pid, _ = map(int, pattern.search(img_path).groups())
|
||||
if pid == -1: continue # junk images are just ignored
|
||||
if pid == -1:
|
||||
continue # junk images are just ignored
|
||||
pid_container.add(pid)
|
||||
pid2label = {pid: label for label, pid in enumerate(pid_container)}
|
||||
|
||||
dataset = []
|
||||
data = []
|
||||
for img_path in img_paths:
|
||||
pid, camid = map(int, pattern.search(img_path).groups())
|
||||
if pid == -1: continue # junk images are just ignored
|
||||
if pid == -1:
|
||||
continue # junk images are just ignored
|
||||
assert 0 <= pid <= 1501 # pid == 0 means background
|
||||
assert 1 <= camid <= 6
|
||||
camid -= 1 # index starts from 0
|
||||
if relabel: pid = pid2label[pid]
|
||||
dataset.append((img_path, pid, camid))
|
||||
if relabel:
|
||||
pid = pid2label[pid]
|
||||
data.append((img_path, pid, camid))
|
||||
|
||||
return data
|
||||
|
||||
return dataset
|
||||
|
|
|
@ -0,0 +1,99 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: l1aoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import os.path as osp
|
||||
|
||||
from .bases import ImageDataset
|
||||
|
||||
##### Log #####
|
||||
# 22.01.2019
|
||||
# - add v2
|
||||
# - v1 and v2 differ in dir names
|
||||
# - note that faces in v2 are blurred
|
||||
TRAIN_DIR_KEY = 'train_dir'
|
||||
TEST_DIR_KEY = 'test_dir'
|
||||
VERSION_DICT = {
|
||||
'MSMT17_V1': {
|
||||
TRAIN_DIR_KEY: 'train',
|
||||
TEST_DIR_KEY: 'test',
|
||||
},
|
||||
'MSMT17_V2': {
|
||||
TRAIN_DIR_KEY: 'mask_train_v2',
|
||||
TEST_DIR_KEY: 'mask_test_v2',
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class MSMT17(ImageDataset):
|
||||
"""MSMT17.
|
||||
Reference:
|
||||
Wei et al. Person Transfer GAN to Bridge Domain Gap for Person Re-Identification. CVPR 2018.
|
||||
URL: `<http://www.pkuvmc.com/publications/msmt17.html>`_
|
||||
|
||||
Dataset statistics:
|
||||
- identities: 4101.
|
||||
- images: 32621 (train) + 11659 (query) + 82161 (gallery).
|
||||
- cameras: 15.
|
||||
"""
|
||||
# dataset_dir = 'MSMT17_V2'
|
||||
dataset_url = None
|
||||
|
||||
def __init__(self, root='datasets', **kwargs):
|
||||
# self.root = osp.abspath(osp.expanduser(root))
|
||||
self.root = root
|
||||
self.dataset_dir = self.root
|
||||
|
||||
has_main_dir = False
|
||||
for main_dir in VERSION_DICT:
|
||||
if osp.exists(osp.join(self.dataset_dir, main_dir)):
|
||||
train_dir = VERSION_DICT[main_dir][TRAIN_DIR_KEY]
|
||||
test_dir = VERSION_DICT[main_dir][TEST_DIR_KEY]
|
||||
has_main_dir = True
|
||||
break
|
||||
assert has_main_dir, 'Dataset folder not found'
|
||||
|
||||
self.train_dir = osp.join(self.dataset_dir, main_dir, train_dir)
|
||||
self.test_dir = osp.join(self.dataset_dir, main_dir, test_dir)
|
||||
self.list_train_path = osp.join(self.dataset_dir, main_dir, 'list_train.txt')
|
||||
self.list_val_path = osp.join(self.dataset_dir, main_dir, 'list_val.txt')
|
||||
self.list_query_path = osp.join(self.dataset_dir, main_dir, 'list_query.txt')
|
||||
self.list_gallery_path = osp.join(self.dataset_dir, main_dir, 'list_gallery.txt')
|
||||
|
||||
required_files = [
|
||||
self.dataset_dir,
|
||||
self.train_dir,
|
||||
self.test_dir
|
||||
]
|
||||
self.check_before_run(required_files)
|
||||
|
||||
train = self.process_dir(self.train_dir, self.list_train_path)
|
||||
val = self.process_dir(self.train_dir, self.list_val_path)
|
||||
query = self.process_dir(self.test_dir, self.list_query_path)
|
||||
gallery = self.process_dir(self.test_dir, self.list_gallery_path)
|
||||
|
||||
# Note: to fairly compare with published methods on the conventional ReID setting,
|
||||
# do not add val images to the training set.
|
||||
if 'combineall' in kwargs and kwargs['combineall']:
|
||||
train += val
|
||||
|
||||
super(MSMT17, self).__init__(train, query, gallery, **kwargs)
|
||||
|
||||
def process_dir(self, dir_path, list_path):
|
||||
with open(list_path, 'r') as txt:
|
||||
lines = txt.readlines()
|
||||
|
||||
data = []
|
||||
|
||||
for img_idx, img_info in enumerate(lines):
|
||||
img_path, pid = img_info.split(' ')
|
||||
pid = int(pid) # no need to relabel
|
||||
camid = int(img_path.split('_')[2]) - 1 # index starts from 0
|
||||
img_path = osp.join(dir_path, img_path)
|
||||
data.append((img_path, pid, camid))
|
||||
|
||||
return data
|
|
@ -0,0 +1,63 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
import torch
|
||||
|
||||
|
||||
class data_prefetcher():
|
||||
def __init__(self, loader):
|
||||
self.loader = iter(loader)
|
||||
self.stream = torch.cuda.Stream()
|
||||
self.mean = torch.tensor([0.485*255, 0.456*255, 0.406*255]).cuda().view(1,3,1,1)
|
||||
self.std = torch.tensor([0.229*255, 0.224*255, 0.225*255]).cuda().view(1,3,1,1)
|
||||
# With Amp, it isn't necessary to manually convert data to half.
|
||||
# if args.fp16:
|
||||
# self.mean = self.mean.half()
|
||||
# self.std = self.std.half()
|
||||
self.preload()
|
||||
|
||||
def preload(self):
|
||||
try:
|
||||
self.next_input, self.next_target, self.next_camid = next(self.loader)
|
||||
except StopIteration:
|
||||
self.next_input = None
|
||||
self.next_target = None
|
||||
self.next_camid = None
|
||||
return
|
||||
# if record_stream() doesn't work, another option is to make sure device inputs are created
|
||||
# on the main stream.
|
||||
# self.next_input_gpu = torch.empty_like(self.next_input, device='cuda')
|
||||
# self.next_target_gpu = torch.empty_like(self.next_target, device='cuda')
|
||||
# Need to make sure the memory allocated for next_* is not still in use by the main stream
|
||||
# at the time we start copying to next_*:
|
||||
# self.stream.wait_stream(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(self.stream):
|
||||
self.next_input = self.next_input.cuda(non_blocking=True)
|
||||
self.next_target = self.next_target.cuda(non_blocking=True)
|
||||
# more code for the alternative if record_stream() doesn't work:
|
||||
# copy_ will record the use of the pinned source tensor in this side stream.
|
||||
# self.next_input_gpu.copy_(self.next_input, non_blocking=True)
|
||||
# self.next_target_gpu.copy_(self.next_target, non_blocking=True)
|
||||
# self.next_input = self.next_input_gpu
|
||||
# self.next_target = self.next_target_gpu
|
||||
|
||||
# With Amp, it isn't necessary to manually convert data to half.
|
||||
# if args.fp16:
|
||||
# self.next_input = self.next_input.half()
|
||||
# else:
|
||||
self.next_input = self.next_input.float()
|
||||
self.next_input = self.next_input.sub_(self.mean).div_(self.std)
|
||||
|
||||
def next(self):
|
||||
torch.cuda.current_stream().wait_stream(self.stream)
|
||||
input = self.next_input
|
||||
target = self.next_target
|
||||
camid = self.next_camid
|
||||
if input is not None:
|
||||
input.record_stream(torch.cuda.current_stream())
|
||||
if target is not None:
|
||||
target.record_stream(torch.cuda.current_stream())
|
||||
self.preload()
|
||||
return input, target, camid
|
|
@ -4,12 +4,13 @@
|
|||
@contact: liaoxingyu2@jd.com
|
||||
"""
|
||||
|
||||
import copy
|
||||
import random
|
||||
import re
|
||||
from collections import defaultdict
|
||||
|
||||
import random
|
||||
import copy
|
||||
import numpy as np
|
||||
import re
|
||||
import torch
|
||||
from torch.utils.data.sampler import Sampler
|
||||
|
||||
|
||||
|
@ -22,9 +23,7 @@ class RandomIdentitySampler(Sampler):
|
|||
- num_instances (int): number of instances per identity in a batch.
|
||||
- batch_size (int): number of examples in a batch.
|
||||
"""
|
||||
|
||||
def __init__(self, data_source, batch_size, num_instances):
|
||||
pat = re.compile(r'([-\d]+)_c(\d)')
|
||||
|
||||
self.data_source = data_source
|
||||
self.batch_size = batch_size
|
||||
|
@ -32,14 +31,7 @@ class RandomIdentitySampler(Sampler):
|
|||
self.num_pids_per_batch = self.batch_size // self.num_instances
|
||||
self.index_dic = defaultdict(list)
|
||||
for index, info in enumerate(self.data_source):
|
||||
fname = info[0]
|
||||
prefix = fname.split('/')[1]
|
||||
try:
|
||||
pid, _ = pat.search(fname).groups()
|
||||
except: # cuhk03
|
||||
prefix = fname.split('/')[4]
|
||||
pid = '_'.join(fname.split('/')[-1].split('_')[:2])
|
||||
pid = prefix + '_' + pid
|
||||
pid = info[1]
|
||||
self.index_dic[pid].append(index)
|
||||
self.pids = list(self.index_dic.keys())
|
||||
|
||||
|
@ -78,14 +70,37 @@ class RandomIdentitySampler(Sampler):
|
|||
if len(batch_idxs_dict[pid]) == 0:
|
||||
avai_pids.remove(pid)
|
||||
|
||||
if len(final_idxs) > self.length:
|
||||
final_idxs = final_idxs[:self.length]
|
||||
elif len(final_idxs) < self.length:
|
||||
cycle = self.length - len(final_idxs)
|
||||
final_idxs = final_idxs + final_idxs[:cycle]
|
||||
assert len(final_idxs) == self.length, 'sampler length must match'
|
||||
# if len(final_idxs) > self.length:
|
||||
# final_idxs = final_idxs[:self.length]
|
||||
# elif len(final_idxs) < self.length:
|
||||
# cycle = self.length - len(final_idxs)
|
||||
# final_idxs = final_idxs + final_idxs[:cycle]
|
||||
# assert len(final_idxs) == self.length, 'sampler length must match'
|
||||
return iter(final_idxs)
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
# class RandomIdentitySampler(Sampler):
|
||||
# def __init__(self, data_source, num_instances=4):
|
||||
# self.data_source = data_source
|
||||
# self.num_instances = num_instances
|
||||
# self.index_dic = defaultdict(list)
|
||||
# for index, (_, pid) in enumerate(data_source):
|
||||
# self.index_dic[pid].append(index)
|
||||
# self.pids = list(self.index_dic.keys())
|
||||
# self.num_identities = len(self.pids)
|
||||
#
|
||||
# def __iter__(self):
|
||||
# indices = torch.randperm(self.num_identities)
|
||||
# ret = []
|
||||
# for i in indices:
|
||||
# pid = self.pids[i]
|
||||
# t = self.index_dic[pid]
|
||||
# replace = False if len(t) >= self.num_instances else True
|
||||
# t = np.random.choice(t, size=self.num_instances, replace=replace)
|
||||
# ret.extend(t)
|
||||
# return iter(ret)
|
||||
#
|
||||
# def __len__(self):
|
||||
# return self.num_identities * self.num_instances
|
||||
|
|
|
@ -10,20 +10,22 @@ from .transforms import *
|
|||
|
||||
|
||||
def build_transforms(cfg, is_train=True):
|
||||
norm2tensor = [T.ToTensor(), T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD)]
|
||||
res = []
|
||||
if is_train:
|
||||
res.append(T.Resize(cfg.INPUT.SIZE_TRAIN))
|
||||
if cfg.INPUT.DO_FLIP: res.append(T.RandomHorizontalFlip(p=cfg.INPUT.FLIP_PROB))
|
||||
if cfg.INPUT.DO_PAD: res.extend([T.Pad(cfg.INPUT.PADDING, padding_mode=cfg.INPUT.PADDING_MODE),
|
||||
T.RandomCrop(cfg.INPUT.SIZE_TRAIN)])
|
||||
if cfg.INPUT.DO_LIGHTING: res.append(T.ColorJitter(cfg.INPUT.MAX_LIGHTING, cfg.INPUT.MAX_LIGHTING))
|
||||
res.extend(norm2tensor)
|
||||
if cfg.INPUT.DO_RE: res.append(RandomErasing(probability=cfg.INPUT.RE_PROB,
|
||||
mean=cfg.INPUT.PIXEL_MEAN))
|
||||
if cfg.INPUT.DO_FLIP:
|
||||
res.append(T.RandomHorizontalFlip(p=cfg.INPUT.FLIP_PROB))
|
||||
if cfg.INPUT.DO_PAD:
|
||||
res.extend([T.Pad(cfg.INPUT.PADDING, padding_mode=cfg.INPUT.PADDING_MODE),
|
||||
T.RandomCrop(cfg.INPUT.SIZE_TRAIN)])
|
||||
if cfg.INPUT.DO_LIGHTING:
|
||||
res.append(T.ColorJitter(cfg.INPUT.MAX_LIGHTING, cfg.INPUT.MAX_LIGHTING))
|
||||
# res.append(T.ToTensor()) # to slow
|
||||
if cfg.INPUT.DO_RE:
|
||||
res.append(RandomErasing(probability=cfg.INPUT.RE_PROB))
|
||||
else:
|
||||
res.append(T.Resize(cfg.INPUT.SIZE_TEST))
|
||||
res.extend(norm2tensor)
|
||||
# res.append(T.ToTensor())
|
||||
return T.Compose(res)
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,71 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
import random
|
||||
from PIL import Image
|
||||
|
||||
__all__ = ['swap']
|
||||
|
||||
|
||||
def swap(img, crop):
|
||||
def crop_image(image, cropnum):
|
||||
width, high = image.size
|
||||
crop_x = [int((width / cropnum[0]) * i) for i in range(cropnum[0] + 1)]
|
||||
crop_y = [int((high / cropnum[1]) * i) for i in range(cropnum[1] + 1)]
|
||||
im_list = []
|
||||
for j in range(len(crop_y) - 1):
|
||||
for i in range(len(crop_x) - 1):
|
||||
im_list.append(image.crop((crop_x[i], crop_y[j], min(crop_x[i + 1], width), min(crop_y[j + 1], high))))
|
||||
return im_list
|
||||
|
||||
widthcut, highcut = img.size
|
||||
img = img.crop((10, 10, widthcut - 10, highcut - 10))
|
||||
images = crop_image(img, crop)
|
||||
pro = 5
|
||||
if pro >= 5:
|
||||
tmpx = []
|
||||
tmpy = []
|
||||
count_x = 0
|
||||
count_y = 0
|
||||
k = 1
|
||||
RAN = 2
|
||||
for i in range(crop[1] * crop[0]):
|
||||
tmpx.append(images[i])
|
||||
count_x += 1
|
||||
if len(tmpx) >= k:
|
||||
tmp = tmpx[count_x - RAN:count_x]
|
||||
random.shuffle(tmp)
|
||||
tmpx[count_x - RAN:count_x] = tmp
|
||||
if count_x == crop[0]:
|
||||
tmpy.append(tmpx)
|
||||
count_x = 0
|
||||
count_y += 1
|
||||
tmpx = []
|
||||
if len(tmpy) >= k:
|
||||
tmp2 = tmpy[count_y - RAN:count_y]
|
||||
random.shuffle(tmp2)
|
||||
tmpy[count_y - RAN:count_y] = tmp2
|
||||
random_im = []
|
||||
for line in tmpy:
|
||||
random_im.extend(line)
|
||||
|
||||
# random.shuffle(images)
|
||||
width, high = img.size
|
||||
iw = int(width / crop[0])
|
||||
ih = int(high / crop[1])
|
||||
toImage = Image.new('RGB', (iw * crop[0], ih * crop[1]))
|
||||
x = 0
|
||||
y = 0
|
||||
for i in random_im:
|
||||
i = i.resize((iw, ih), Image.ANTIALIAS)
|
||||
toImage.paste(i, (x * iw, y * ih))
|
||||
x += 1
|
||||
if x == crop[0]:
|
||||
x = 0
|
||||
y += 1
|
||||
else:
|
||||
toImage = img
|
||||
toImage = toImage.resize((widthcut, highcut))
|
||||
return toImage
|
|
@ -4,10 +4,14 @@
|
|||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
__all__ = ['RandomErasing', 'Randomswap']
|
||||
|
||||
import math
|
||||
import random
|
||||
|
||||
__all__ = ['RandomErasing', ]
|
||||
import numpy as np
|
||||
|
||||
from .functional import *
|
||||
|
||||
|
||||
class RandomErasing(object):
|
||||
|
@ -22,7 +26,7 @@ class RandomErasing(object):
|
|||
mean: Erasing value.
|
||||
"""
|
||||
|
||||
def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=(0.4914, 0.4822, 0.4465)):
|
||||
def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=(255*0.59606, 255*0.55814, 255*0.49735)):
|
||||
self.probability = probability
|
||||
self.mean = mean
|
||||
self.sl = sl
|
||||
|
@ -30,12 +34,12 @@ class RandomErasing(object):
|
|||
self.r1 = r1
|
||||
|
||||
def __call__(self, img):
|
||||
|
||||
img = np.asarray(img, dtype=np.uint8).copy()
|
||||
if random.uniform(0, 1) > self.probability:
|
||||
return img
|
||||
|
||||
for attempt in range(100):
|
||||
area = img.size()[1] * img.size()[2]
|
||||
area = img.shape[0] * img.shape[1]
|
||||
|
||||
target_area = random.uniform(self.sl, self.sh) * area
|
||||
aspect_ratio = random.uniform(self.r1, 1 / self.r1)
|
||||
|
@ -43,16 +47,31 @@ class RandomErasing(object):
|
|||
h = int(round(math.sqrt(target_area * aspect_ratio)))
|
||||
w = int(round(math.sqrt(target_area / aspect_ratio)))
|
||||
|
||||
if w < img.size()[2] and h < img.size()[1]:
|
||||
x1 = random.randint(0, img.size()[1] - h)
|
||||
y1 = random.randint(0, img.size()[2] - w)
|
||||
if img.size()[0] == 3:
|
||||
img[0, x1:x1 + h, y1:y1 + w] = self.mean[0]
|
||||
img[1, x1:x1 + h, y1:y1 + w] = self.mean[1]
|
||||
img[2, x1:x1 + h, y1:y1 + w] = self.mean[2]
|
||||
if w < img.shape[1] and h < img.shape[0]:
|
||||
x1 = random.randint(0, img.shape[0] - h)
|
||||
y1 = random.randint(0, img.shape[1] - w)
|
||||
if img.shape[2] == 3:
|
||||
img[x1:x1 + h, y1:y1 + w, 0] = self.mean[0]
|
||||
img[x1:x1 + h, y1:y1 + w, 1] = self.mean[1]
|
||||
img[x1:x1 + h, y1:y1 + w, 2] = self.mean[2]
|
||||
else:
|
||||
img[0, x1:x1 + h, y1:y1 + w] = self.mean[0]
|
||||
img[x1:x1 + h, y1:y1 + w, 0] = self.mean[0]
|
||||
return img
|
||||
|
||||
return img
|
||||
|
||||
|
||||
class Randomswap(object):
|
||||
def __init__(self, size):
|
||||
self.size = size
|
||||
if isinstance(size, numbers.Number):
|
||||
self.size = (int(size), int(size))
|
||||
else:
|
||||
assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
|
||||
self.size = size
|
||||
|
||||
def __call__(self, img):
|
||||
return swap(img, self.size)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + '(size={0})'.format(self.size)
|
||||
|
|
|
@ -9,52 +9,54 @@ import torch
|
|||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
from data.datasets.eval_reid import evaluate
|
||||
from fastai.torch_core import to_np
|
||||
from data.prefetcher import data_prefetcher
|
||||
|
||||
|
||||
def inference(
|
||||
cfg,
|
||||
model,
|
||||
data_bunch,
|
||||
tst_loader,
|
||||
test_dataloader,
|
||||
num_query
|
||||
):
|
||||
logger = logging.getLogger("reid_baseline.inference")
|
||||
logger.info("Start inferencing")
|
||||
|
||||
model.eval()
|
||||
feats = []
|
||||
pids = []
|
||||
camids = []
|
||||
for imgs, pid, camid in data_bunch.test_dl:
|
||||
|
||||
feats, pids, camids = [], [], []
|
||||
test_prefetcher = data_prefetcher(test_dataloader)
|
||||
batch = test_prefetcher.next()
|
||||
while batch[0] is not None:
|
||||
img, pid, camid = batch
|
||||
with torch.no_grad():
|
||||
feat = model(imgs.cuda())
|
||||
feat = model(img)
|
||||
feats.append(feat)
|
||||
pids.append(pid)
|
||||
camids.append(camid)
|
||||
pids.extend(pid.cpu().numpy())
|
||||
camids.extend(np.asarray(camid))
|
||||
|
||||
batch = test_prefetcher.next()
|
||||
|
||||
feats = torch.cat(feats, dim=0)
|
||||
if cfg.TEST.NORM:
|
||||
feats = F.normalize(feats, p=2, dim=1)
|
||||
# query
|
||||
qf = feats[:num_query]
|
||||
gf = feats[num_query:]
|
||||
q_pids = np.asarray(pids[:num_query])
|
||||
g_pids = np.asarray(pids[num_query:])
|
||||
q_camids = np.asarray(camids[:num_query])
|
||||
# gallery
|
||||
gf = feats[num_query:]
|
||||
g_pids = np.asarray(pids[num_query:])
|
||||
g_camids = np.asarray(camids[num_query:])
|
||||
|
||||
m, n = qf.shape[0], gf.shape[0]
|
||||
# Cosine distance
|
||||
distmat = torch.mm(F.normalize(qf), F.normalize(gf).t())
|
||||
# cosine distance
|
||||
distmat = torch.mm(qf, gf.t()).cpu().numpy()
|
||||
|
||||
# Euclid distance
|
||||
# distmat = torch.pow(qf,2).sum(dim=1,keepdim=True).expand(m,n) + \
|
||||
# torch.pow(gf,2).sum(dim=1,keepdim=True).expand(n,m).t()
|
||||
# euclidean distance
|
||||
# distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \
|
||||
# torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
|
||||
# distmat.addmm_(1, -2, qf, gf.t())
|
||||
|
||||
distmat = to_np(distmat)
|
||||
|
||||
# Compute CMC and mAP.
|
||||
# distmat = distmat.numpy()
|
||||
cmc, mAP = evaluate(-distmat, q_pids, g_pids, q_camids, g_camids)
|
||||
logger.info('Compute CMC Curve')
|
||||
logger.info("mAP: {:.1%}".format(mAP))
|
||||
logger.info(f"mAP: {mAP:.1%}")
|
||||
for r in [1, 5, 10]:
|
||||
logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1]))
|
||||
logger.info(f"CMC curve, Rank-{r:<3}:{cmc[r - 1]:.1%}")
|
||||
|
|
|
@ -4,17 +4,25 @@
|
|||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from fastai.callbacks import *
|
||||
from fastai.vision import *
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
from collections import namedtuple
|
||||
import matplotlib.pyplot as plt
|
||||
from modeling import build_model
|
||||
from data import get_dataloader
|
||||
from data.prefetcher import data_prefetcher
|
||||
|
||||
|
||||
class ReidInterpretation():
|
||||
"Interpretation methods for reid models."
|
||||
def __init__(self, model, tst_loader, num_q):
|
||||
self.model,self.tst_loader,self.num_q = model,tst_loader,num_q
|
||||
def __init__(self, cfg):
|
||||
self.cfg = cfg
|
||||
self.model = build_model(cfg, 0)
|
||||
self.tng_dataloader, self.val_dataloader, self.num_classes, self.num_query = get_dataloader(cfg)
|
||||
self.model = self.model.cuda()
|
||||
|
||||
self.model.load_params_wo_fc(torch.load(cfg.TEST.WEIGHT))
|
||||
|
||||
self.get_distmat()
|
||||
|
||||
def get_distmat(self):
|
||||
|
@ -22,26 +30,28 @@ class ReidInterpretation():
|
|||
feats = []
|
||||
pids = []
|
||||
camids = []
|
||||
for img, pid, camid in self.tst_loader:
|
||||
val_prefetcher = data_prefetcher(self.val_dataloader)
|
||||
batch = val_prefetcher.next()
|
||||
while batch[0] is not None:
|
||||
img, pid, camid = batch
|
||||
with torch.no_grad():
|
||||
feat = m(img.cuda())
|
||||
feats.append(feat)
|
||||
pids.extend(to_np(pid))
|
||||
camids.extend(to_np(camid))
|
||||
pids.extend(pid.cpu().numpy())
|
||||
camids.extend(np.asarray(camid))
|
||||
feats = torch.cat(feats, dim=0)
|
||||
feats = F.normalize(feats)
|
||||
qf = feats[:self.num_q]
|
||||
gf = feats[self.num_q:]
|
||||
self.q_pids = np.asarray(pids[:self.num_q])
|
||||
self.g_pids = np.asarray(pids[self.num_q:])
|
||||
self.q_camids = np.asarray(camids[:self.num_q])
|
||||
self.g_camids = np.asarray(camids[self.num_q:])
|
||||
|
||||
m, n = qf.shape[0], gf.shape[0]
|
||||
if self.cfg.TEST.NORM:
|
||||
feats = F.normalize(feats)
|
||||
qf = feats[:self.num_query]
|
||||
gf = feats[self.num_query:]
|
||||
self.q_pids = np.asarray(pids[:self.num_query])
|
||||
self.g_pids = np.asarray(pids[self.num_query:])
|
||||
self.q_camids = np.asarray(camids[:self.num_query])
|
||||
self.g_camids = np.asarray(camids[self.num_query:])
|
||||
|
||||
# Cosine distance
|
||||
distmat = torch.mm(qf, gf.t())
|
||||
self.distmat = to_np(distmat)
|
||||
self.distmat = distmat.cpu().numpy()
|
||||
self.indices = np.argsort(-self.distmat, axis=1)
|
||||
self.matches = (self.g_pids[self.indices] == self.q_pids[:, np.newaxis]).astype(np.int32)
|
||||
|
||||
|
@ -56,61 +66,13 @@ class ReidInterpretation():
|
|||
sort_idx = order[keep]
|
||||
return cmc, sort_idx
|
||||
|
||||
# def plot_gradcam(self, q_idx):
|
||||
# m = self.model.eval()
|
||||
# cmc, sort_idx = self.get_matched_result(q_idx)
|
||||
# fig,axes = plt.subplots(1, 2, figsize=(10, 5))
|
||||
# fig.suptitle('query gallery gradcam')
|
||||
# query_im, _, _ = self.tst_loader.dataset[q_idx]
|
||||
# de_query_im = Image(denormalize(query_im, tensor(imagenet_stats[0]), tensor(imagenet_stats[1])))
|
||||
# de_query_im.show(ax=axes.flat[0], title='query')
|
||||
#
|
||||
# g_idx = self.num_q + sort_idx[0]
|
||||
# im, _, _ = self.tst_loader.dataset[g_idx]
|
||||
# de_im = Image(denormalize(im, tensor(imagenet_stats[0]), tensor(imagenet_stats[1])))
|
||||
# if cmc[0] == 1: label = 'true'
|
||||
# else: label = 'false'
|
||||
# de_im.show(ax=axes.flat[1], title='gallery')
|
||||
#
|
||||
# query_im = query_im[None,...]
|
||||
# qf_cont = F.normalize(m(query_im.cuda()).detach(), p=2, dim=1)
|
||||
# im = im[None,...]
|
||||
# gf_cont = F.normalize(m(im.cuda()).detach(), p=2, dim=1)
|
||||
#
|
||||
# with hook_output(m.base) as hook_a:
|
||||
# with hook_output(m.base) as hook_g:
|
||||
# qf = m(query_im.cuda())
|
||||
# sim = torch.mm(F.normalize(qf, p=2, dim=1), gf_cont.t())
|
||||
# sim.backward()
|
||||
# acts = hook_a.stored[0].cpu() # activation maps
|
||||
# grad = hook_g.stored[0].cpu()
|
||||
# grad_chan = grad.mean(1).mean(1)
|
||||
# mult = F.relu(((acts * grad_chan[...,None,None])).sum(0))
|
||||
#
|
||||
# acts = self.get_actmap(acts)
|
||||
# sz = list(query_im.shape[-2:])
|
||||
# axes.flat[0].imshow(mult, alpha=0.4, extent=(0, *sz[::-1], 0), interpolation='bilinear', cmap='jet')
|
||||
#
|
||||
# with hook_output(m.base) as hook_a:
|
||||
# with hook_output(m.base) as hook_g:
|
||||
# gf = m(im.cuda())
|
||||
# sim = torch.mm(F.normalize(gf, p=2, dim=1), qf_cont.t())
|
||||
# sim.backward()
|
||||
# acts = hook_a.stored[0].cpu() # activation maps
|
||||
# grad = hook_g.stored[0].cpu()
|
||||
# grad_chan = grad.mean(1).mean(1)
|
||||
# mult = F.relu(((acts * grad_chan[...,None,None])).sum(0))
|
||||
#
|
||||
# acts = self.get_actmap(acts)
|
||||
# sz = list(im.shape[-2:])
|
||||
# axes.flat[1].imshow(mult, alpha=0.4, extent=(0, *sz[::-1], 0), interpolation='bilinear', cmap='jet')
|
||||
|
||||
def plot_rank_result(self, q_idx, top=5, actmap=False, title="Rank result"):
|
||||
m = self.model.eval()
|
||||
cmc, sort_idx = self.get_matched_result(q_idx)
|
||||
fig,axes = plt.subplots(1, top+1, figsize=(15, 5))
|
||||
fig.suptitle('query similarity/true(false)')
|
||||
query_im, _, _ = self.tst_loader.dataset[q_idx]
|
||||
query_im, _, _ = self.val_dataloader.dataset[q_idx]
|
||||
from ipdb import set_trace; set_trace()
|
||||
de_query_im = Image(denormalize(query_im, tensor(imagenet_stats[0]), tensor(imagenet_stats[1])))
|
||||
de_query_im.show(ax=axes.flat[0], title='query')
|
||||
if actmap:
|
||||
|
|
|
@ -4,47 +4,136 @@
|
|||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from test_tube import Experiment
|
||||
from torch import nn
|
||||
|
||||
from data import get_dataloader
|
||||
from data.datasets.eval_reid import evaluate
|
||||
from modeling import build_model, reidLoss
|
||||
from solver.build import make_optimizer, make_lr_scheduler
|
||||
from data.prefetcher import data_prefetcher
|
||||
from modeling import build_model
|
||||
from modeling.losses import TripletLoss
|
||||
from solver.build import make_lr_scheduler, make_optimizer
|
||||
from utils.meters import AverageMeter
|
||||
|
||||
|
||||
class ReidSystem(pl.LightningModule):
|
||||
def __init__(self, cfg, logger, tng_loader, val_loader, num_classes, num_query):
|
||||
super().__init__()
|
||||
# Define networks
|
||||
self.cfg,self.logger,self.tng_loader,self.val_loader,self.num_classes,self.num_query = \
|
||||
cfg,logger,tng_loader,val_loader,num_classes,num_query
|
||||
self.model = build_model(cfg, num_classes)
|
||||
self.loss_fns = reidLoss(cfg.SOLVER.LOSSTYPE, cfg.SOLVER.MARGIN, num_classes)
|
||||
class ReidSystem():
|
||||
def __init__(self, cfg, logger, writer):
|
||||
self.cfg, self.logger, self.writer = cfg, logger, writer
|
||||
# Define dataloader
|
||||
self.tng_dataloader, self.val_dataloader, self.num_classes, self.num_query = get_dataloader(cfg)
|
||||
# networks
|
||||
self.model = build_model(cfg, self.num_classes)
|
||||
# loss function
|
||||
self.ce_loss = nn.CrossEntropyLoss()
|
||||
self.triplet = TripletLoss(cfg.SOLVER.MARGIN)
|
||||
# optimizer and scheduler
|
||||
self.opt = make_optimizer(self.cfg, self.model)
|
||||
self.lr_sched = make_lr_scheduler(self.cfg, self.opt)
|
||||
|
||||
def training_step(self, batch, batch_nb):
|
||||
inputs, labels = batch
|
||||
outs = self.model(inputs, labels)
|
||||
loss = self.loss_fns(outs, labels)
|
||||
return {'loss': loss}
|
||||
self._construct()
|
||||
|
||||
def validation_step(self, batch, batch_nb):
|
||||
inputs, pids, camids = batch
|
||||
feats = self.model(inputs)
|
||||
return {'feats': feats, 'pids': pids.cpu().numpy(), 'camids': camids.cpu().numpy()}
|
||||
def _construct(self):
|
||||
self.global_step = 0
|
||||
self.current_epoch = 0
|
||||
self.batch_nb = 0
|
||||
self.max_epochs = self.cfg.SOLVER.MAX_EPOCHS
|
||||
self.log_interval = self.cfg.SOLVER.LOG_INTERVAL
|
||||
self.eval_period = self.cfg.SOLVER.EVAL_PERIOD
|
||||
self.use_dp = False
|
||||
self.use_ddp = False
|
||||
|
||||
def loss_fns(self, outputs, labels):
|
||||
ce_loss = self.ce_loss(outputs[0], labels)
|
||||
triplet_loss = self.triplet(outputs[1], labels)[0]
|
||||
|
||||
return {'ce_loss': ce_loss, 'triplet': triplet_loss}
|
||||
|
||||
def on_train_begin(self):
|
||||
self.best_mAP = -np.inf
|
||||
self.running_loss = AverageMeter()
|
||||
log_save_dir = os.path.join(self.cfg.OUTPUT_DIR, self.cfg.DATASETS.TEST_NAMES, self.cfg.MODEL.VERSION)
|
||||
self.model_save_dir = os.path.join(log_save_dir, 'ckpts')
|
||||
if not os.path.exists(self.model_save_dir): os.makedirs(self.model_save_dir)
|
||||
|
||||
self.gpus = os.environ['CUDA_VISIBLE_DEVICES'].split(',')
|
||||
self.use_dp = (len(self.gpus) > 0) and (self.cfg.MODEL.DIST_BACKEND == 'dp')
|
||||
|
||||
if self.use_dp:
|
||||
self.model = nn.DataParallel(self.model)
|
||||
|
||||
self.model = self.model.cuda()
|
||||
|
||||
self.model.train()
|
||||
|
||||
def on_epoch_begin(self):
|
||||
self.batch_nb = 0
|
||||
self.current_epoch += 1
|
||||
self.t0 = time.time()
|
||||
self.running_loss.reset()
|
||||
|
||||
self.tng_prefetcher = data_prefetcher(self.tng_dataloader)
|
||||
|
||||
def training_step(self, batch):
|
||||
inputs, labels, _ = batch
|
||||
outputs = self.model(inputs, labels)
|
||||
loss_dict = self.loss_fns(outputs, labels)
|
||||
|
||||
total_loss = 0
|
||||
print_str = f'\r Epoch {self.current_epoch} Iter {self.batch_nb}/{len(self.tng_dataloader)} '
|
||||
for loss_name, loss_value in loss_dict.items():
|
||||
total_loss += loss_value
|
||||
print_str += (loss_name+f': {loss_value.item():.3f} ')
|
||||
loss_dict['total_loss'] = total_loss.item()
|
||||
print_str += f'Total loss: {total_loss.item():.3f} '
|
||||
print(print_str, end=' ')
|
||||
|
||||
if (self.global_step+1) % self.log_interval == 0:
|
||||
self.writer.add_scalar('cross_entropy_loss', loss_dict['ce_loss'], self.global_step)
|
||||
self.writer.add_scalar('triplet_loss', loss_dict['triplet'], self.global_step)
|
||||
self.writer.add_scalar('total_loss', loss_dict['total_loss'], self.global_step)
|
||||
|
||||
self.running_loss.update(total_loss.item())
|
||||
|
||||
self.opt.zero_grad()
|
||||
total_loss.backward()
|
||||
self.opt.step()
|
||||
|
||||
self.global_step += 1
|
||||
self.batch_nb += 1
|
||||
|
||||
def on_epoch_end(self):
|
||||
elapsed = time.time() - self.t0
|
||||
mins = int(elapsed) // 60
|
||||
seconds = int(elapsed - mins * 60)
|
||||
print('')
|
||||
self.logger.info(f'Epoch {self.current_epoch} Total loss: {self.running_loss.avg:.3f} '
|
||||
f'lr: {self.opt.param_groups[0]["lr"]:.2e} During {mins:d}min:{seconds:d}s')
|
||||
# update learning rate
|
||||
self.lr_sched.step()
|
||||
|
||||
def test(self):
|
||||
# convert to eval mode
|
||||
self.model.eval()
|
||||
|
||||
def validation_end(self, outputs):
|
||||
feats,pids,camids = [],[],[]
|
||||
for o in outputs:
|
||||
feats.append(o['feats'])
|
||||
pids.extend(o['pids'])
|
||||
camids.extend(o['camids'])
|
||||
val_prefetcher = data_prefetcher(self.val_dataloader)
|
||||
batch = val_prefetcher.next()
|
||||
while batch[0] is not None:
|
||||
img, pid, camid = batch
|
||||
with torch.no_grad():
|
||||
feat = self.model(img)
|
||||
feats.append(feat)
|
||||
pids.extend(pid.cpu().numpy())
|
||||
camids.extend(np.asarray(camid))
|
||||
|
||||
batch = val_prefetcher.next()
|
||||
|
||||
feats = torch.cat(feats, dim=0)
|
||||
if self.cfg.TEST.NORM:
|
||||
feats = F.normalize(feats, p=2, dim=1)
|
||||
|
@ -57,83 +146,54 @@ class ReidSystem(pl.LightningModule):
|
|||
g_pids = np.asarray(pids[self.num_query:])
|
||||
g_camids = np.asarray(camids[self.num_query:])
|
||||
|
||||
m, n = qf.shape[0], gf.shape[0]
|
||||
distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \
|
||||
torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
|
||||
distmat.addmm_(1, -2, qf, gf.t())
|
||||
distmat = distmat.cpu().numpy()
|
||||
cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids)
|
||||
self.logger.info(f"Test Results - Epoch: {self.current_epoch + 1}")
|
||||
# m, n = qf.shape[0], gf.shape[0]
|
||||
distmat = torch.mm(qf, gf.t()).cpu().numpy()
|
||||
# distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \
|
||||
# torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
|
||||
# distmat.addmm_(1, -2, qf, gf.t())
|
||||
# distmat = distmat.numpy()
|
||||
cmc, mAP = evaluate(-distmat, q_pids, g_pids, q_camids, g_camids)
|
||||
self.logger.info(f"Test Results - Epoch: {self.current_epoch}")
|
||||
self.logger.info(f"mAP: {mAP:.1%}")
|
||||
for r in [1, 5, 10]:
|
||||
self.logger.info(f"CMC curve, Rank-{r:<3}:{cmc[r - 1]:.1%}")
|
||||
tqdm_dic = {'rank1': cmc[0], 'mAP': mAP}
|
||||
return tqdm_dic
|
||||
|
||||
self.writer.add_scalar('rank1', cmc[0], self.global_step)
|
||||
self.writer.add_scalar('mAP', mAP, self.global_step)
|
||||
metric_dict = {'rank1': cmc[0], 'mAP': mAP}
|
||||
# convert to train mode
|
||||
self.model.train()
|
||||
return metric_dict
|
||||
|
||||
def configure_optimizers(self):
|
||||
opt_fns = make_optimizer(self.cfg, self.model)
|
||||
lr_sched = make_lr_scheduler(self.cfg, opt_fns)
|
||||
return [opt_fns], [lr_sched]
|
||||
def train(self):
|
||||
self.on_train_begin()
|
||||
for epoch in range(self.max_epochs):
|
||||
self.on_epoch_begin()
|
||||
batch = self.tng_prefetcher.next()
|
||||
while batch[0] is not None:
|
||||
self.training_step(batch)
|
||||
batch = self.tng_prefetcher.next()
|
||||
self.on_epoch_end()
|
||||
if (epoch+1) % self.eval_period == 0:
|
||||
metric_dict = self.test()
|
||||
if metric_dict['mAP'] > self.best_mAP:
|
||||
is_best = True
|
||||
self.best_mAP = metric_dict['mAP']
|
||||
else:
|
||||
is_best = False
|
||||
self.save_checkpoints(is_best)
|
||||
|
||||
@pl.data_loader
|
||||
def tng_dataloader(self):
|
||||
return self.tng_loader
|
||||
|
||||
@pl.data_loader
|
||||
def val_dataloader(self):
|
||||
return self.val_loader
|
||||
|
||||
|
||||
def do_train(
|
||||
cfg,
|
||||
local_rank,
|
||||
tng_loader,
|
||||
val_loader,
|
||||
num_classes,
|
||||
num_query,
|
||||
):
|
||||
eval_period = cfg.SOLVER.EVAL_PERIOD
|
||||
output_dir = cfg.OUTPUT_DIR
|
||||
epochs = cfg.SOLVER.MAX_EPOCHS
|
||||
gpus = cfg.MODEL.GPUS
|
||||
|
||||
logger = logging.getLogger("reid_baseline.train")
|
||||
logger.info("Start Training")
|
||||
|
||||
filepath = os.path.join(output_dir, cfg.DATASETS.TEST_NAMES, 'version_'+cfg.MODEL.VERSION, 'ckpts')
|
||||
checkpoint_callback = ModelCheckpoint(
|
||||
filepath=filepath,
|
||||
monitor='rank1',
|
||||
save_best_only=True,
|
||||
verbose=True,
|
||||
mode='max',
|
||||
)
|
||||
|
||||
model = ReidSystem(cfg, logger, tng_loader, val_loader, num_classes, num_query)
|
||||
exp = Experiment(save_dir=output_dir, name=cfg.DATASETS.TEST_NAMES, version=cfg.MODEL.VERSION)
|
||||
|
||||
trainer = pl.Trainer(
|
||||
experiment=exp,
|
||||
max_nb_epochs=epochs,
|
||||
checkpoint_callback=checkpoint_callback,
|
||||
check_val_every_n_epoch=eval_period,
|
||||
gpus=gpus,
|
||||
nb_sanity_val_steps=0,
|
||||
print_weights_summary=False,
|
||||
add_log_row_interval=len(tng_loader)//2,
|
||||
)
|
||||
|
||||
trainer.fit(model)
|
||||
|
||||
# continue training
|
||||
# if cfg.MODEL.CHECKPOINT is not '':
|
||||
# state = torch.load(cfg.MODEL.CHECKPOINT)
|
||||
# if set(state.keys()) == {'model', 'opt'}:
|
||||
# model_state = state['model']
|
||||
# learn.model.load_state_dict(model_state)
|
||||
# learn.create_opt(0, 0)
|
||||
# learn.opt.load_state_dict(state['opt'])
|
||||
# else:
|
||||
# learn.model.load_state_dict(state['model'])
|
||||
# logger.info(f'continue training from checkpoint {cfg.MODEL.CHECKPOINT}')
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def save_checkpoints(self, is_best):
|
||||
if self.use_dp:
|
||||
state_dict = self.model.module.state_dict()
|
||||
else:
|
||||
state_dict = self.model.state_dict()
|
||||
|
||||
# TODO: add optimizer state dict and lr scheduler
|
||||
filepath = os.path.join(self.model_save_dir, f'model_epoch{self.current_epoch}.pth')
|
||||
torch.save(state_dict, filepath)
|
||||
if is_best:
|
||||
best_filepath = os.path.join(self.model_save_dir, 'model_best.pth')
|
||||
shutil.copyfile(filepath, best_filepath)
|
||||
|
|
|
@ -29,10 +29,13 @@ model_layers = {
|
|||
'resnet101': [3, 4, 23, 3]
|
||||
}
|
||||
|
||||
__all__ = ['ResNet']
|
||||
__all__ = ['ResNet', 'Bottleneck']
|
||||
|
||||
|
||||
class IBN(nn.Module):
|
||||
"""
|
||||
IBN with BN:IN = 7:1
|
||||
"""
|
||||
def __init__(self, planes):
|
||||
super(IBN, self).__init__()
|
||||
half1 = int(planes/8)
|
||||
|
@ -151,7 +154,7 @@ class ResNet(nn.Module):
|
|||
return x
|
||||
|
||||
def load_pretrain(self, model_path=''):
|
||||
with_model_path = model_path is not ''
|
||||
with_model_path = (model_path is not '')
|
||||
if not with_model_path: # resnet pretrain
|
||||
state_dict = model_zoo.load_url(model_urls[self._model_name])
|
||||
state_dict.pop('fc.weight')
|
||||
|
|
|
@ -8,29 +8,7 @@ from torch import nn
|
|||
|
||||
from .backbones import *
|
||||
from .losses.cosface import AddMarginProduct
|
||||
|
||||
|
||||
def weights_init_kaiming(m):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find('Linear') != -1:
|
||||
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out')
|
||||
nn.init.constant_(m.bias, 0.0)
|
||||
elif classname.find('Conv') != -1:
|
||||
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0.0)
|
||||
elif classname.find('BatchNorm') != -1:
|
||||
if m.affine:
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
nn.init.constant_(m.bias, 0.0)
|
||||
|
||||
|
||||
def weights_init_classifier(m):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find('Linear') != -1:
|
||||
nn.init.normal_(m.weight, std=0.001)
|
||||
if m.bias:
|
||||
nn.init.constant_(m.bias, 0.0)
|
||||
from .utils import *
|
||||
|
||||
|
||||
class Baseline(nn.Module):
|
||||
|
@ -46,35 +24,42 @@ class Baseline(nn.Module):
|
|||
pretrain=True,
|
||||
model_path=''):
|
||||
super().__init__()
|
||||
try: self.base = ResNet.from_name(backbone, last_stride, with_ibn, gcb, stage_with_gcb)
|
||||
except: print(f'not support {backbone} backbone')
|
||||
try:
|
||||
self.base = ResNet.from_name(backbone, last_stride, with_ibn, gcb, stage_with_gcb)
|
||||
except:
|
||||
print(f'not support {backbone} backbone')
|
||||
|
||||
if pretrain: self.base.load_pretrain(model_path)
|
||||
if pretrain:
|
||||
self.base.load_pretrain(model_path)
|
||||
|
||||
self.gap = nn.AdaptiveAvgPool2d(1)
|
||||
self.num_classes = num_classes
|
||||
|
||||
self.bottleneck = nn.BatchNorm1d(self.in_planes)
|
||||
self.bottleneck.bias.requires_grad_(False) # no shift
|
||||
# self.classifier = nn.Linear(self.in_planes, self.num_classes, bias=False)
|
||||
self.classifier = AddMarginProduct(self.in_planes, self.num_classes, s=30, m=0.3)
|
||||
self.classifier = nn.Linear(self.in_planes, self.num_classes, bias=False)
|
||||
# self.classifier = AddMarginProduct(self.in_planes, self.num_classes, s=20, m=0.3)
|
||||
|
||||
self.bottleneck.apply(weights_init_kaiming)
|
||||
# self.classifier.apply(weights_init_classifier)
|
||||
self.classifier.apply(weights_init_classifier)
|
||||
|
||||
def forward(self, x, label=None):
|
||||
global_feat = self.gap(self.base(x)) # (b, 2048, 1, 1)
|
||||
global_feat = global_feat.view(-1, global_feat.size()[1])
|
||||
feat = self.bottleneck(global_feat) # normalize for angular softmax
|
||||
if self.training:
|
||||
cls_score = self.classifier(feat, label) # (2*b, class)
|
||||
# adv_score = self.classifier_swap(feat) # (2*b, 2)
|
||||
# return cls_score, adv_score, global_feat # global feature for triplet loss
|
||||
cls_score = self.classifier(feat)
|
||||
# cls_score = self.classifier(feat, label)
|
||||
return cls_score, global_feat
|
||||
else:
|
||||
return feat
|
||||
|
||||
def load_params_wo_fc(self, state_dict):
|
||||
# new_state_dict = {}
|
||||
# for k, v in state_dict.items():
|
||||
# k = '.'.join(k.split('.')[1:])
|
||||
# new_state_dict[k] = v
|
||||
# state_dict = new_state_dict
|
||||
state_dict.pop('classifier.weight')
|
||||
res = self.load_state_dict(state_dict, strict=False)
|
||||
assert str(res.missing_keys) == str(['classifier.weight',]), 'issue loading pretrained weights'
|
||||
assert str(res.missing_keys) == str(['classifier.weight',]), 'issue loading pretrained weights'
|
||||
|
|
|
@ -4,4 +4,5 @@
|
|||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from .triplet_loss import TripletLoss
|
||||
from .loss import *
|
|
@ -0,0 +1,46 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class CenterLoss(nn.Module):
|
||||
"""Center loss.
|
||||
Reference:
|
||||
Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.
|
||||
Args:
|
||||
num_classes (int): number of classes.
|
||||
feat_dim (int): feature dimension.
|
||||
"""
|
||||
|
||||
def __init__(self, num_classes=751, feat_dim=2048, use_gpu=True):
|
||||
super(CenterLoss, self).__init__()
|
||||
self.num_classes,self.feat_dim = num_classes, feat_dim
|
||||
|
||||
if use_gpu: self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda())
|
||||
else: self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim))
|
||||
|
||||
def forward(self, x, labels):
|
||||
"""
|
||||
Args:
|
||||
x: feature matrix with shape (batch_size, feat_dim).
|
||||
labels: ground truth labels with shape (num_classes).
|
||||
"""
|
||||
assert x.size(0) == labels.size(0), "features.size(0) is not equal to labels.size(0)"
|
||||
|
||||
batch_size = x.size(0)
|
||||
distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \
|
||||
torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()
|
||||
distmat.addmm_(1, -2, x, self.centers.t())
|
||||
|
||||
classes = torch.arange(self.num_classes).long()
|
||||
classes = classes.to(x.device)
|
||||
labels = labels.unsqueeze(1).expand(batch_size, self.num_classes)
|
||||
mask = labels.eq(classes.expand(batch_size, self.num_classes))
|
||||
|
||||
dist = distmat * mask.float()
|
||||
loss = dist.clamp(min=1e-12, max=1e+12).sum() / batch_size
|
||||
return loss
|
|
@ -5,15 +5,13 @@
|
|||
"""
|
||||
from torch import nn
|
||||
|
||||
from .center_loss import CenterLoss
|
||||
from .cosface import AddMarginProduct
|
||||
from .label_smooth import CrossEntropyLabelSmooth
|
||||
from .triplet_loss import TripletLoss
|
||||
|
||||
__all__ = ['reidLoss']
|
||||
|
||||
|
||||
class reidLoss(nn.Module):
|
||||
class reidLoss(object):
|
||||
def __init__(self, lossType: list, margin: float, num_classes: float):
|
||||
super().__init__()
|
||||
self.lossType = lossType
|
||||
|
@ -23,10 +21,24 @@ class reidLoss(nn.Module):
|
|||
if 'triplet' in self.lossType: self.triplet_loss = TripletLoss(margin)
|
||||
# if 'center' in self.lossType: self.center_loss = CenterLoss(num_classes, feat_dim)
|
||||
|
||||
def forward(self, out, labels):
|
||||
cls_scores, feats = out
|
||||
loss = 0
|
||||
if 'softmax' or 'softmax_smooth' in self.lossType: loss += self.ce_loss(cls_scores, labels)
|
||||
if 'triplet' in self.lossType: loss += self.triplet_loss(feats, labels)[0]
|
||||
def __call__(self, outputs, labels):
|
||||
# cls_scores, feats = outputs
|
||||
loss = {}
|
||||
if 'softmax' or 'softmax_smooth' in self.lossType:
|
||||
loss['ce_loss'] = self.ce_loss(outputs[0], labels)
|
||||
# loss['ce_loss'] = 0
|
||||
# ce_iter = 0
|
||||
# for output in outputs[1:]:
|
||||
# loss['ce_loss'] += self.ce_loss(output, labels)
|
||||
# ce_iter += 1
|
||||
# loss['ce_loss'] = 2 * loss['ce_loss'] / ce_iter
|
||||
if 'triplet' in self.lossType:
|
||||
loss['triplet'] = self.triplet_loss(outputs[1], labels)[0]
|
||||
# tri_iter = 0
|
||||
# for output in outputs[:3]:
|
||||
# loss['triplet'] += self.triplet_loss(output, labels)[0]
|
||||
# tri_iter += 1
|
||||
# loss['triplet'] = loss['triplet'] / tri_iter
|
||||
# loss['triplet'] = self.triplet_loss(feats, labels)[0]
|
||||
# if 'center' in self.lossType: loss += 0.0005 * self.center_loss(feats, labels)
|
||||
return loss
|
||||
|
|
|
@ -115,10 +115,11 @@ class TripletLoss(nn.Module):
|
|||
def forward(self, global_feat, labels, normalize_feature=False):
|
||||
if normalize_feature:
|
||||
global_feat = normalize(global_feat, axis=-1)
|
||||
|
||||
dist_mat = euclidean_dist(global_feat, global_feat)
|
||||
dist_ap, dist_an = hard_example_mining(
|
||||
dist_mat, labels)
|
||||
dist_ap, dist_an = hard_example_mining(dist_mat, labels)
|
||||
y = dist_an.new().resize_as_(dist_an).fill_(1)
|
||||
|
||||
if self.margin is not None:
|
||||
loss = self.ranking_loss(dist_an, dist_ap, y)
|
||||
else:
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
from torch import nn
|
||||
|
||||
__all__ = ['weights_init_classifier', 'weights_init_kaiming', 'BN_no_bias']
|
||||
|
||||
|
||||
def weights_init_kaiming(m):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find('Linear') != -1:
|
||||
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out')
|
||||
nn.init.constant_(m.bias, 0.0)
|
||||
elif classname.find('Conv') != -1:
|
||||
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0.0)
|
||||
elif classname.find('BatchNorm') != -1:
|
||||
if m.affine:
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
nn.init.constant_(m.bias, 0.0)
|
||||
|
||||
|
||||
def weights_init_classifier(m):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find('Linear') != -1:
|
||||
nn.init.normal_(m.weight, std=0.001)
|
||||
if m.bias:
|
||||
nn.init.constant_(m.bias, 0.0)
|
||||
|
||||
|
||||
def BN_no_bias(in_features):
|
||||
bn_layer = nn.BatchNorm1d(in_features)
|
||||
bn_layer.bias.requires_grad_(False)
|
||||
return bn_layer
|
|
@ -0,0 +1,6 @@
|
|||
gpu=2
|
||||
|
||||
CUDA_VISIBLE_DEVICES=$gpu python tools/train.py -cfg='configs/softmax_triplet.yml' \
|
||||
DATASETS.NAMES '("market1501",)' \
|
||||
DATASETS.TEST_NAMES 'market1501' \
|
||||
OUTPUT_DIR 'logs/test'
|
|
@ -1,7 +1,8 @@
|
|||
gpu=0
|
||||
GPUS='2'
|
||||
|
||||
CUDA_VISIBLE_DEVICES=$gpu python tools/test.py -cfg='configs/softmax_triplet.yml' \
|
||||
CUDA_VISIBLE_DEVICES=$GPUS python tools/test.py -cfg='configs/softmax_triplet.yml' \
|
||||
DATASETS.TEST_NAMES 'beijing' \
|
||||
MODEL.NAME 'baseline' \
|
||||
MODEL.BACKBONE 'resnet50' \
|
||||
DATASETS.TEST_NAMES 'duke' \
|
||||
OUTPUT_DIR 'logs/test' \
|
||||
TEST.WEIGHT 'logs/2019.8.16/market/resnet50/models/model_149.pth'
|
||||
MODEL.WITH_IBN 'False' \
|
||||
TEST.WEIGHT 'logs/beijing/combineall_bs256_cosface/ckpts/model_best.pth'
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
gpu=2
|
||||
|
||||
#CUDA_VISIBLE_DEVICES=$gpu python tools/train.py -cfg='configs/softmax_triplet.yml' \
|
||||
#DATASETS.NAMES '("market1501","duke","cuhk03","beijing")' \
|
||||
#DATASETS.TEST_NAMES 'bj' \
|
||||
#MODEL.BACKBONE 'resnet50' \
|
||||
#MODEL.WITH_IBN 'False' \
|
||||
#MODEL.STAGE_WITH_GCB '(False, False, False, False)' \
|
||||
#SOLVER.LOSSTYPE '("softmax_smooth", "triplet")' \
|
||||
#OUTPUT_DIR 'logs/2019.8.26/bj/softmax_smooth'
|
||||
|
||||
|
||||
CUDA_VISIBLE_DEVICES=$gpu python tools/train.py -cfg='configs/softmax_triplet.yml' \
|
||||
DATASETS.NAMES '("market1501","duke","cuhk03","beijing")' \
|
||||
DATASETS.TEST_NAMES 'bj' \
|
||||
INPUT.DO_LIGHTING 'False' \
|
||||
MODEL.BACKBONE 'resnet50' \
|
||||
MODEL.WITH_IBN 'True' \
|
||||
MODEL.PRETRAIN_PATH '/export/home/lxy/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar' \
|
||||
MODEL.STAGE_WITH_GCB '(False, False, False, False)' \
|
||||
SOLVER.LOSSTYPE '("softmax_smooth", "triplet")' \
|
||||
OUTPUT_DIR 'logs/2019.8.27/bj/ibn_softmax_smooth'
|
||||
|
||||
# CUDA_VISIBLE_DEVICES=$gpu python tools/train.py -cfg='configs/softmax_triplet.yml' \
|
||||
# DATASETS.NAMES '("market1501","duke","cuhk03","beijing")' \
|
||||
# DATASETS.TEST_NAMES 'bj' \
|
||||
# MODEL.BACKBONE 'resnet50_ibn' \
|
||||
# INPUT.DO_LIGHTING 'True' \
|
||||
# OUTPUT_DIR 'logs/2019.8.14/bj/lighting_ibn7_1'
|
|
@ -1,8 +0,0 @@
|
|||
python tools/train.py -cfg='configs/softmax_triplet.yml' \
|
||||
DATASETS.NAMES '("duke",)' \
|
||||
DATASETS.TEST_NAMES 'duke' \
|
||||
MODEL.BACKBONE 'resnet50' \
|
||||
MODEL.VERSION 'cos_triplet' \
|
||||
SOLVER.LOSSTYPE '("softmax", "triplet")' \
|
||||
OUTPUT_DIR 'logs/2019.9.3'
|
||||
|
|
@ -1,19 +0,0 @@
|
|||
gpu=2
|
||||
|
||||
CUDA_VISIBLE_DEVICES=$gpu python tools/train.py -cfg='configs/softmax_triplet.yml' \
|
||||
DATASETS.NAMES '("duke",)' \
|
||||
DATASETS.TEST_NAMES 'duke' \
|
||||
INPUT.DO_LIGHTING 'False' \
|
||||
MODEL.WITH_IBN 'False' \
|
||||
MODEL.STAGE_WITH_GCB '(False, False, False, False)' \
|
||||
SOLVER.LOSSTYPE '("softmax_smooth", "triplet", "center")' \
|
||||
OUTPUT_DIR 'logs/2019.8.28/duke/smooth_triplet_center'
|
||||
|
||||
|
||||
# MODEL.PRETRAIN_PATH '/export/home/lxy/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar' \
|
||||
# CUDA_VISIBLE_DEVICES=$gpu python tools/train.py -cfg='configs/softmax_triplet.yml' \
|
||||
# DATASETS.NAMES '("market1501",)' \
|
||||
# DATASETS.TEST_NAMES 'market1501' \
|
||||
# SOLVER.IMS_PER_BATCH '64' \
|
||||
# INPUT.DO_LIGHTING 'True' \
|
||||
# OUTPUT_DIR 'logs/market/bs64'
|
|
@ -0,0 +1,13 @@
|
|||
GPUS='2'
|
||||
|
||||
CUDA_VISIBLE_DEVICES=$GPUS python tools/train.py -cfg='configs/softmax_triplet.yml' \
|
||||
DATASETS.NAMES '("dukemtmc",)' \
|
||||
DATASETS.TEST_NAMES 'dukemtmc' \
|
||||
INPUT.SIZE_TRAIN '[288, 144]' \
|
||||
INPUT.SIZE_TEST '[288, 144]' \
|
||||
SOLVER.IMS_PER_BATCH '64' \
|
||||
MODEL.NAME 'mgn' \
|
||||
MODEL.BACKBONE 'resnet50' \
|
||||
MODEL.VERSION 'mgn++' \
|
||||
SOLVER.OPT 'adam' \
|
||||
SOLVER.LOSSTYPE '("softmax", "triplet")' \
|
|
@ -0,0 +1,16 @@
|
|||
GPUS='2,3'
|
||||
|
||||
CUDA_VISIBLE_DEVICES=$GPUS python tools/train.py -cfg='configs/softmax_triplet.yml' \
|
||||
DATASETS.NAMES '("market1501","dukemtmc","cuhk03","msmt17")' \
|
||||
DATASETS.TEST_NAMES 'beijing' \
|
||||
SOLVER.IMS_PER_BATCH '256' \
|
||||
MODEL.NAME 'mgn_plus' \
|
||||
MODEL.WITH_IBN 'False' \
|
||||
MODEL.BACKBONE 'resnet50' \
|
||||
MODEL.VERSION 'combineall_bs256_mgn_plus' \
|
||||
SOLVER.OPT 'adam' \
|
||||
SOLVER.LOSSTYPE '("softmax", "triplet")'
|
||||
|
||||
|
||||
|
||||
# MODEL.PRETRAIN_PATH '/home/liaoxingyu2/lxy/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar' \
|
|
@ -0,0 +1,15 @@
|
|||
GPUS='0,1'
|
||||
|
||||
CUDA_VISIBLE_DEVICES=$GPUS python tools/train.py -cfg='configs/softmax_triplet.yml' \
|
||||
DATASETS.NAMES '("dukemtmc",)' \
|
||||
DATASETS.TEST_NAMES 'dukemtmc' \
|
||||
INPUT.SIZE_TRAIN '[256, 128]' \
|
||||
INPUT.SIZE_TEST '[256, 128]' \
|
||||
SOLVER.IMS_PER_BATCH '256' \
|
||||
MODEL.NAME 'baseline' \
|
||||
MODEL.WITH_IBN 'True' \
|
||||
MODEL.BACKBONE 'resnet50' \
|
||||
MODEL.VERSION 'baseline_bs256' \
|
||||
MODEL.PRETRAIN_PATH '/export/home/lxy/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar' \
|
||||
SOLVER.OPT 'adam' \
|
||||
SOLVER.LOSSTYPE '("softmax", "triplet")' \
|
|
@ -22,6 +22,7 @@ def make_optimizer(cfg, model):
|
|||
params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]
|
||||
if cfg.SOLVER.OPT == 'sgd': opt_fns = torch.optim.SGD(params, momentum=cfg.SOLVER.MOMENTUM)
|
||||
elif cfg.SOLVER.OPT == 'adam': opt_fns = torch.optim.Adam(params)
|
||||
elif cfg.SOLVER.OPT == 'adamw': opt_fns = torch.optim.AdamW(params)
|
||||
else:
|
||||
raise NameError(f'optimizer {cfg.SOLVER.OPT} not support')
|
||||
return opt_fns
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import sys
|
||||
from fastai.vision import *
|
||||
sys.path.append('.')
|
||||
from data import get_dataloader
|
||||
from config import cfg
|
||||
import argparse
|
||||
from data.datasets import init_dataset
|
||||
# cfg.DATALOADER.SAMPLER = 'triplet'
|
||||
cfg.DATASETS.NAMES = ("market1501", "dukemtmc", "cuhk03", "msmt17",)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description="ReID Baseline Training")
|
||||
parser.add_argument(
|
||||
'-cfg', "--config_file",
|
||||
default="",
|
||||
metavar="FILE",
|
||||
help="path to config file",
|
||||
type=str
|
||||
)
|
||||
# parser.add_argument("--local_rank", type=int, default=0)
|
||||
parser.add_argument("opts", help="Modify config options using the command-line", default=None,
|
||||
nargs=argparse.REMAINDER)
|
||||
args = parser.parse_args()
|
||||
cfg.merge_from_list(args.opts)
|
||||
|
||||
# dataset = init_dataset('msmt17', combineall=True)
|
||||
get_dataloader(cfg)
|
||||
# tng_dataloader, val_dataloader, num_classes, num_query = get_dataloader(cfg)
|
||||
# def get_ex(): return open_image('datasets/beijingStation/query/000245_c10s2_1561732033722.000000.jpg')
|
||||
# im = get_ex()
|
||||
# print(data.train_ds[0])
|
||||
# print(data.test_ds[0])
|
||||
# a = next(iter(data.train_dl))
|
||||
# from IPython import embed; embed()
|
||||
# from ipdb import set_trace; set_trace()
|
||||
# im.apply_tfms(crop_pad(size=(300, 300)))
|
|
@ -1,29 +1,31 @@
|
|||
import sys
|
||||
import unittest
|
||||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
import sys
|
||||
sys.path.append('.')
|
||||
from modeling import *
|
||||
sys.path.append(".")
|
||||
from config import cfg
|
||||
from modeling import build_model
|
||||
from modeling.backbones import *
|
||||
from modeling.mgn import MGN
|
||||
from modeling.mgn_plus import MGN_P
|
||||
|
||||
class MyTestCase(unittest.TestCase):
|
||||
def test_model(self):
|
||||
cfg.MODEL.WITH_IBN = True
|
||||
cfg.MODEL.PRETRAIN_PATH = '/home/user01/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar'
|
||||
net = build_model(cfg, 100)
|
||||
y = net(torch.randn(2, 3, 256, 128))
|
||||
from ipdb import set_trace; set_trace()
|
||||
# net1 = ResNet.from_name('resnet50', 1, True)
|
||||
# for i in net1.named_parameters():
|
||||
# print(i[0])
|
||||
# net2 = resnet50_ibn_a(1)
|
||||
# print('*'*10)
|
||||
# for i in net2.named_parameters():
|
||||
# print(i[0])
|
||||
cfg.MODEL.BACKBONE = 'resnet50'
|
||||
cfg.MODEL.WITH_IBN = False
|
||||
# cfg.MODEL.PRETRAIN_PATH = '/export/home/lxy/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar'
|
||||
|
||||
net = MGN_P('resnet50', 100, 1, False, None, cfg.MODEL.STAGE_WITH_GCB, cfg.MODEL.PRETRAIN, cfg.MODEL.PRETRAIN_PATH)
|
||||
# net = MGN('resnet50', 100, 2, False,None, cfg.MODEL.STAGE_WITH_GCB, cfg.MODEL.PRETRAIN, cfg.MODEL.PRETRAIN_PATH)
|
||||
# net.eval()
|
||||
# net = net.cuda()
|
||||
x = torch.randn(10, 3, 256, 128)
|
||||
y = net(x)
|
||||
from ipdb import set_trace; set_trace()
|
||||
# label = torch.ones(10).long().cuda()
|
||||
# y = net(x, label)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
@ -7,17 +7,16 @@
|
|||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from os import mkdir
|
||||
|
||||
import torch
|
||||
from torch.backends import cudnn
|
||||
|
||||
sys.path.append('.')
|
||||
from config import cfg
|
||||
from data import get_data_bunch
|
||||
from data import get_test_dataloader
|
||||
from engine.inference import inference
|
||||
from utils.logger import setup_logger
|
||||
from modeling import build_model
|
||||
from utils.logger import setup_logger
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -35,11 +34,11 @@ def main():
|
|||
if args.config_file != "":
|
||||
cfg.merge_from_file(args.config_file)
|
||||
cfg.merge_from_list(args.opts)
|
||||
# set pretrian = False to avoid loading weight repeatedly
|
||||
cfg.MODEL.PRETRAIN = False
|
||||
cfg.freeze()
|
||||
|
||||
if not os.path.exists(cfg.OUTPUT_DIR): os.makedirs(cfg.OUTPUT_DIR)
|
||||
|
||||
logger = setup_logger("reid_baseline", cfg.OUTPUT_DIR, 0)
|
||||
logger = setup_logger("reid_baseline", False, 0)
|
||||
logger.info("Using {} GPUS".format(num_gpus))
|
||||
logger.info(args)
|
||||
|
||||
|
@ -49,15 +48,15 @@ def main():
|
|||
|
||||
cudnn.benchmark = True
|
||||
|
||||
data_bunch, test_labels, num_query = get_data_bunch(cfg)
|
||||
model = build_model(cfg, data_bunch.c)
|
||||
state_dict = torch.load(cfg.TEST.WEIGHT)
|
||||
model.load_params_wo_fc(state_dict['model'])
|
||||
model.cuda()
|
||||
# model = torch.jit.load("/export/home/lxy/reid_baseline/pcb_model_v0.2.pt")
|
||||
model = build_model(cfg, 0)
|
||||
model = model.cuda()
|
||||
model.load_params_wo_fc(torch.load(cfg.TEST.WEIGHT))
|
||||
|
||||
inference(cfg, model, data_bunch, test_labels, num_query)
|
||||
test_dataloader, num_query = get_test_dataloader(cfg)
|
||||
|
||||
inference(cfg, model, test_dataloader, num_query)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
|
|
|
@ -8,33 +8,17 @@ import argparse
|
|||
import os
|
||||
import sys
|
||||
|
||||
import warnings
|
||||
import torch
|
||||
from torch.backends import cudnn
|
||||
|
||||
sys.path.append(".")
|
||||
from config import cfg
|
||||
from data import get_dataloader
|
||||
from engine.trainer import do_train
|
||||
from utils.logger import setup_logger
|
||||
|
||||
|
||||
def train(cfg, local_rank):
|
||||
# prepare dataset
|
||||
tng_loader, val_loader, num_classes, num_query = get_dataloader(cfg)
|
||||
|
||||
do_train(
|
||||
cfg,
|
||||
local_rank,
|
||||
tng_loader,
|
||||
val_loader,
|
||||
num_classes,
|
||||
num_query,
|
||||
)
|
||||
from engine.trainer import ReidSystem
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="ReID Baseline Training")
|
||||
parser = argparse.ArgumentParser(description="ReID Model Training")
|
||||
parser.add_argument(
|
||||
'-cfg', "--config_file",
|
||||
default="",
|
||||
|
@ -42,7 +26,7 @@ def main():
|
|||
help="path to config file",
|
||||
type=str
|
||||
)
|
||||
parser.add_argument("--local_rank", type=int, default=0)
|
||||
# parser.add_argument("--local_rank", type=int, default=0)
|
||||
parser.add_argument("opts", help="Modify config options using the command-line", default=None,
|
||||
nargs=argparse.REMAINDER)
|
||||
args = parser.parse_args()
|
||||
|
@ -51,21 +35,21 @@ def main():
|
|||
cfg.merge_from_list(args.opts)
|
||||
|
||||
num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
|
||||
cfg.SOLVER.DIST = num_gpus > 1
|
||||
# cfg.SOLVER.DIST = num_gpus > 1
|
||||
|
||||
if cfg.SOLVER.DIST:
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
torch.distributed.init_process_group(
|
||||
backend="nccl", init_method="env://"
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
# if cfg.SOLVER.DIST:
|
||||
# torch.cuda.set_device(args.local_rank)
|
||||
# torch.distributed.init_process_group(
|
||||
# backend="nccl", init_method="env://"
|
||||
# )
|
||||
# torch.cuda.synchronize()
|
||||
|
||||
cfg.freeze()
|
||||
|
||||
log_save_dir = os.path.join(cfg.OUTPUT_DIR, cfg.DATASETS.TEST_NAMES, 'version_'+cfg.MODEL.VERSION)
|
||||
log_save_dir = os.path.join(cfg.OUTPUT_DIR, cfg.DATASETS.TEST_NAMES, cfg.MODEL.VERSION)
|
||||
if not os.path.exists(log_save_dir): os.makedirs(log_save_dir)
|
||||
|
||||
logger = setup_logger("reid_baseline", log_save_dir, 0)
|
||||
logger = setup_logger("reid_baseline.train", log_save_dir, 0)
|
||||
logger.info("Using {} GPUs.".format(num_gpus))
|
||||
logger.info(args)
|
||||
|
||||
|
@ -73,8 +57,24 @@ def main():
|
|||
logger.info("Loaded configuration file {}".format(args.config_file))
|
||||
logger.info("Running with config:\n{}".format(cfg))
|
||||
|
||||
logger.info('start training')
|
||||
cudnn.benchmark = True
|
||||
train(cfg, args.local_rank)
|
||||
|
||||
writer = SummaryWriter(os.path.join(log_save_dir, 'tf'))
|
||||
reid_system = ReidSystem(cfg, logger, writer)
|
||||
reid_system.train()
|
||||
|
||||
# TODO: continue training
|
||||
# if cfg.MODEL.CHECKPOINT is not '':
|
||||
# state = torch.load(cfg.MODEL.CHECKPOINT)
|
||||
# if set(state.keys()) == {'model', 'opt'}:
|
||||
# model_state = state['model']
|
||||
# learn.model.load_state_dict(model_state)
|
||||
# learn.create_opt(0, 0)
|
||||
# learn.opt.load_state_dict(state['opt'])
|
||||
# else:
|
||||
# learn.model.load_state_dict(state['model'])
|
||||
# logger.info(f'continue training from checkpoint {cfg.MODEL.CHECKPOINT}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
|
||||
class AverageMeter(object):
|
||||
"""Computes and stores the average and current value"""
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
Loading…
Reference in New Issue