mirror of https://github.com/JDAI-CV/fast-reid.git
updating for pytorch1.6
parent
9844760d7f
commit
ac8409a7da
|
@ -496,11 +496,13 @@ class DefaultTrainer(SimpleTrainer):
|
|||
cfg.SOLVER.STEPS[i] *= iters_per_epoch
|
||||
cfg.SOLVER.SWA.ITER *= iters_per_epoch
|
||||
cfg.SOLVER.SWA.PERIOD *= iters_per_epoch
|
||||
cfg.SOLVER.CHECKPOINT_PERIOD *= iters_per_epoch
|
||||
|
||||
ckpt_multiple = cfg.SOLVER.CHECKPOINT_PERIOD / cfg.TEST.EVAL_PERIOD
|
||||
# Evaluation period must be divided by 200 for writing into tensorboard.
|
||||
num_mod = (200 - cfg.TEST.EVAL_PERIOD * iters_per_epoch) % 200
|
||||
cfg.TEST.EVAL_PERIOD = cfg.TEST.EVAL_PERIOD * iters_per_epoch + num_mod
|
||||
eval_num_mod = (200 - cfg.TEST.EVAL_PERIOD * iters_per_epoch) % 200
|
||||
cfg.TEST.EVAL_PERIOD = cfg.TEST.EVAL_PERIOD * iters_per_epoch + eval_num_mod
|
||||
# Change checkpoint saving period consistent with evaluation period.
|
||||
cfg.SOLVER.CHECKPOINT_PERIOD = int(cfg.TEST.EVAL_PERIOD * ckpt_multiple)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info(
|
||||
|
|
|
@ -210,11 +210,6 @@ class SimpleTrainer(TrainerBase):
|
|||
loss_dict = self.model.losses(outputs, targets)
|
||||
|
||||
losses = sum(loss_dict.values())
|
||||
self._detect_anomaly(losses, loss_dict)
|
||||
|
||||
metrics_dict = loss_dict
|
||||
metrics_dict["data_time"] = data_time
|
||||
self._write_metrics(metrics_dict)
|
||||
|
||||
"""
|
||||
If you need accumulate gradients or something similar, you can
|
||||
|
@ -223,6 +218,12 @@ class SimpleTrainer(TrainerBase):
|
|||
self.optimizer.zero_grad()
|
||||
losses.backward()
|
||||
|
||||
with torch.cuda.stream(torch.cuda.Stream()):
|
||||
metrics_dict = loss_dict
|
||||
metrics_dict["data_time"] = data_time
|
||||
self._write_metrics(metrics_dict)
|
||||
self._detect_anomaly(losses, loss_dict)
|
||||
|
||||
"""
|
||||
If you need gradient clipping/scaling or other processing, you can
|
||||
wrap the optimizer with your custom `step()` method.
|
||||
|
|
|
@ -24,8 +24,8 @@ class BatchNorm(nn.BatchNorm2d):
|
|||
def __init__(self, num_features, eps=1e-05, momentum=0.1, weight_freeze=False, bias_freeze=False, weight_init=1.0,
|
||||
bias_init=0.0):
|
||||
super().__init__(num_features, eps=eps, momentum=momentum)
|
||||
if weight_init is not None: self.weight.data.fill_(weight_init)
|
||||
if bias_init is not None: self.bias.data.fill_(bias_init)
|
||||
if weight_init is not None: nn.init.constant_(self.weight, weight_init)
|
||||
if bias_init is not None: nn.init.constant_(self.bias, bias_init)
|
||||
self.weight.requires_grad_(not weight_freeze)
|
||||
self.bias.requires_grad_(not bias_freeze)
|
||||
|
||||
|
|
|
@ -57,13 +57,14 @@ class GeneralizedMeanPoolingP(GeneralizedMeanPooling):
|
|||
class AdaptiveAvgMaxPool2d(nn.Module):
|
||||
def __init__(self):
|
||||
super(AdaptiveAvgMaxPool2d, self).__init__()
|
||||
self.avgpool = FastGlobalAvgPool2d()
|
||||
self.gap = FastGlobalAvgPool2d()
|
||||
self.gmp = nn.AdaptiveMaxPool2d(1)
|
||||
|
||||
def forward(self, x):
|
||||
x_avg = self.avgpool(x, self.output_size)
|
||||
x_max = F.adaptive_max_pool2d(x, 1)
|
||||
x = x_max + x_avg
|
||||
return x
|
||||
avg_feat = self.gap(x)
|
||||
max_feat = self.gmp(x)
|
||||
feat = avg_feat + max_feat
|
||||
return feat
|
||||
|
||||
|
||||
class FastGlobalAvgPool2d(nn.Module):
|
||||
|
|
|
@ -304,10 +304,27 @@ def build_resnet_backbone(cfg):
|
|||
with_nl = cfg.MODEL.BACKBONE.WITH_NL
|
||||
depth = cfg.MODEL.BACKBONE.DEPTH
|
||||
|
||||
num_blocks_per_stage = {'18x': [2, 2, 2, 2], '34x': [3, 4, 6, 3], '50x': [3, 4, 6, 3],
|
||||
'101x': [3, 4, 23, 3],}[depth]
|
||||
nl_layers_per_stage = {'18x': [0, 0, 0, 0], '34x': [0, 0, 0, 0], '50x': [0, 2, 3, 0], '101x': [0, 2, 9, 0]}[depth]
|
||||
block = {'18x': BasicBlock, '34x': BasicBlock, '50x': Bottleneck, '101x': Bottleneck}[depth]
|
||||
num_blocks_per_stage = {
|
||||
'18x': [2, 2, 2, 2],
|
||||
'34x': [3, 4, 6, 3],
|
||||
'50x': [3, 4, 6, 3],
|
||||
'101x': [3, 4, 23, 3],
|
||||
}[depth]
|
||||
|
||||
nl_layers_per_stage = {
|
||||
'18x': [0, 0, 0, 0],
|
||||
'34x': [0, 0, 0, 0],
|
||||
'50x': [0, 2, 3, 0],
|
||||
'101x': [0, 2, 9, 0]
|
||||
}[depth]
|
||||
|
||||
block = {
|
||||
'18x': BasicBlock,
|
||||
'34x': BasicBlock,
|
||||
'50x': Bottleneck,
|
||||
'101x': Bottleneck
|
||||
}[depth]
|
||||
|
||||
model = ResNet(last_stride, bn_norm, num_splits, with_ibn, with_se, with_nl, block,
|
||||
num_blocks_per_stage, nl_layers_per_stage)
|
||||
if pretrain:
|
||||
|
|
|
@ -36,6 +36,7 @@ class BNneckHead(nn.Module):
|
|||
See :class:`ReIDHeads.forward`.
|
||||
"""
|
||||
global_feat = self.pool_layer(features)
|
||||
global_feat = torch.clamp(global_feat, min=0., max=1.)
|
||||
bn_feat = self.bnneck(global_feat)
|
||||
bn_feat = bn_feat[..., 0, 0]
|
||||
|
||||
|
@ -43,8 +44,10 @@ class BNneckHead(nn.Module):
|
|||
if not self.training: return bn_feat
|
||||
|
||||
# Training
|
||||
try: cls_outputs = self.classifier(bn_feat)
|
||||
except TypeError: cls_outputs = self.classifier(bn_feat, targets)
|
||||
if self.classifier.__class__.__name__ == 'Linear':
|
||||
cls_outputs = self.classifier(bn_feat)
|
||||
else:
|
||||
cls_outputs = self.classifier(bn_feat, targets)
|
||||
|
||||
pred_class_logits = F.linear(bn_feat, self.classifier.weight)
|
||||
|
||||
|
|
|
@ -38,8 +38,11 @@ class LinearHead(nn.Module):
|
|||
if not self.training: return global_feat
|
||||
|
||||
# Training
|
||||
try: cls_outputs = self.classifier(global_feat)
|
||||
except TypeError: cls_outputs = self.classifier(global_feat, targets)
|
||||
if self.classifier.__class__.__name__ == 'Linear':
|
||||
cls_outputs = self.classifier(global_feat)
|
||||
else:
|
||||
cls_outputs = self.classifier(global_feat, targets)
|
||||
|
||||
|
||||
pred_class_logits = F.linear(global_feat, self.classifier.weight)
|
||||
|
||||
|
|
|
@ -21,14 +21,10 @@ class ReductionHead(nn.Module):
|
|||
|
||||
self.bottleneck = nn.Sequential(
|
||||
nn.Conv2d(in_feat, reduction_dim, 1, 1, bias=False),
|
||||
get_norm(cfg.MODEL.HEADS.NORM, reduction_dim, cfg.MODEL.HEADS.NORM_SPLIT),
|
||||
nn.LeakyReLU(0.1, inplace=True),
|
||||
get_norm(cfg.MODEL.HEADS.NORM, reduction_dim, cfg.MODEL.HEADS.NORM_SPLIT, bias_freeze=True),
|
||||
)
|
||||
|
||||
self.bnneck = get_norm(cfg.MODEL.HEADS.NORM, reduction_dim, cfg.MODEL.HEADS.NORM_SPLIT, bias_freeze=True)
|
||||
|
||||
self.bottleneck.apply(weights_init_kaiming)
|
||||
self.bnneck.apply(weights_init_kaiming)
|
||||
|
||||
# identity classification layer
|
||||
cls_type = cfg.MODEL.HEADS.CLS_LAYER
|
||||
|
@ -46,17 +42,19 @@ class ReductionHead(nn.Module):
|
|||
"""
|
||||
See :class:`ReIDHeads.forward`.
|
||||
"""
|
||||
features = self.pool_layer(features)
|
||||
global_feat = self.bottleneck(features)
|
||||
bn_feat = self.bnneck(global_feat)
|
||||
global_feat = self.pool_layer(features)
|
||||
bn_feat = self.bottleneck(global_feat)
|
||||
bn_feat = bn_feat[..., 0, 0]
|
||||
|
||||
# Evaluation
|
||||
if not self.training: return bn_feat
|
||||
|
||||
# Training
|
||||
try: cls_outputs = self.classifier(bn_feat)
|
||||
except TypeError: cls_outputs = self.classifier(bn_feat, targets)
|
||||
# Training
|
||||
if self.classifier.__class__.__name__ == 'Linear':
|
||||
cls_outputs = self.classifier(bn_feat)
|
||||
else:
|
||||
cls_outputs = self.classifier(bn_feat, targets)
|
||||
|
||||
pred_class_logits = F.linear(bn_feat, self.classifier.weight)
|
||||
|
||||
|
@ -66,4 +64,3 @@ class ReductionHead(nn.Module):
|
|||
raise KeyError("MODEL.HEADS.NECK_FEAT value is invalid, must choose from ('after' & 'before')")
|
||||
|
||||
return cls_outputs, pred_class_logits, feat
|
||||
|
||||
|
|
|
@ -29,21 +29,15 @@ class CircleLoss(object):
|
|||
all_embedding = embedding
|
||||
all_targets = targets
|
||||
|
||||
dist_mat = torch.matmul(embedding, all_embedding.t())
|
||||
dist_mat = torch.matmul(all_embedding, all_embedding.t())
|
||||
|
||||
N, M = dist_mat.size()
|
||||
is_pos = targets.view(N, 1).expand(N, M).eq(all_targets.view(M, 1).expand(M, N).t()).float()
|
||||
N = dist_mat.size(0)
|
||||
is_pos = targets.view(N, 1).expand(N, N).eq(all_targets.view(N, 1).expand(N, N).t()).float()
|
||||
|
||||
# Compute the mask which ignores the relevance score of the query to itself
|
||||
if M > N:
|
||||
identity_indx = torch.eye(N, N, device=is_pos.device)
|
||||
remain_indx = torch.zeros(N, M - N, device=is_pos.device)
|
||||
identity_indx = torch.cat((identity_indx, remain_indx), dim=1)
|
||||
is_pos = is_pos - identity_indx
|
||||
else:
|
||||
is_pos = is_pos - torch.eye(N, N, device=is_pos.device)
|
||||
is_pos = is_pos - torch.eye(N, N, device=is_pos.device)
|
||||
|
||||
is_neg = targets.view(N, 1).expand(N, M).ne(all_targets.view(M, 1).expand(M, N).t())
|
||||
is_neg = targets.view(N, 1).expand(N, N).ne(all_targets.view(N, 1).expand(N, N).t())
|
||||
|
||||
s_p = dist_mat * is_pos
|
||||
s_n = dist_mat * is_neg
|
||||
|
|
|
@ -48,7 +48,7 @@ class CrossEntropyLoss(object):
|
|||
if self._eps >= 0:
|
||||
smooth_param = self._eps
|
||||
else:
|
||||
# adaptive lsr
|
||||
# Adaptive label smooth regularization
|
||||
soft_label = F.softmax(pred_class_logits, dim=1)
|
||||
smooth_param = self._alpha * soft_label[torch.arange(soft_label.size(0)), gt_classes].unsqueeze(1)
|
||||
|
||||
|
@ -60,8 +60,16 @@ class CrossEntropyLoss(object):
|
|||
|
||||
loss = (-targets * log_probs).sum(dim=1)
|
||||
|
||||
"""
|
||||
# confidence penalty
|
||||
conf_penalty = 0.3
|
||||
probs = F.softmax(pred_class_logits, dim=1)
|
||||
entropy = torch.sum(-probs * log_probs, dim=1)
|
||||
loss = torch.clamp_min(loss - conf_penalty * entropy, min=0.)
|
||||
"""
|
||||
|
||||
with torch.no_grad():
|
||||
non_zero_cnt = max(loss.nonzero().size(0), 1)
|
||||
non_zero_cnt = max(loss.nonzero(as_tuple=False).size(0), 1)
|
||||
|
||||
loss = loss.sum() / non_zero_cnt
|
||||
|
||||
|
|
|
@ -110,11 +110,11 @@ class TripletLoss(object):
|
|||
all_embedding = embedding
|
||||
all_targets = targets
|
||||
|
||||
dist_mat = euclidean_dist(embedding, all_embedding)
|
||||
dist_mat = euclidean_dist(all_embedding, all_embedding)
|
||||
|
||||
N, M = dist_mat.size()
|
||||
is_pos = targets.view(N, 1).expand(N, M).eq(all_targets.view(M, 1).expand(M, N).t())
|
||||
is_neg = targets.view(N, 1).expand(N, M).ne(all_targets.view(M, 1).expand(M, N).t())
|
||||
N = dist_mat.size(0)
|
||||
is_pos = all_targets.view(N, 1).expand(N, N).eq(all_targets.view(N, 1).expand(N, N).t())
|
||||
is_neg = all_targets.view(N, 1).expand(N, N).ne(all_targets.view(N, 1).expand(N, N).t())
|
||||
|
||||
if self._hard_mining:
|
||||
dist_ap, dist_an = hard_example_mining(dist_mat, is_pos, is_neg)
|
||||
|
|
|
@ -7,7 +7,6 @@
|
|||
import torch
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def concat_all_gather(tensor):
|
||||
"""
|
||||
Performs all_gather operation on the provided tensors.
|
||||
|
|
|
@ -1,19 +1,13 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: xingyu liao
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import torch
|
||||
import math
|
||||
import torch
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
|
||||
class Adam(Optimizer):
|
||||
r"""Implements Adam algorithm.
|
||||
|
||||
It has been proposed in `Adam: A Method for Stochastic Optimization`_.
|
||||
|
||||
The implementation of the L2 penalty follows changes proposed in
|
||||
`Decoupled Weight Decay Regularization`_.
|
||||
Arguments:
|
||||
params (iterable): iterable of parameters to optimize or dicts defining
|
||||
parameter groups
|
||||
|
@ -26,9 +20,10 @@ class Adam(Optimizer):
|
|||
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
|
||||
algorithm from the paper `On the Convergence of Adam and Beyond`_
|
||||
(default: False)
|
||||
|
||||
.. _Adam\: A Method for Stochastic Optimization:
|
||||
https://arxiv.org/abs/1412.6980
|
||||
.. _Decoupled Weight Decay Regularization:
|
||||
https://arxiv.org/abs/1711.05101
|
||||
.. _On the Convergence of Adam and Beyond:
|
||||
https://openreview.net/forum?id=ryQu7f-RZ
|
||||
"""
|
||||
|
@ -43,6 +38,8 @@ class Adam(Optimizer):
|
|||
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
||||
if not 0.0 <= betas[1] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
||||
if not 0.0 <= weight_decay:
|
||||
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps,
|
||||
weight_decay=weight_decay, amsgrad=amsgrad)
|
||||
super(Adam, self).__init__(params, defaults)
|
||||
|
@ -52,22 +49,23 @@ class Adam(Optimizer):
|
|||
for group in self.param_groups:
|
||||
group.setdefault('amsgrad', False)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
"""Performs a single optimization step.
|
||||
|
||||
Arguments:
|
||||
closure (callable, optional): A closure that reevaluates the model
|
||||
and returns the loss.
|
||||
"""
|
||||
loss = None
|
||||
if closure is not None:
|
||||
loss = closure()
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
if p.grad is None or group['freeze']:
|
||||
continue
|
||||
grad = p.grad.data
|
||||
grad = p.grad
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
|
||||
amsgrad = group['amsgrad']
|
||||
|
@ -78,12 +76,12 @@ class Adam(Optimizer):
|
|||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
# Exponential moving average of gradient values
|
||||
state['exp_avg'] = torch.zeros_like(p.data)
|
||||
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
||||
# Exponential moving average of squared gradient values
|
||||
state['exp_avg_sq'] = torch.zeros_like(p.data)
|
||||
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
||||
if amsgrad:
|
||||
# Maintains max of all exp. moving avg. of sq. grad. values
|
||||
state['max_exp_avg_sq'] = torch.zeros_like(p.data)
|
||||
state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
||||
|
||||
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
||||
if amsgrad:
|
||||
|
@ -91,25 +89,25 @@ class Adam(Optimizer):
|
|||
beta1, beta2 = group['betas']
|
||||
|
||||
state['step'] += 1
|
||||
bias_correction1 = 1 - beta1 ** state['step']
|
||||
bias_correction2 = 1 - beta2 ** state['step']
|
||||
|
||||
if group['weight_decay'] != 0:
|
||||
grad.add_(group['weight_decay'], p.data)
|
||||
grad = grad.add(p, alpha=group['weight_decay'])
|
||||
|
||||
# Decay the first and second moment running average coefficient
|
||||
exp_avg.mul_(beta1).add_(1 - beta1, grad)
|
||||
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
||||
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
||||
if amsgrad:
|
||||
# Maintains the maximum of all 2nd moment running avg. till now
|
||||
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
|
||||
# Use the max. for normalizing running avg. of gradient
|
||||
denom = max_exp_avg_sq.sqrt().add_(group['eps'])
|
||||
denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
|
||||
else:
|
||||
denom = exp_avg_sq.sqrt().add_(group['eps'])
|
||||
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
|
||||
|
||||
bias_correction1 = 1 - beta1 ** state['step']
|
||||
bias_correction2 = 1 - beta2 ** state['step']
|
||||
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
|
||||
step_size = group['lr'] / bias_correction1
|
||||
|
||||
p.data.addcdiv_(-step_size, exp_avg, denom)
|
||||
p.addcdiv_(exp_avg, denom, value=-step_size)
|
||||
|
||||
return loss
|
||||
return loss
|
|
@ -95,21 +95,21 @@ class SGD(Optimizer):
|
|||
for p in group['params']:
|
||||
if p.grad is None or group['freeze']:
|
||||
continue
|
||||
d_p = p.grad.data
|
||||
d_p = p.grad
|
||||
if weight_decay != 0:
|
||||
d_p.add_(weight_decay, p.data)
|
||||
d_p.add_(p, alpha=weight_decay)
|
||||
if momentum != 0:
|
||||
param_state = self.state[p]
|
||||
if 'momentum_buffer' not in param_state:
|
||||
buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()
|
||||
else:
|
||||
buf = param_state['momentum_buffer']
|
||||
buf.mul_(momentum).add_(1 - dampening, d_p)
|
||||
buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
|
||||
if nesterov:
|
||||
d_p = d_p.add(momentum, buf)
|
||||
d_p = d_p.add(buf, alpha=momentum)
|
||||
else:
|
||||
d_p = buf
|
||||
|
||||
p.data.add_(-group['lr'], d_p)
|
||||
p.data.add_(d_p, alpha=-group['lr'])
|
||||
|
||||
return loss
|
||||
return loss
|
||||
|
|
Loading…
Reference in New Issue