mirror of https://github.com/JDAI-CV/fast-reid.git
update fastreid v1.2 readme and changelog
parent
44cee30dfc
commit
55300730e1
14
CHANGELOG.md
14
CHANGELOG.md
|
@ -1,5 +1,19 @@
|
|||
# Changelog
|
||||
|
||||
### v1.2 (06/04/2021)
|
||||
|
||||
#### New Features
|
||||
|
||||
- Multiple machine training support
|
||||
- [RepVGG](https://github.com/DingXiaoH/RepVGG) backbone
|
||||
- [Partial FC](projects/FastFace)
|
||||
|
||||
#### Improvements
|
||||
|
||||
- Torch2trt pipeline
|
||||
- Decouple linear transforms and softmax
|
||||
- config decorator
|
||||
|
||||
### v1.1 (29/01/2021)
|
||||
|
||||
#### New Features
|
||||
|
|
|
@ -8,6 +8,7 @@ FastReID is a research platform that implements state-of-the-art re-identificati
|
|||
|
||||
## What's New
|
||||
|
||||
- [Apr 2021] Partial FC supported in [FastFace](projects/FastFace)!
|
||||
- [Jan 2021] TRT network definition APIs in [FastRT](projects/FastRT) has been released!
|
||||
Thanks for [Darren](https://github.com/TCHeish)'s contribution.
|
||||
- [Jan 2021] NAIC20(reid track) [1-st solution](projects/NAIC20) based on fastreid has been released!
|
||||
|
|
|
@ -16,8 +16,8 @@ SOLVER:
|
|||
|
||||
IMS_PER_BATCH: 64
|
||||
MAX_EPOCH: 60
|
||||
WARMUP_ITERS: 2000
|
||||
FREEZE_ITERS: 1000
|
||||
WARMUP_ITERS: 3000
|
||||
FREEZE_ITERS: 3000
|
||||
|
||||
CHECKPOINT_PERIOD: 10
|
||||
|
||||
|
@ -25,6 +25,9 @@ DATASETS:
|
|||
NAMES: ("VeRi",)
|
||||
TESTS: ("VeRi",)
|
||||
|
||||
DATALOADER:
|
||||
SAMPLER_TRAIN: BalancedIdentitySampler
|
||||
|
||||
TEST:
|
||||
EVAL_PERIOD: 10
|
||||
IMS_PER_BATCH: 256
|
||||
|
|
|
@ -5,4 +5,4 @@
|
|||
"""
|
||||
|
||||
|
||||
__version__ = "1.0.0"
|
||||
__version__ = "1.3"
|
||||
|
|
|
@ -206,6 +206,9 @@ _C.DATALOADER.SAMPLER_TRAIN = "TrainingSampler"
|
|||
_C.DATALOADER.NUM_INSTANCE = 4
|
||||
_C.DATALOADER.NUM_WORKERS = 8
|
||||
|
||||
# For set re-weight
|
||||
_C.DATALOADER.SET_WEIGHT = []
|
||||
|
||||
# ---------------------------------------------------------------------------- #
|
||||
# Solver
|
||||
# ---------------------------------------------------------------------------- #
|
||||
|
|
|
@ -53,6 +53,9 @@ def _train_loader_from_config(cfg, *, train_set=None, transforms=None, sampler=N
|
|||
sampler = samplers.NaiveIdentitySampler(train_set.img_items, mini_batch_size, num_instance)
|
||||
elif sampler_name == "BalancedIdentitySampler":
|
||||
sampler = samplers.BalancedIdentitySampler(train_set.img_items, mini_batch_size, num_instance)
|
||||
elif sampler_name == "SetReWeightSampler":
|
||||
set_weight = cfg.DATALOADER.SET_WEIGHT
|
||||
sampler = samplers.SetReWeightSampler(train_set.img_items, mini_batch_size, num_instance, set_weight)
|
||||
else:
|
||||
raise ValueError("Unknown training sampler: {}".format(sampler_name))
|
||||
|
||||
|
|
|
@ -4,12 +4,13 @@
|
|||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from .triplet_sampler import BalancedIdentitySampler, NaiveIdentitySampler
|
||||
from .triplet_sampler import BalancedIdentitySampler, NaiveIdentitySampler, SetReWeightSampler
|
||||
from .data_sampler import TrainingSampler, InferenceSampler
|
||||
|
||||
__all__ = [
|
||||
"BalancedIdentitySampler",
|
||||
"NaiveIdentitySampler",
|
||||
"SetReWeightSampler",
|
||||
"TrainingSampler",
|
||||
"InferenceSampler"
|
||||
]
|
||||
|
|
|
@ -119,6 +119,82 @@ class BalancedIdentitySampler(Sampler):
|
|||
batch_indices = []
|
||||
|
||||
|
||||
class SetReWeightSampler(Sampler):
|
||||
def __init__(self, data_source: str, mini_batch_size: int, num_instances: int, set_weight: list,
|
||||
seed: Optional[int] = None):
|
||||
self.data_source = data_source
|
||||
self.num_instances = num_instances
|
||||
self.num_pids_per_batch = mini_batch_size // self.num_instances
|
||||
|
||||
self.set_weight = set_weight
|
||||
|
||||
self._rank = comm.get_rank()
|
||||
self._world_size = comm.get_world_size()
|
||||
self.batch_size = mini_batch_size * self._world_size
|
||||
|
||||
assert self.batch_size % (sum(self.set_weight) * self.num_instances) == 0 and \
|
||||
self.batch_size > sum(
|
||||
self.set_weight) * self.num_instances, "Batch size must be divisible by the sum set weight"
|
||||
|
||||
self.index_pid = dict()
|
||||
self.pid_cam = defaultdict(list)
|
||||
self.pid_index = defaultdict(list)
|
||||
|
||||
self.cam_pid = defaultdict(list)
|
||||
|
||||
for index, info in enumerate(data_source):
|
||||
pid = info[1]
|
||||
camid = info[2]
|
||||
self.index_pid[index] = pid
|
||||
self.pid_cam[pid].append(camid)
|
||||
self.pid_index[pid].append(index)
|
||||
self.cam_pid[camid].append(pid)
|
||||
|
||||
# Get sampler prob for each cam
|
||||
self.set_pid_prob = defaultdict(list)
|
||||
for camid, pid_list in self.cam_pid.items():
|
||||
index_per_pid = []
|
||||
for pid in pid_list:
|
||||
index_per_pid.append(len(self.pid_index[pid]))
|
||||
cam_image_number = sum(index_per_pid)
|
||||
prob = [i / cam_image_number for i in index_per_pid]
|
||||
self.set_pid_prob[camid] = prob
|
||||
|
||||
self.pids = sorted(list(self.pid_index.keys()))
|
||||
self.num_identities = len(self.pids)
|
||||
|
||||
if seed is None:
|
||||
seed = comm.shared_random_seed()
|
||||
self._seed = int(seed)
|
||||
|
||||
self._rank = comm.get_rank()
|
||||
self._world_size = comm.get_world_size()
|
||||
|
||||
def __iter__(self):
|
||||
start = self._rank
|
||||
yield from itertools.islice(self._infinite_indices(), start, None, self._world_size)
|
||||
|
||||
def _infinite_indices(self):
|
||||
np.random.seed(self._seed)
|
||||
while True:
|
||||
batch_indices = []
|
||||
for camid in range(len(self.cam_pid.keys())):
|
||||
select_pids = np.random.choice(self.cam_pid[camid], size=self.set_weight[camid], replace=False,
|
||||
p=self.set_pid_prob[camid])
|
||||
for pid in select_pids:
|
||||
index_list = self.pid_index[pid]
|
||||
if len(index_list) > self.num_instances:
|
||||
select_indexs = np.random.choice(index_list, size=self.num_instances, replace=False)
|
||||
else:
|
||||
select_indexs = np.random.choice(index_list, size=self.num_instances, replace=True)
|
||||
|
||||
batch_indices += select_indexs
|
||||
np.random.shuffle(batch_indices)
|
||||
|
||||
if len(batch_indices) == self.batch_size:
|
||||
yield from reorder_index(batch_indices, self._world_size)
|
||||
|
||||
|
||||
class NaiveIdentitySampler(Sampler):
|
||||
"""
|
||||
Randomly sample N identities, then for each identity,
|
||||
|
|
|
@ -6,4 +6,5 @@
|
|||
|
||||
from .lamb import Lamb
|
||||
from .swa import SWA
|
||||
from .radam import RAdam
|
||||
from torch.optim import *
|
||||
|
|
|
@ -0,0 +1,149 @@
|
|||
import math
|
||||
|
||||
import torch
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
|
||||
class RAdam(Optimizer):
|
||||
|
||||
def __init__(self, params, lr=0.001, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
|
||||
self.buffer = [[None, None, None] for ind in range(10)]
|
||||
super(RAdam, self).__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state):
|
||||
super(RAdam, self).__setstate__(state)
|
||||
|
||||
def step(self, closure=None):
|
||||
|
||||
loss = None
|
||||
if closure is not None:
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad.data.float()
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError('RAdam does not support sparse gradients')
|
||||
|
||||
p_data_fp32 = p.data.float()
|
||||
|
||||
state = self.state[p]
|
||||
|
||||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
state['exp_avg'] = torch.zeros_like(p_data_fp32)
|
||||
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
|
||||
else:
|
||||
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
|
||||
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
|
||||
|
||||
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
||||
beta1, beta2 = group['betas']
|
||||
|
||||
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
||||
exp_avg.mul_(beta1).add_(1 - beta1, grad)
|
||||
|
||||
state['step'] += 1
|
||||
buffered = self.buffer[int(state['step'] % 10)]
|
||||
if state['step'] == buffered[0]:
|
||||
N_sma, step_size = buffered[1], buffered[2]
|
||||
else:
|
||||
buffered[0] = state['step']
|
||||
beta2_t = beta2 ** state['step']
|
||||
N_sma_max = 2 / (1 - beta2) - 1
|
||||
N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
|
||||
buffered[1] = N_sma
|
||||
|
||||
# more conservative since it's an approximated value
|
||||
if N_sma >= 5:
|
||||
step_size = group['lr'] * math.sqrt(
|
||||
(1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (
|
||||
N_sma_max - 2)) / (1 - beta1 ** state['step'])
|
||||
else:
|
||||
step_size = group['lr'] / (1 - beta1 ** state['step'])
|
||||
buffered[2] = step_size
|
||||
|
||||
if group['weight_decay'] != 0:
|
||||
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
|
||||
|
||||
# more conservative since it's an approximated value
|
||||
if N_sma >= 5:
|
||||
denom = exp_avg_sq.sqrt().add_(group['eps'])
|
||||
p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
|
||||
else:
|
||||
p_data_fp32.add_(-step_size, exp_avg)
|
||||
|
||||
p.data.copy_(p_data_fp32)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
class PlainRAdam(Optimizer):
|
||||
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
|
||||
|
||||
super(PlainRAdam, self).__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state):
|
||||
super(PlainRAdam, self).__setstate__(state)
|
||||
|
||||
def step(self, closure=None):
|
||||
|
||||
loss = None
|
||||
if closure is not None:
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad.data.float()
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError('RAdam does not support sparse gradients')
|
||||
|
||||
p_data_fp32 = p.data.float()
|
||||
|
||||
state = self.state[p]
|
||||
|
||||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
state['exp_avg'] = torch.zeros_like(p_data_fp32)
|
||||
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
|
||||
else:
|
||||
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
|
||||
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
|
||||
|
||||
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
||||
beta1, beta2 = group['betas']
|
||||
|
||||
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
||||
exp_avg.mul_(beta1).add_(1 - beta1, grad)
|
||||
|
||||
state['step'] += 1
|
||||
beta2_t = beta2 ** state['step']
|
||||
N_sma_max = 2 / (1 - beta2) - 1
|
||||
N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
|
||||
|
||||
if group['weight_decay'] != 0:
|
||||
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
|
||||
|
||||
# more conservative since it's an approximated value
|
||||
if N_sma >= 5:
|
||||
step_size = group['lr'] * math.sqrt(
|
||||
(1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (
|
||||
N_sma_max - 2)) / (1 - beta1 ** state['step'])
|
||||
denom = exp_avg_sq.sqrt().add_(group['eps'])
|
||||
p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
|
||||
else:
|
||||
step_size = group['lr'] / (1 - beta1 ** state['step'])
|
||||
p_data_fp32.add_(-step_size, exp_avg)
|
||||
|
||||
p.data.copy_(p_data_fp32)
|
||||
|
||||
return loss
|
Loading…
Reference in New Issue