mirror of https://github.com/JDAI-CV/fast-reid.git
Finish refactor code by fastai
parent
852bb8ae8b
commit
29630d1290
|
@ -1,6 +1,6 @@
|
||||||
# encoding: utf-8
|
# encoding: utf-8
|
||||||
"""
|
"""
|
||||||
@author: sherlock
|
@author: l1aoxingyu
|
||||||
@contact: sherlockliao01@gmail.com
|
@contact: sherlockliao01@gmail.com
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
|
@ -26,9 +26,9 @@ _C.MODEL.PRETRAIN_PATH = ''
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
_C.INPUT = CN()
|
_C.INPUT = CN()
|
||||||
# Size of the image during training
|
# Size of the image during training
|
||||||
_C.INPUT.SIZE_TRAIN = [384, 128]
|
_C.INPUT.SIZE_TRAIN = [256, 128]
|
||||||
# Size of the image during test
|
# Size of the image during test
|
||||||
_C.INPUT.SIZE_TEST = [384, 128]
|
_C.INPUT.SIZE_TEST = [256, 128]
|
||||||
# Random probability for image horizontal flip
|
# Random probability for image horizontal flip
|
||||||
_C.INPUT.PROB = 0.5
|
_C.INPUT.PROB = 0.5
|
||||||
# Values to be used for image normalization
|
# Values to be used for image normalization
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
MODEL:
|
MODEL:
|
||||||
PRETRAIN_PATH: '/export/home/lxy/.torch/models/resnet50-19c8e357.pth'
|
PRETRAIN_PATH: 'home/user01/.torch/models/resnet50-19c8e357.pth'
|
||||||
|
|
||||||
|
|
||||||
INPUT:
|
INPUT:
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
MODEL:
|
MODEL:
|
||||||
PRETRAIN_PATH: '/export/home/lxy/.torch/models/resnet50-19c8e357.pth'
|
PRETRAIN_PATH: '/home/user01/.torch/models/resnet50-19c8e357.pth'
|
||||||
|
|
||||||
|
|
||||||
INPUT:
|
INPUT:
|
||||||
SIZE_TRAIN: [384, 128]
|
SIZE_TRAIN: [256, 128]
|
||||||
SIZE_TEST: [384, 128]
|
SIZE_TEST: [256, 128]
|
||||||
PROB: 0.5 # random horizontal flip
|
PROB: 0.5 # random horizontal flip
|
||||||
PADDING: 10
|
PADDING: 10
|
||||||
|
|
||||||
|
@ -14,13 +14,11 @@ DATASETS:
|
||||||
DATALOADER:
|
DATALOADER:
|
||||||
SAMPLER: 'softmax_triplet'
|
SAMPLER: 'softmax_triplet'
|
||||||
NUM_INSTANCE: 4
|
NUM_INSTANCE: 4
|
||||||
NUM_WORKERS: 8
|
|
||||||
|
|
||||||
SOLVER:
|
SOLVER:
|
||||||
OPTIMIZER_NAME: 'Adam'
|
OPTIMIZER_NAME: 'Adam'
|
||||||
MAX_EPOCHS: 120
|
MAX_EPOCHS: 120
|
||||||
BASE_LR: 0.00035
|
BASE_LR: 0.00035
|
||||||
BIAS_LR_FACTOR: 1
|
|
||||||
WEIGHT_DECAY: 0.0005
|
WEIGHT_DECAY: 0.0005
|
||||||
WEIGHT_DECAY_BIAS: 0.0005
|
WEIGHT_DECAY_BIAS: 0.0005
|
||||||
IMS_PER_BATCH: 64
|
IMS_PER_BATCH: 64
|
||||||
|
@ -30,16 +28,13 @@ SOLVER:
|
||||||
|
|
||||||
WARMUP_FACTOR: 0.01
|
WARMUP_FACTOR: 0.01
|
||||||
WARMUP_ITERS: 10
|
WARMUP_ITERS: 10
|
||||||
WARMUP_METHOD: 'linear'
|
|
||||||
|
|
||||||
CHECKPOINT_PERIOD: 40
|
|
||||||
LOG_PERIOD: 100
|
|
||||||
EVAL_PERIOD: 40
|
EVAL_PERIOD: 40
|
||||||
|
|
||||||
TEST:
|
TEST:
|
||||||
IMS_PER_BATCH: 256
|
IMS_PER_BATCH: 512
|
||||||
WEIGHT: "path"
|
WEIGHT: "path"
|
||||||
|
|
||||||
OUTPUT_DIR: "/export/home/lxy/CHECKPOINTS/reid/market1501/softmax_triplet_bs128_384x128"
|
OUTPUT_DIR: "/home/user01/l1aoxingyu/reid_baseline/logs/co-train/softmax_triplet_bs64_256x128"
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -4,4 +4,4 @@
|
||||||
@contact: sherlockliao01@gmail.com
|
@contact: sherlockliao01@gmail.com
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .build import make_data_loader
|
from .build import get_data_bunch
|
||||||
|
|
|
@ -1,44 +1,69 @@
|
||||||
# encoding: utf-8
|
# encoding: utf-8
|
||||||
"""
|
"""
|
||||||
@author: liaoxingyu
|
@author: l1aoxingyu
|
||||||
@contact: sherlockliao01@gmail.com
|
@contact: sherlockliao01@gmail.com
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from torch.utils.data import DataLoader
|
import glob
|
||||||
|
|
||||||
from .collate_batch import train_collate_fn, val_collate_fn
|
from fastai.vision import *
|
||||||
from .datasets import init_dataset, ImageDataset
|
from .transforms import RandomErasing
|
||||||
from .samplers import RandomIdentitySampler
|
from .samplers import RandomIdentitySampler
|
||||||
from .transforms import build_transforms
|
|
||||||
|
|
||||||
|
|
||||||
def make_data_loader(cfg):
|
def get_data_bunch(cfg):
|
||||||
train_transforms = build_transforms(cfg, is_train=True)
|
ds_tfms = (
|
||||||
val_transforms = build_transforms(cfg, is_train=False)
|
[flip_lr(p=0.5),
|
||||||
num_workers = cfg.DATALOADER.NUM_WORKERS
|
*rand_pad(padding=cfg.INPUT.PADDING, size=cfg.INPUT.SIZE_TRAIN, mode='zeros'),
|
||||||
if len(cfg.DATASETS.NAMES) == 1:
|
RandomErasing()
|
||||||
dataset = init_dataset(cfg.DATASETS.NAMES)
|
],
|
||||||
else:
|
None
|
||||||
# TODO: add multi dataset to train
|
|
||||||
dataset = init_dataset(cfg.DATASETS.NAMES)
|
|
||||||
|
|
||||||
num_classes = dataset.num_train_pids
|
|
||||||
train_set = ImageDataset(dataset.train, train_transforms)
|
|
||||||
if cfg.DATALOADER.SAMPLER == 'softmax':
|
|
||||||
train_loader = DataLoader(
|
|
||||||
train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH, shuffle=True, num_workers=num_workers,
|
|
||||||
collate_fn=train_collate_fn
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
train_loader = DataLoader(
|
|
||||||
train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH,
|
|
||||||
sampler=RandomIdentitySampler(dataset.train, cfg.SOLVER.IMS_PER_BATCH, cfg.DATALOADER.NUM_INSTANCE),
|
|
||||||
num_workers=num_workers, collate_fn=train_collate_fn
|
|
||||||
)
|
|
||||||
|
|
||||||
val_set = ImageDataset(dataset.query + dataset.gallery, val_transforms)
|
|
||||||
val_loader = DataLoader(
|
|
||||||
val_set, batch_size=cfg.TEST.IMS_PER_BATCH, shuffle=False, num_workers=num_workers,
|
|
||||||
collate_fn=val_collate_fn
|
|
||||||
)
|
)
|
||||||
return train_loader, val_loader, len(dataset.query), num_classes
|
|
||||||
|
def _process_dir(dir_path):
|
||||||
|
img_paths = glob.glob(os.path.join(dir_path, '*.jpg'))
|
||||||
|
pattern = re.compile(r'([-\d]+)_c(\d)')
|
||||||
|
|
||||||
|
pid_container = set()
|
||||||
|
v_paths = []
|
||||||
|
for img_path in img_paths:
|
||||||
|
pid, camid = map(int, pattern.search(img_path).groups())
|
||||||
|
if pid == -1: continue # junk images are just ignored
|
||||||
|
pid_container.add(pid)
|
||||||
|
v_paths.append([img_path, pid, camid])
|
||||||
|
return v_paths
|
||||||
|
|
||||||
|
market_train_path = 'datasets/Market-1501-v15.09.15/bounding_box_train'
|
||||||
|
duke_train_path = 'datasets/DukeMTMC-reID/bounding_box_train'
|
||||||
|
cuhk03_train_path = ''
|
||||||
|
query_path = 'datasets/Market-1501-v15.09.15/query'
|
||||||
|
gallery_path = 'datasets/Market-1501-v15.09.15/bounding_box_test'
|
||||||
|
|
||||||
|
train_img_names = _process_dir(market_train_path) + _process_dir(duke_train_path)
|
||||||
|
train_names = [i[0] for i in train_img_names]
|
||||||
|
|
||||||
|
query_names = _process_dir(query_path)
|
||||||
|
gallery_names = _process_dir(gallery_path)
|
||||||
|
test_fnames = []
|
||||||
|
test_labels = []
|
||||||
|
for i in query_names+gallery_names:
|
||||||
|
test_fnames.append(i[0])
|
||||||
|
test_labels.append(i[1:])
|
||||||
|
|
||||||
|
def get_labels(file_path):
|
||||||
|
""" Suitable for muilti-dataset training """
|
||||||
|
prefix = file_path.split('/')[1]
|
||||||
|
pat = re.compile(r'([-\d]+)_c(\d)')
|
||||||
|
pid, _ = pat.search(file_path).groups()
|
||||||
|
return prefix + '_' + pid
|
||||||
|
|
||||||
|
data_sampler = RandomIdentitySampler(train_names, cfg.SOLVER.IMS_PER_BATCH, cfg.DATALOADER.NUM_INSTANCE) \
|
||||||
|
if cfg.DATALOADER.SAMPLER == 'softmax_triplet' else None
|
||||||
|
data_bunch = ImageDataBunch.from_name_func('datasets', train_names, label_func=get_labels, valid_pct=0,
|
||||||
|
size=(256, 128), ds_tfms=ds_tfms, bs=cfg.SOLVER.IMS_PER_BATCH,
|
||||||
|
val_bs=cfg.TEST.IMS_PER_BATCH,
|
||||||
|
sampler=data_sampler)
|
||||||
|
data_bunch.add_test(test_fnames)
|
||||||
|
data_bunch.normalize(imagenet_stats)
|
||||||
|
|
||||||
|
return data_bunch, test_labels, len(query_names)
|
||||||
|
|
|
@ -5,14 +5,15 @@
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from fastai.vision import *
|
||||||
|
|
||||||
|
|
||||||
def train_collate_fn(batch):
|
def test_collate_fn(batch):
|
||||||
imgs, pids, _, _, = zip(*batch)
|
imgs, label = zip(*batch)
|
||||||
pids = torch.tensor(pids, dtype=torch.int64)
|
imgs = to_data(imgs)
|
||||||
return torch.stack(imgs, dim=0), pids
|
pids = []
|
||||||
|
camids = []
|
||||||
|
for i in label:
|
||||||
def val_collate_fn(batch):
|
pids.append(i.obj[0])
|
||||||
imgs, pids, camids, _ = zip(*batch)
|
camids.append(i.obj[1])
|
||||||
return torch.stack(imgs, dim=0), pids, camids
|
return torch.stack(imgs, dim=0), (torch.LongTensor(pids), torch.LongTensor(camids))
|
||||||
|
|
|
@ -3,18 +3,33 @@
|
||||||
@author: liaoxingyu
|
@author: liaoxingyu
|
||||||
@contact: sherlockliao01@gmail.com
|
@contact: sherlockliao01@gmail.com
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import copy
|
||||||
|
from collections import defaultdict
|
||||||
|
import sys
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
try:
|
||||||
|
from csrc.eval_cylib.eval_metrics_cy import evaluate_cy
|
||||||
|
IS_CYTHON_AVAI = True
|
||||||
|
print("Using Cython evaluation code as the backend")
|
||||||
|
except ImportError:
|
||||||
|
IS_CYTHON_AVAI = False
|
||||||
|
warnings.warn("Cython evaluation is UNAVAILABLE, which is highly recommended")
|
||||||
|
|
||||||
|
|
||||||
def eval_func(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50):
|
def eval_cuhk03(distmat, q_pids, g_pids, q_camids, g_camids, max_rank):
|
||||||
"""Evaluation with market1501 metric
|
"""Evaluation with cuhk03 metric
|
||||||
Key: for each query identity, its gallery images from the same camera view are discarded.
|
Key: one image for each gallery identity is randomly sampled for each query identity.
|
||||||
"""
|
Random sampling is performed num_repeats times.
|
||||||
|
"""
|
||||||
|
num_repeats = 10
|
||||||
num_q, num_g = distmat.shape
|
num_q, num_g = distmat.shape
|
||||||
|
|
||||||
if num_g < max_rank:
|
if num_g < max_rank:
|
||||||
max_rank = num_g
|
max_rank = num_g
|
||||||
print("Note: number of gallery samples is quite small, got {}".format(num_g))
|
print("Note: number of gallery samples is quite small, got {}".format(num_g))
|
||||||
|
|
||||||
indices = np.argsort(distmat, axis=1)
|
indices = np.argsort(distmat, axis=1)
|
||||||
matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)
|
matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)
|
||||||
|
|
||||||
|
@ -22,6 +37,7 @@ def eval_func(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50):
|
||||||
all_cmc = []
|
all_cmc = []
|
||||||
all_AP = []
|
all_AP = []
|
||||||
num_valid_q = 0. # number of valid query
|
num_valid_q = 0. # number of valid query
|
||||||
|
|
||||||
for q_idx in range(num_q):
|
for q_idx in range(num_q):
|
||||||
# get query pid and camid
|
# get query pid and camid
|
||||||
q_pid = q_pids[q_idx]
|
q_pid = q_pids[q_idx]
|
||||||
|
@ -33,13 +49,84 @@ def eval_func(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50):
|
||||||
keep = np.invert(remove)
|
keep = np.invert(remove)
|
||||||
|
|
||||||
# compute cmc curve
|
# compute cmc curve
|
||||||
# binary vector, positions with value 1 are correct matches
|
raw_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches
|
||||||
orig_cmc = matches[q_idx][keep]
|
if not np.any(raw_cmc):
|
||||||
if not np.any(orig_cmc):
|
|
||||||
# this condition is true when query identity does not appear in gallery
|
# this condition is true when query identity does not appear in gallery
|
||||||
continue
|
continue
|
||||||
|
|
||||||
cmc = orig_cmc.cumsum()
|
kept_g_pids = g_pids[order][keep]
|
||||||
|
g_pids_dict = defaultdict(list)
|
||||||
|
for idx, pid in enumerate(kept_g_pids):
|
||||||
|
g_pids_dict[pid].append(idx)
|
||||||
|
|
||||||
|
cmc, AP = 0., 0.
|
||||||
|
for repeat_idx in range(num_repeats):
|
||||||
|
mask = np.zeros(len(raw_cmc), dtype=np.bool)
|
||||||
|
for _, idxs in g_pids_dict.items():
|
||||||
|
# randomly sample one image for each gallery person
|
||||||
|
rnd_idx = np.random.choice(idxs)
|
||||||
|
mask[rnd_idx] = True
|
||||||
|
masked_raw_cmc = raw_cmc[mask]
|
||||||
|
_cmc = masked_raw_cmc.cumsum()
|
||||||
|
_cmc[_cmc > 1] = 1
|
||||||
|
cmc += _cmc[:max_rank].astype(np.float32)
|
||||||
|
# compute AP
|
||||||
|
num_rel = masked_raw_cmc.sum()
|
||||||
|
tmp_cmc = masked_raw_cmc.cumsum()
|
||||||
|
tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)]
|
||||||
|
tmp_cmc = np.asarray(tmp_cmc) * masked_raw_cmc
|
||||||
|
AP += tmp_cmc.sum() / num_rel
|
||||||
|
|
||||||
|
cmc /= num_repeats
|
||||||
|
AP /= num_repeats
|
||||||
|
all_cmc.append(cmc)
|
||||||
|
all_AP.append(AP)
|
||||||
|
num_valid_q += 1.
|
||||||
|
|
||||||
|
assert num_valid_q > 0, "Error: all query identities do not appear in gallery"
|
||||||
|
|
||||||
|
all_cmc = np.asarray(all_cmc).astype(np.float32)
|
||||||
|
all_cmc = all_cmc.sum(0) / num_valid_q
|
||||||
|
mAP = np.mean(all_AP)
|
||||||
|
|
||||||
|
return all_cmc, mAP
|
||||||
|
|
||||||
|
|
||||||
|
def eval_market1501(distmat, q_pids, g_pids, q_camids, g_camids, max_rank):
|
||||||
|
"""Evaluation with market1501 metric
|
||||||
|
Key: for each query identity, its gallery images from the same camera view are discarded.
|
||||||
|
"""
|
||||||
|
num_q, num_g = distmat.shape
|
||||||
|
|
||||||
|
if num_g < max_rank:
|
||||||
|
max_rank = num_g
|
||||||
|
print("Note: number of gallery samples is quite small, got {}".format(num_g))
|
||||||
|
|
||||||
|
indices = np.argsort(distmat, axis=1)
|
||||||
|
matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)
|
||||||
|
|
||||||
|
# compute cmc curve for each query
|
||||||
|
all_cmc = []
|
||||||
|
all_AP = []
|
||||||
|
num_valid_q = 0. # number of valid query
|
||||||
|
|
||||||
|
for q_idx in range(num_q):
|
||||||
|
# get query pid and camid
|
||||||
|
q_pid = q_pids[q_idx]
|
||||||
|
q_camid = q_camids[q_idx]
|
||||||
|
|
||||||
|
# remove gallery samples that have the same pid and camid with query
|
||||||
|
order = indices[q_idx]
|
||||||
|
remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid)
|
||||||
|
keep = np.invert(remove)
|
||||||
|
|
||||||
|
# compute cmc curve
|
||||||
|
raw_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches
|
||||||
|
if not np.any(raw_cmc):
|
||||||
|
# this condition is true when query identity does not appear in gallery
|
||||||
|
continue
|
||||||
|
|
||||||
|
cmc = raw_cmc.cumsum()
|
||||||
cmc[cmc > 1] = 1
|
cmc[cmc > 1] = 1
|
||||||
|
|
||||||
all_cmc.append(cmc[:max_rank])
|
all_cmc.append(cmc[:max_rank])
|
||||||
|
@ -47,10 +134,10 @@ def eval_func(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50):
|
||||||
|
|
||||||
# compute average precision
|
# compute average precision
|
||||||
# reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision
|
# reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision
|
||||||
num_rel = orig_cmc.sum()
|
num_rel = raw_cmc.sum()
|
||||||
tmp_cmc = orig_cmc.cumsum()
|
tmp_cmc = raw_cmc.cumsum()
|
||||||
tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)]
|
tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)]
|
||||||
tmp_cmc = np.asarray(tmp_cmc) * orig_cmc
|
tmp_cmc = np.asarray(tmp_cmc) * raw_cmc
|
||||||
AP = tmp_cmc.sum() / num_rel
|
AP = tmp_cmc.sum() / num_rel
|
||||||
all_AP.append(AP)
|
all_AP.append(AP)
|
||||||
|
|
||||||
|
@ -61,3 +148,19 @@ def eval_func(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50):
|
||||||
mAP = np.mean(all_AP)
|
mAP = np.mean(all_AP)
|
||||||
|
|
||||||
return all_cmc, mAP
|
return all_cmc, mAP
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_py(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03):
|
||||||
|
if use_metric_cuhk03:
|
||||||
|
return eval_cuhk03(distmat, q_pids, g_pids, q_camids, g_camids, max_rank)
|
||||||
|
else:
|
||||||
|
return eval_market1501(distmat, q_pids, g_pids, q_camids, g_camids, max_rank)
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50, use_metric_cuhk03=False, use_cython=True):
|
||||||
|
if use_cython and IS_CYTHON_AVAI:
|
||||||
|
return evaluate_cy(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03)
|
||||||
|
else:
|
||||||
|
return evaluate_py(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -4,11 +4,13 @@
|
||||||
@contact: liaoxingyu2@jd.com
|
@contact: liaoxingyu2@jd.com
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import copy
|
|
||||||
import random
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
|
import random
|
||||||
|
import copy
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import re
|
||||||
|
import torch
|
||||||
from torch.utils.data.sampler import Sampler
|
from torch.utils.data.sampler import Sampler
|
||||||
|
|
||||||
|
|
||||||
|
@ -23,12 +25,17 @@ class RandomIdentitySampler(Sampler):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, data_source, batch_size, num_instances):
|
def __init__(self, data_source, batch_size, num_instances):
|
||||||
|
pat = re.compile(r'([-\d]+)_c(\d)')
|
||||||
|
|
||||||
self.data_source = data_source
|
self.data_source = data_source
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.num_instances = num_instances
|
self.num_instances = num_instances
|
||||||
self.num_pids_per_batch = self.batch_size // self.num_instances
|
self.num_pids_per_batch = self.batch_size // self.num_instances
|
||||||
self.index_dic = defaultdict(list)
|
self.index_dic = defaultdict(list)
|
||||||
for index, (_, pid, _) in enumerate(self.data_source):
|
for index, fname in enumerate(self.data_source):
|
||||||
|
prefix = fname.split('/')[1]
|
||||||
|
pid, _ = pat.search(fname).groups()
|
||||||
|
pid = prefix + '_' + pid
|
||||||
self.index_dic[pid].append(index)
|
self.index_dic[pid].append(index)
|
||||||
self.pids = list(self.index_dic.keys())
|
self.pids = list(self.index_dic.keys())
|
||||||
|
|
||||||
|
@ -71,3 +78,27 @@ class RandomIdentitySampler(Sampler):
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.length
|
return self.length
|
||||||
|
|
||||||
|
# class RandomIdentitySampler(Sampler):
|
||||||
|
# def __init__(self, data_source, num_instances=4):
|
||||||
|
# self.data_source = data_source
|
||||||
|
# self.num_instances = num_instances
|
||||||
|
# self.index_dic = defaultdict(list)
|
||||||
|
# for index, (_, pid) in enumerate(data_source):
|
||||||
|
# self.index_dic[pid].append(index)
|
||||||
|
# self.pids = list(self.index_dic.keys())
|
||||||
|
# self.num_identities = len(self.pids)
|
||||||
|
#
|
||||||
|
# def __iter__(self):
|
||||||
|
# indices = torch.randperm(self.num_identities)
|
||||||
|
# ret = []
|
||||||
|
# for i in indices:
|
||||||
|
# pid = self.pids[i]
|
||||||
|
# t = self.index_dic[pid]
|
||||||
|
# replace = False if len(t) >= self.num_instances else True
|
||||||
|
# t = np.random.choice(t, size=self.num_instances, replace=replace)
|
||||||
|
# ret.extend(t)
|
||||||
|
# return iter(ret)
|
||||||
|
#
|
||||||
|
# def __len__(self):
|
||||||
|
# return self.num_identities * self.num_instances
|
||||||
|
|
|
@ -4,4 +4,4 @@
|
||||||
@contact: sherlockliao01@gmail.com
|
@contact: sherlockliao01@gmail.com
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .build import build_transforms
|
from .transforms import *
|
||||||
|
|
|
@ -1,31 +0,0 @@
|
||||||
# encoding: utf-8
|
|
||||||
"""
|
|
||||||
@author: liaoxingyu
|
|
||||||
@contact: liaoxingyu2@jd.com
|
|
||||||
"""
|
|
||||||
|
|
||||||
import torchvision.transforms as T
|
|
||||||
|
|
||||||
from .transforms import RandomErasing
|
|
||||||
|
|
||||||
|
|
||||||
def build_transforms(cfg, is_train=True):
|
|
||||||
normalize_transform = T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD)
|
|
||||||
if is_train:
|
|
||||||
transform = T.Compose([
|
|
||||||
T.Resize(cfg.INPUT.SIZE_TRAIN),
|
|
||||||
T.RandomHorizontalFlip(p=cfg.INPUT.PROB),
|
|
||||||
T.Pad(cfg.INPUT.PADDING),
|
|
||||||
T.RandomCrop(cfg.INPUT.SIZE_TRAIN),
|
|
||||||
T.ToTensor(),
|
|
||||||
normalize_transform,
|
|
||||||
RandomErasing(probability=cfg.INPUT.PROB, mean=cfg.INPUT.PIXEL_MEAN)
|
|
||||||
])
|
|
||||||
else:
|
|
||||||
transform = T.Compose([
|
|
||||||
T.Resize(cfg.INPUT.SIZE_TEST),
|
|
||||||
T.ToTensor(),
|
|
||||||
normalize_transform
|
|
||||||
])
|
|
||||||
|
|
||||||
return transform
|
|
|
@ -4,52 +4,34 @@
|
||||||
@contact: liaoxingyu2@jd.com
|
@contact: liaoxingyu2@jd.com
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import math
|
from fastai.vision import *
|
||||||
import random
|
from fastai.vision.image import *
|
||||||
|
|
||||||
|
|
||||||
class RandomErasing(object):
|
def _random_erasing(x, probability=0.5, sl=0.02, sh=0.4, r1=0.3,
|
||||||
""" Randomly selects a rectangle region in an image and erases its pixels.
|
mean=(np.array(imagenet_stats[1]) + 1) * imagenet_stats[0]):
|
||||||
'Random Erasing Data Augmentation' by Zhong et al.
|
if random.uniform(0, 1) > probability:
|
||||||
See https://arxiv.org/pdf/1708.04896.pdf
|
return x
|
||||||
Args:
|
|
||||||
probability: The probability that the Random Erasing operation will be performed.
|
|
||||||
sl: Minimum proportion of erased area against input image.
|
|
||||||
sh: Maximum proportion of erased area against input image.
|
|
||||||
r1: Minimum aspect ratio of erased area.
|
|
||||||
mean: Erasing value.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=(0.4914, 0.4822, 0.4465)):
|
for attempt in range(100):
|
||||||
self.probability = probability
|
area = x.size()[1] * x.size()[2]
|
||||||
self.mean = mean
|
|
||||||
self.sl = sl
|
|
||||||
self.sh = sh
|
|
||||||
self.r1 = r1
|
|
||||||
|
|
||||||
def __call__(self, img):
|
target_area = random.uniform(sl, sh) * area
|
||||||
|
aspect_ratio = random.uniform(r1, 1 / r1)
|
||||||
|
|
||||||
if random.uniform(0, 1) > self.probability:
|
h = int(round(math.sqrt(target_area * aspect_ratio)))
|
||||||
return img
|
w = int(round(math.sqrt(target_area / aspect_ratio)))
|
||||||
|
|
||||||
for attempt in range(100):
|
if w < x.size()[2] and h < x.size()[1]:
|
||||||
area = img.size()[1] * img.size()[2]
|
x1 = random.randint(0, x.size()[1] - h)
|
||||||
|
y1 = random.randint(0, x.size()[2] - w)
|
||||||
|
if x.size()[0] == 3:
|
||||||
|
x[0, x1:x1 + h, y1:y1 + w] = mean[0]
|
||||||
|
x[1, x1:x1 + h, y1:y1 + w] = mean[1]
|
||||||
|
x[2, x1:x1 + h, y1:y1 + w] = mean[2]
|
||||||
|
else:
|
||||||
|
x[0, x1:x1 + h, y1:y1 + w] = mean[0]
|
||||||
|
return x
|
||||||
|
|
||||||
target_area = random.uniform(self.sl, self.sh) * area
|
|
||||||
aspect_ratio = random.uniform(self.r1, 1 / self.r1)
|
|
||||||
|
|
||||||
h = int(round(math.sqrt(target_area * aspect_ratio)))
|
RandomErasing = TfmPixel(_random_erasing)
|
||||||
w = int(round(math.sqrt(target_area / aspect_ratio)))
|
|
||||||
|
|
||||||
if w < img.size()[2] and h < img.size()[1]:
|
|
||||||
x1 = random.randint(0, img.size()[1] - h)
|
|
||||||
y1 = random.randint(0, img.size()[2] - w)
|
|
||||||
if img.size()[0] == 3:
|
|
||||||
img[0, x1:x1 + h, y1:y1 + w] = self.mean[0]
|
|
||||||
img[1, x1:x1 + h, y1:y1 + w] = self.mean[1]
|
|
||||||
img[2, x1:x1 + h, y1:y1 + w] = self.mean[2]
|
|
||||||
else:
|
|
||||||
img[0, x1:x1 + h, y1:y1 + w] = self.mean[0]
|
|
||||||
return img
|
|
||||||
|
|
||||||
return img
|
|
||||||
|
|
|
@ -4,147 +4,88 @@
|
||||||
@contact: sherlockliao01@gmail.com
|
@contact: sherlockliao01@gmail.com
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
from data.datasets.eval_reid import evaluate
|
||||||
|
from fastai.vision import *
|
||||||
import torch
|
|
||||||
from ignite.engine import Engine, Events
|
|
||||||
from ignite.handlers import ModelCheckpoint, Timer
|
|
||||||
from ignite.metrics import RunningAverage
|
|
||||||
|
|
||||||
from utils.reid_metric import R1_mAP
|
|
||||||
|
|
||||||
|
|
||||||
def create_supervised_trainer(model, optimizer, loss_fn,
|
class LrScheduler(LearnerCallback):
|
||||||
device=None):
|
def __init__(self, learn: Learner, lr_sched: Scheduler):
|
||||||
"""
|
super().__init__(learn)
|
||||||
Factory function for creating a trainer for supervised models
|
self.lr_sched = lr_sched
|
||||||
|
|
||||||
Args:
|
def on_train_begin(self, **kwargs: Any) -> None:
|
||||||
model (`torch.nn.Module`): the model to train
|
self.opt = self.learn.opt
|
||||||
optimizer (`torch.optim.Optimizer`): the optimizer to use
|
|
||||||
loss_fn (torch.nn loss function): the loss function to use
|
|
||||||
device (str, optional): device type specification (default: None).
|
|
||||||
Applies to both model and batches.
|
|
||||||
|
|
||||||
Returns:
|
def on_epoch_begin(self, **kwargs: Any) -> None:
|
||||||
Engine: a trainer engine with supervised update function
|
self.opt.lr = self.lr_sched.step()
|
||||||
"""
|
|
||||||
if device:
|
|
||||||
model.to(device)
|
|
||||||
|
|
||||||
def _update(engine, batch):
|
|
||||||
model.train()
|
|
||||||
optimizer.zero_grad()
|
|
||||||
img, target = batch
|
|
||||||
img = img.cuda()
|
|
||||||
target = target.cuda()
|
|
||||||
score, feat = model(img)
|
|
||||||
loss = loss_fn(score, feat, target)
|
|
||||||
loss.backward()
|
|
||||||
optimizer.step()
|
|
||||||
# compute acc
|
|
||||||
acc = (score.max(1)[1] == target).float().mean()
|
|
||||||
return loss.item(), acc.item()
|
|
||||||
|
|
||||||
return Engine(_update)
|
|
||||||
|
|
||||||
|
|
||||||
def create_supervised_evaluator(model, metrics,
|
class TestModel(LearnerCallback):
|
||||||
device=None):
|
def __init__(self, learn: Learner, test_labels: Iterator, eval_period: int, num_query: int, output_dir: Path):
|
||||||
"""
|
super().__init__(learn)
|
||||||
Factory function for creating an evaluator for supervised models
|
self.test_dl = learn.data.test_dl
|
||||||
|
self.eval_period = eval_period
|
||||||
|
self.output_dir = output_dir
|
||||||
|
self.num_query = num_query
|
||||||
|
pids = []
|
||||||
|
camids = []
|
||||||
|
for i in test_labels:
|
||||||
|
pids.append(i[0])
|
||||||
|
camids.append(i[1])
|
||||||
|
self.q_pids = np.asarray(pids[:num_query])
|
||||||
|
self.q_camids = np.asarray(camids[:num_query])
|
||||||
|
self.g_pids = np.asarray(pids[num_query:])
|
||||||
|
self.g_camids = np.asarray(camids[num_query:])
|
||||||
|
|
||||||
Args:
|
def on_epoch_end(self, epoch, **kwargs: Any) -> None:
|
||||||
model (`torch.nn.Module`): the model to train
|
# test model performance
|
||||||
metrics (dict of str - :class:`ignite.metrics.Metric`): a map of metric names to Metrics
|
if (epoch + 1) % self.eval_period == 0:
|
||||||
device (str, optional): device type specification (default: None).
|
print('Testing ...')
|
||||||
Applies to both model and batches.
|
feats, pids, camids = [], [], []
|
||||||
Returns:
|
self.learn.model.eval()
|
||||||
Engine: an evaluator engine with supervised inference function
|
with torch.no_grad():
|
||||||
"""
|
for imgs, _ in self.test_dl:
|
||||||
if device:
|
feat = self.learn.model(imgs)
|
||||||
model.to(device)
|
feats.append(feat)
|
||||||
|
|
||||||
def _inference(engine, batch):
|
feats = torch.cat(feats, dim=0)
|
||||||
model.eval()
|
# query
|
||||||
with torch.no_grad():
|
qf = feats[:self.num_query]
|
||||||
data, pids, camids = batch
|
# gallery
|
||||||
data = data.cuda()
|
gf = feats[self.num_query:]
|
||||||
feat = model(data)
|
m, n = qf.shape[0], gf.shape[0]
|
||||||
return feat, pids, camids
|
distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \
|
||||||
|
torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
|
||||||
engine = Engine(_inference)
|
distmat.addmm_(1, -2, qf, gf.t())
|
||||||
|
distmat = to_np(distmat)
|
||||||
for name, metric in metrics.items():
|
cmc, mAP = evaluate(distmat, self.q_pids, self.g_pids, self.q_camids, self.g_camids)
|
||||||
metric.attach(engine, name)
|
print("Test Results - Epoch: {}".format(epoch + 1))
|
||||||
|
print("mAP: {:.1%}".format(mAP))
|
||||||
return engine
|
for r in [1, 5, 10]:
|
||||||
|
print("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1]))
|
||||||
|
self.learn.save(self.output_dir / 'reid_model_{}'.format(epoch))
|
||||||
|
|
||||||
|
|
||||||
def do_train(
|
def do_train(
|
||||||
cfg,
|
cfg,
|
||||||
model,
|
model,
|
||||||
train_loader,
|
data_bunch,
|
||||||
val_loader,
|
test_labels,
|
||||||
optimizer,
|
opt_func,
|
||||||
scheduler,
|
lr_sched,
|
||||||
loss_fn,
|
loss_func,
|
||||||
num_query
|
num_query
|
||||||
):
|
):
|
||||||
log_period = cfg.SOLVER.LOG_PERIOD
|
|
||||||
checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
|
|
||||||
eval_period = cfg.SOLVER.EVAL_PERIOD
|
eval_period = cfg.SOLVER.EVAL_PERIOD
|
||||||
output_dir = cfg.OUTPUT_DIR
|
output_dir = cfg.OUTPUT_DIR
|
||||||
device = cfg.MODEL.DEVICE
|
|
||||||
epochs = cfg.SOLVER.MAX_EPOCHS
|
epochs = cfg.SOLVER.MAX_EPOCHS
|
||||||
|
|
||||||
logger = logging.getLogger("reid_baseline.train")
|
print("Start training")
|
||||||
logger.info("Start training")
|
|
||||||
trainer = create_supervised_trainer(model, optimizer, loss_fn, device=device)
|
|
||||||
evaluator = create_supervised_evaluator(model, metrics={'r1_mAP': R1_mAP(num_query)}, device=device)
|
|
||||||
checkpointer = ModelCheckpoint(output_dir, cfg.MODEL.NAME, checkpoint_period, n_saved=10, require_empty=False)
|
|
||||||
timer = Timer(average=True)
|
|
||||||
|
|
||||||
trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'model': model.state_dict(),
|
learn = Learner(data_bunch, model, opt_func=opt_func, loss_func=loss_func, true_wd=False)
|
||||||
'optimizer': optimizer.state_dict()})
|
|
||||||
timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED,
|
|
||||||
pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED)
|
|
||||||
|
|
||||||
# average metric to attach on trainer
|
lr_sched_cb = LrScheduler(learn, lr_sched)
|
||||||
RunningAverage(output_transform=lambda x: x[0]).attach(trainer, 'avg_loss')
|
testmodel_cb = TestModel(learn, test_labels, eval_period, num_query, Path(output_dir))
|
||||||
RunningAverage(output_transform=lambda x: x[1]).attach(trainer, 'avg_acc')
|
|
||||||
|
|
||||||
@trainer.on(Events.EPOCH_STARTED)
|
learn.fit(epochs, callbacks=[lr_sched_cb, testmodel_cb],
|
||||||
def adjust_learning_rate(engine):
|
lr=cfg.SOLVER.BASE_LR, wd=cfg.SOLVER.WEIGHT_DECAY)
|
||||||
scheduler.step()
|
|
||||||
|
|
||||||
@trainer.on(Events.ITERATION_COMPLETED)
|
|
||||||
def log_training_loss(engine):
|
|
||||||
iter = (engine.state.iteration - 1) % len(train_loader) + 1
|
|
||||||
|
|
||||||
if iter % log_period == 0:
|
|
||||||
logger.info("Epoch[{}] Iteration[{}/{}] Loss: {:.3f}, Acc: {:.3f}, Base Lr: {:.2e}"
|
|
||||||
.format(engine.state.epoch, iter, len(train_loader),
|
|
||||||
engine.state.metrics['avg_loss'], engine.state.metrics['avg_acc'],
|
|
||||||
scheduler.get_lr()[0]))
|
|
||||||
|
|
||||||
# adding handlers using `trainer.on` decorator API
|
|
||||||
@trainer.on(Events.EPOCH_COMPLETED)
|
|
||||||
def print_times(engine):
|
|
||||||
logger.info('Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]'
|
|
||||||
.format(engine.state.epoch, timer.value() * timer.step_count,
|
|
||||||
train_loader.batch_size / timer.value()))
|
|
||||||
logger.info('-' * 10)
|
|
||||||
timer.reset()
|
|
||||||
|
|
||||||
@trainer.on(Events.EPOCH_COMPLETED)
|
|
||||||
def log_validation_results(engine):
|
|
||||||
if engine.state.epoch % eval_period == 0:
|
|
||||||
evaluator.run(val_loader)
|
|
||||||
cmc, mAP = evaluator.state.metrics['r1_mAP']
|
|
||||||
logger.info("Validation Results - Epoch: {}".format(engine.state.epoch))
|
|
||||||
logger.info("mAP: {:.1%}".format(mAP))
|
|
||||||
for r in [1, 5, 10]:
|
|
||||||
logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1]))
|
|
||||||
|
|
||||||
trainer.run(train_loader, max_epochs=epochs)
|
|
||||||
|
|
|
@ -14,13 +14,16 @@ def make_loss(cfg):
|
||||||
triplet = TripletLoss(cfg.SOLVER.MARGIN) # triplet loss
|
triplet = TripletLoss(cfg.SOLVER.MARGIN) # triplet loss
|
||||||
|
|
||||||
if sampler == 'softmax':
|
if sampler == 'softmax':
|
||||||
def loss_func(score, feat, target):
|
def loss_func(out, target):
|
||||||
|
score, feat = out
|
||||||
return F.cross_entropy(score, target)
|
return F.cross_entropy(score, target)
|
||||||
elif cfg.DATALOADER.SAMPLER == 'triplet':
|
elif cfg.DATALOADER.SAMPLER == 'triplet':
|
||||||
def loss_func(score, feat, target):
|
def loss_func(out, target):
|
||||||
|
score, feat = out
|
||||||
return triplet(feat, target)[0]
|
return triplet(feat, target)[0]
|
||||||
elif cfg.DATALOADER.SAMPLER == 'softmax_triplet':
|
elif cfg.DATALOADER.SAMPLER == 'softmax_triplet':
|
||||||
def loss_func(score, feat, target):
|
def loss_func(out, target):
|
||||||
|
score, feat = out
|
||||||
return F.cross_entropy(score, target) + triplet(feat, target)[0]
|
return F.cross_entropy(score, target) + triplet(feat, target)[0]
|
||||||
else:
|
else:
|
||||||
print('expected sampler should be softmax, triplet or softmax_triplet, '
|
print('expected sampler should be softmax, triplet or softmax_triplet, '
|
||||||
|
|
|
@ -4,5 +4,4 @@
|
||||||
@contact: sherlockliao01@gmail.com
|
@contact: sherlockliao01@gmail.com
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .build import make_optimizer
|
from .build import *
|
||||||
from .lr_scheduler import WarmupMultiStepLR
|
|
||||||
|
|
|
@ -4,22 +4,12 @@
|
||||||
@contact: sherlockliao01@gmail.com
|
@contact: sherlockliao01@gmail.com
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
from fastai.vision import *
|
||||||
|
|
||||||
|
|
||||||
def make_optimizer(cfg, model):
|
def make_optimizer(cfg):
|
||||||
params = []
|
|
||||||
for key, value in model.named_parameters():
|
|
||||||
if not value.requires_grad:
|
|
||||||
continue
|
|
||||||
lr = cfg.SOLVER.BASE_LR
|
|
||||||
weight_decay = cfg.SOLVER.WEIGHT_DECAY
|
|
||||||
if "bias" in key:
|
|
||||||
lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR
|
|
||||||
weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS
|
|
||||||
params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]
|
|
||||||
if cfg.SOLVER.OPTIMIZER_NAME == 'SGD':
|
if cfg.SOLVER.OPTIMIZER_NAME == 'SGD':
|
||||||
optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params, momentum=cfg.SOLVER.MOMENTUM)
|
opt = partial(getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME), momentum=cfg.SOLVER.MOMENTUM)
|
||||||
else:
|
else:
|
||||||
optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params)
|
opt = partial(getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME))
|
||||||
return optimizer
|
return opt
|
||||||
|
|
|
@ -1,56 +0,0 @@
|
||||||
# encoding: utf-8
|
|
||||||
"""
|
|
||||||
@author: liaoxingyu
|
|
||||||
@contact: sherlockliao01@gmail.com
|
|
||||||
"""
|
|
||||||
from bisect import bisect_right
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
# FIXME ideally this would be achieved with a CombinedLRScheduler,
|
|
||||||
# separating MultiStepLR with WarmupLR
|
|
||||||
# but the current LRScheduler design doesn't allow it
|
|
||||||
|
|
||||||
class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
optimizer,
|
|
||||||
milestones,
|
|
||||||
gamma=0.1,
|
|
||||||
warmup_factor=1.0 / 3,
|
|
||||||
warmup_iters=500,
|
|
||||||
warmup_method="linear",
|
|
||||||
last_epoch=-1,
|
|
||||||
):
|
|
||||||
if not list(milestones) == sorted(milestones):
|
|
||||||
raise ValueError(
|
|
||||||
"Milestones should be a list of" " increasing integers. Got {}",
|
|
||||||
milestones,
|
|
||||||
)
|
|
||||||
|
|
||||||
if warmup_method not in ("constant", "linear"):
|
|
||||||
raise ValueError(
|
|
||||||
"Only 'constant' or 'linear' warmup_method accepted"
|
|
||||||
"got {}".format(warmup_method)
|
|
||||||
)
|
|
||||||
self.milestones = milestones
|
|
||||||
self.gamma = gamma
|
|
||||||
self.warmup_factor = warmup_factor
|
|
||||||
self.warmup_iters = warmup_iters
|
|
||||||
self.warmup_method = warmup_method
|
|
||||||
super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch)
|
|
||||||
|
|
||||||
def get_lr(self):
|
|
||||||
warmup_factor = 1
|
|
||||||
if self.last_epoch < self.warmup_iters:
|
|
||||||
if self.warmup_method == "constant":
|
|
||||||
warmup_factor = self.warmup_factor
|
|
||||||
elif self.warmup_method == "linear":
|
|
||||||
alpha = self.last_epoch / self.warmup_iters
|
|
||||||
warmup_factor = self.warmup_factor * (1 - alpha) + alpha
|
|
||||||
return [
|
|
||||||
base_lr
|
|
||||||
* warmup_factor
|
|
||||||
* self.gamma ** bisect_right(self.milestones, self.last_epoch)
|
|
||||||
for base_lr in self.base_lrs
|
|
||||||
]
|
|
|
@ -1,6 +1,6 @@
|
||||||
# encoding: utf-8
|
# encoding: utf-8
|
||||||
"""
|
"""
|
||||||
@author: sherlock
|
@author: l1aoxingyu
|
||||||
@contact: sherlockliao01@gmail.com
|
@contact: sherlockliao01@gmail.com
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -14,10 +14,9 @@ from torch.backends import cudnn
|
||||||
|
|
||||||
sys.path.append('.')
|
sys.path.append('.')
|
||||||
from config import cfg
|
from config import cfg
|
||||||
from data import make_data_loader
|
from data import get_data_bunch
|
||||||
from engine.inference import inference
|
from engine.inference import inference
|
||||||
from modeling import build_model
|
from modeling import build_model
|
||||||
from utils.logger import setup_logger
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
@ -54,11 +53,11 @@ def main():
|
||||||
|
|
||||||
cudnn.benchmark = True
|
cudnn.benchmark = True
|
||||||
|
|
||||||
train_loader, val_loader, num_query, num_classes = make_data_loader(cfg)
|
train_databunch, test_databunch, num_query = get_data_bunch(cfg)
|
||||||
model = build_model(cfg, num_classes)
|
model = build_model(cfg, train_databunch.c)
|
||||||
model.load_state_dict(torch.load(cfg.TEST.WEIGHT))
|
model.load_state_dict(torch.load(cfg.TEST.WEIGHT))
|
||||||
|
|
||||||
inference(cfg, model, val_loader, num_query)
|
inference(cfg, model, test_databunch, num_query)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -5,43 +5,51 @@
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
|
||||||
import sys
|
import sys
|
||||||
|
from bisect import bisect_right
|
||||||
|
|
||||||
from torch.backends import cudnn
|
from torch.backends import cudnn
|
||||||
|
|
||||||
sys.path.append('.')
|
sys.path.append('.')
|
||||||
from config import cfg
|
from config import cfg
|
||||||
from data import make_data_loader
|
from data import get_data_bunch
|
||||||
from engine.trainer import do_train
|
from engine.trainer import do_train
|
||||||
from modeling import build_model
|
|
||||||
from layers import make_loss
|
from layers import make_loss
|
||||||
from solver import make_optimizer, WarmupMultiStepLR
|
from modeling import build_model
|
||||||
|
from utils.logger import Logger
|
||||||
from utils.logger import setup_logger
|
from fastai.vision import *
|
||||||
|
|
||||||
|
|
||||||
def train(cfg):
|
def train(cfg):
|
||||||
# prepare dataset
|
# prepare dataset
|
||||||
train_loader, val_loader, num_query, num_classes = make_data_loader(cfg)
|
data_bunch, test_labels, num_query = get_data_bunch(cfg)
|
||||||
# prepare model
|
|
||||||
model = build_model(cfg, num_classes)
|
|
||||||
|
|
||||||
optimizer = make_optimizer(cfg, model)
|
# prepare model
|
||||||
scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR,
|
model = build_model(cfg, data_bunch.c)
|
||||||
cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD)
|
|
||||||
|
opt_func = partial(torch.optim.Adam)
|
||||||
|
|
||||||
|
def warmup_multistep(start: float, end: float, pct: float) -> float:
|
||||||
|
warmup_factor = 1
|
||||||
|
gamma = cfg.SOLVER.GAMMA
|
||||||
|
milestones = [1.0 * s / cfg.SOLVER.MAX_EPOCHS for s in cfg.SOLVER.STEPS]
|
||||||
|
warmup_iter = 1.0 * cfg.SOLVER.WARMUP_ITERS / cfg.SOLVER.MAX_EPOCHS
|
||||||
|
if pct < warmup_iter:
|
||||||
|
alpha = pct / warmup_iter
|
||||||
|
warmup_factor = cfg.SOLVER.WARMUP_FACTOR * (1 - alpha) + alpha
|
||||||
|
return start * warmup_factor * gamma ** bisect_right(milestones, pct)
|
||||||
|
|
||||||
|
lr_sched = Scheduler((cfg.SOLVER.BASE_LR, 0), cfg.SOLVER.MAX_EPOCHS, warmup_multistep)
|
||||||
|
|
||||||
loss_func = make_loss(cfg)
|
loss_func = make_loss(cfg)
|
||||||
|
|
||||||
arguments = {}
|
|
||||||
|
|
||||||
do_train(
|
do_train(
|
||||||
cfg,
|
cfg,
|
||||||
model,
|
model,
|
||||||
train_loader,
|
data_bunch,
|
||||||
val_loader,
|
test_labels,
|
||||||
optimizer,
|
opt_func,
|
||||||
scheduler,
|
lr_sched,
|
||||||
loss_func,
|
loss_func,
|
||||||
num_query
|
num_query
|
||||||
)
|
)
|
||||||
|
@ -64,20 +72,15 @@ def main():
|
||||||
cfg.merge_from_list(args.opts)
|
cfg.merge_from_list(args.opts)
|
||||||
cfg.freeze()
|
cfg.freeze()
|
||||||
|
|
||||||
output_dir = cfg.OUTPUT_DIR
|
sys.stdout = Logger(os.path.join(cfg.OUTPUT_DIR, 'log.txt'))
|
||||||
if output_dir and not os.path.exists(output_dir):
|
print(args)
|
||||||
os.makedirs(output_dir)
|
|
||||||
|
|
||||||
logger = setup_logger("reid_baseline", output_dir, 0)
|
|
||||||
logger.info("Using {} GPUS".format(num_gpus))
|
|
||||||
logger.info(args)
|
|
||||||
|
|
||||||
if args.config_file != "":
|
if args.config_file != "":
|
||||||
logger.info("Loaded configuration file {}".format(args.config_file))
|
print("Loaded configuration file {}".format(args.config_file))
|
||||||
with open(args.config_file, 'r') as cf:
|
with open(args.config_file, 'r') as cf:
|
||||||
config_str = "\n" + cf.read()
|
config_str = "\n" + cf.read()
|
||||||
logger.info(config_str)
|
print(config_str)
|
||||||
logger.info("Running with config:\n{}".format(cfg))
|
print("Running with config:\n{}".format(cfg))
|
||||||
|
|
||||||
cudnn.benchmark = True
|
cudnn.benchmark = True
|
||||||
train(cfg)
|
train(cfg)
|
||||||
|
|
|
@ -4,27 +4,48 @@
|
||||||
@contact: sherlockliao01@gmail.com
|
@contact: sherlockliao01@gmail.com
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import errno
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
|
||||||
def setup_logger(name, save_dir, distributed_rank):
|
def mkdir_if_missing(dir_path):
|
||||||
logger = logging.getLogger(name)
|
try:
|
||||||
logger.setLevel(logging.DEBUG)
|
os.makedirs(dir_path)
|
||||||
# don't log results for the non-master process
|
except OSError as e:
|
||||||
if distributed_rank > 0:
|
if e.errno != errno.EEXIST:
|
||||||
return logger
|
raise
|
||||||
ch = logging.StreamHandler(stream=sys.stdout)
|
|
||||||
ch.setLevel(logging.DEBUG)
|
|
||||||
formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s")
|
|
||||||
ch.setFormatter(formatter)
|
|
||||||
logger.addHandler(ch)
|
|
||||||
|
|
||||||
if save_dir:
|
|
||||||
fh = logging.FileHandler(os.path.join(save_dir, "log.txt"), mode='w')
|
|
||||||
fh.setLevel(logging.DEBUG)
|
|
||||||
fh.setFormatter(formatter)
|
|
||||||
logger.addHandler(fh)
|
|
||||||
|
|
||||||
return logger
|
class Logger(object):
|
||||||
|
def __init__(self, fpath=None):
|
||||||
|
self.console = sys.stdout
|
||||||
|
self.file = None
|
||||||
|
if fpath is not None:
|
||||||
|
mkdir_if_missing(os.path.dirname(fpath))
|
||||||
|
self.file = open(fpath, 'w')
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
self.close()
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __exit__(self, *args):
|
||||||
|
self.close()
|
||||||
|
|
||||||
|
def write(self, msg):
|
||||||
|
self.console.write(msg)
|
||||||
|
if self.file is not None:
|
||||||
|
self.file.write(msg)
|
||||||
|
|
||||||
|
def flush(self):
|
||||||
|
self.console.flush()
|
||||||
|
if self.file is not None:
|
||||||
|
self.file.flush()
|
||||||
|
os.fsync(self.file.fileno())
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self.console.close()
|
||||||
|
if self.file is not None:
|
||||||
|
self.file.close()
|
||||||
|
|
|
@ -1,48 +0,0 @@
|
||||||
# encoding: utf-8
|
|
||||||
"""
|
|
||||||
@author: liaoxingyu
|
|
||||||
@contact: sherlockliao01@gmail.com
|
|
||||||
"""
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from ignite.metrics import Metric
|
|
||||||
|
|
||||||
from data.datasets.eval_reid import eval_func
|
|
||||||
|
|
||||||
|
|
||||||
class R1_mAP(Metric):
|
|
||||||
def __init__(self, num_query, max_rank=50):
|
|
||||||
super(R1_mAP, self).__init__()
|
|
||||||
self.num_query = num_query
|
|
||||||
self.max_rank = max_rank
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
self.feats = []
|
|
||||||
self.pids = []
|
|
||||||
self.camids = []
|
|
||||||
|
|
||||||
def update(self, output):
|
|
||||||
feat, pid, camid = output
|
|
||||||
self.feats.append(feat)
|
|
||||||
self.pids.extend(np.asarray(pid))
|
|
||||||
self.camids.extend(np.asarray(camid))
|
|
||||||
|
|
||||||
def compute(self):
|
|
||||||
feats = torch.cat(self.feats, dim=0)
|
|
||||||
# query
|
|
||||||
qf = feats[:self.num_query]
|
|
||||||
q_pids = np.asarray(self.pids[:self.num_query])
|
|
||||||
q_camids = np.asarray(self.camids[:self.num_query])
|
|
||||||
# gallery
|
|
||||||
gf = feats[self.num_query:]
|
|
||||||
g_pids = np.asarray(self.pids[self.num_query:])
|
|
||||||
g_camids = np.asarray(self.camids[self.num_query:])
|
|
||||||
m, n = qf.shape[0], gf.shape[0]
|
|
||||||
distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \
|
|
||||||
torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
|
|
||||||
distmat.addmm_(1, -2, qf, gf.t())
|
|
||||||
distmat = distmat.cpu().numpy()
|
|
||||||
cmc, mAP = eval_func(distmat, q_pids, g_pids, q_camids, g_camids)
|
|
||||||
|
|
||||||
return cmc, mAP
|
|
Loading…
Reference in New Issue