updating for pytorch1.6

pull/240/head
liaoxingyu 2020-08-20 15:51:41 +08:00
parent 9844760d7f
commit ac8409a7da
14 changed files with 107 additions and 84 deletions

View File

@ -496,11 +496,13 @@ class DefaultTrainer(SimpleTrainer):
cfg.SOLVER.STEPS[i] *= iters_per_epoch
cfg.SOLVER.SWA.ITER *= iters_per_epoch
cfg.SOLVER.SWA.PERIOD *= iters_per_epoch
cfg.SOLVER.CHECKPOINT_PERIOD *= iters_per_epoch
ckpt_multiple = cfg.SOLVER.CHECKPOINT_PERIOD / cfg.TEST.EVAL_PERIOD
# Evaluation period must be divided by 200 for writing into tensorboard.
num_mod = (200 - cfg.TEST.EVAL_PERIOD * iters_per_epoch) % 200
cfg.TEST.EVAL_PERIOD = cfg.TEST.EVAL_PERIOD * iters_per_epoch + num_mod
eval_num_mod = (200 - cfg.TEST.EVAL_PERIOD * iters_per_epoch) % 200
cfg.TEST.EVAL_PERIOD = cfg.TEST.EVAL_PERIOD * iters_per_epoch + eval_num_mod
# Change checkpoint saving period consistent with evaluation period.
cfg.SOLVER.CHECKPOINT_PERIOD = int(cfg.TEST.EVAL_PERIOD * ckpt_multiple)
logger = logging.getLogger(__name__)
logger.info(

View File

@ -210,11 +210,6 @@ class SimpleTrainer(TrainerBase):
loss_dict = self.model.losses(outputs, targets)
losses = sum(loss_dict.values())
self._detect_anomaly(losses, loss_dict)
metrics_dict = loss_dict
metrics_dict["data_time"] = data_time
self._write_metrics(metrics_dict)
"""
If you need accumulate gradients or something similar, you can
@ -223,6 +218,12 @@ class SimpleTrainer(TrainerBase):
self.optimizer.zero_grad()
losses.backward()
with torch.cuda.stream(torch.cuda.Stream()):
metrics_dict = loss_dict
metrics_dict["data_time"] = data_time
self._write_metrics(metrics_dict)
self._detect_anomaly(losses, loss_dict)
"""
If you need gradient clipping/scaling or other processing, you can
wrap the optimizer with your custom `step()` method.

View File

@ -24,8 +24,8 @@ class BatchNorm(nn.BatchNorm2d):
def __init__(self, num_features, eps=1e-05, momentum=0.1, weight_freeze=False, bias_freeze=False, weight_init=1.0,
bias_init=0.0):
super().__init__(num_features, eps=eps, momentum=momentum)
if weight_init is not None: self.weight.data.fill_(weight_init)
if bias_init is not None: self.bias.data.fill_(bias_init)
if weight_init is not None: nn.init.constant_(self.weight, weight_init)
if bias_init is not None: nn.init.constant_(self.bias, bias_init)
self.weight.requires_grad_(not weight_freeze)
self.bias.requires_grad_(not bias_freeze)

View File

@ -57,13 +57,14 @@ class GeneralizedMeanPoolingP(GeneralizedMeanPooling):
class AdaptiveAvgMaxPool2d(nn.Module):
def __init__(self):
super(AdaptiveAvgMaxPool2d, self).__init__()
self.avgpool = FastGlobalAvgPool2d()
self.gap = FastGlobalAvgPool2d()
self.gmp = nn.AdaptiveMaxPool2d(1)
def forward(self, x):
x_avg = self.avgpool(x, self.output_size)
x_max = F.adaptive_max_pool2d(x, 1)
x = x_max + x_avg
return x
avg_feat = self.gap(x)
max_feat = self.gmp(x)
feat = avg_feat + max_feat
return feat
class FastGlobalAvgPool2d(nn.Module):

View File

@ -304,10 +304,27 @@ def build_resnet_backbone(cfg):
with_nl = cfg.MODEL.BACKBONE.WITH_NL
depth = cfg.MODEL.BACKBONE.DEPTH
num_blocks_per_stage = {'18x': [2, 2, 2, 2], '34x': [3, 4, 6, 3], '50x': [3, 4, 6, 3],
'101x': [3, 4, 23, 3],}[depth]
nl_layers_per_stage = {'18x': [0, 0, 0, 0], '34x': [0, 0, 0, 0], '50x': [0, 2, 3, 0], '101x': [0, 2, 9, 0]}[depth]
block = {'18x': BasicBlock, '34x': BasicBlock, '50x': Bottleneck, '101x': Bottleneck}[depth]
num_blocks_per_stage = {
'18x': [2, 2, 2, 2],
'34x': [3, 4, 6, 3],
'50x': [3, 4, 6, 3],
'101x': [3, 4, 23, 3],
}[depth]
nl_layers_per_stage = {
'18x': [0, 0, 0, 0],
'34x': [0, 0, 0, 0],
'50x': [0, 2, 3, 0],
'101x': [0, 2, 9, 0]
}[depth]
block = {
'18x': BasicBlock,
'34x': BasicBlock,
'50x': Bottleneck,
'101x': Bottleneck
}[depth]
model = ResNet(last_stride, bn_norm, num_splits, with_ibn, with_se, with_nl, block,
num_blocks_per_stage, nl_layers_per_stage)
if pretrain:

View File

@ -36,6 +36,7 @@ class BNneckHead(nn.Module):
See :class:`ReIDHeads.forward`.
"""
global_feat = self.pool_layer(features)
global_feat = torch.clamp(global_feat, min=0., max=1.)
bn_feat = self.bnneck(global_feat)
bn_feat = bn_feat[..., 0, 0]
@ -43,8 +44,10 @@ class BNneckHead(nn.Module):
if not self.training: return bn_feat
# Training
try: cls_outputs = self.classifier(bn_feat)
except TypeError: cls_outputs = self.classifier(bn_feat, targets)
if self.classifier.__class__.__name__ == 'Linear':
cls_outputs = self.classifier(bn_feat)
else:
cls_outputs = self.classifier(bn_feat, targets)
pred_class_logits = F.linear(bn_feat, self.classifier.weight)

View File

@ -38,8 +38,11 @@ class LinearHead(nn.Module):
if not self.training: return global_feat
# Training
try: cls_outputs = self.classifier(global_feat)
except TypeError: cls_outputs = self.classifier(global_feat, targets)
if self.classifier.__class__.__name__ == 'Linear':
cls_outputs = self.classifier(global_feat)
else:
cls_outputs = self.classifier(global_feat, targets)
pred_class_logits = F.linear(global_feat, self.classifier.weight)

View File

@ -21,14 +21,10 @@ class ReductionHead(nn.Module):
self.bottleneck = nn.Sequential(
nn.Conv2d(in_feat, reduction_dim, 1, 1, bias=False),
get_norm(cfg.MODEL.HEADS.NORM, reduction_dim, cfg.MODEL.HEADS.NORM_SPLIT),
nn.LeakyReLU(0.1, inplace=True),
get_norm(cfg.MODEL.HEADS.NORM, reduction_dim, cfg.MODEL.HEADS.NORM_SPLIT, bias_freeze=True),
)
self.bnneck = get_norm(cfg.MODEL.HEADS.NORM, reduction_dim, cfg.MODEL.HEADS.NORM_SPLIT, bias_freeze=True)
self.bottleneck.apply(weights_init_kaiming)
self.bnneck.apply(weights_init_kaiming)
# identity classification layer
cls_type = cfg.MODEL.HEADS.CLS_LAYER
@ -46,17 +42,19 @@ class ReductionHead(nn.Module):
"""
See :class:`ReIDHeads.forward`.
"""
features = self.pool_layer(features)
global_feat = self.bottleneck(features)
bn_feat = self.bnneck(global_feat)
global_feat = self.pool_layer(features)
bn_feat = self.bottleneck(global_feat)
bn_feat = bn_feat[..., 0, 0]
# Evaluation
if not self.training: return bn_feat
# Training
try: cls_outputs = self.classifier(bn_feat)
except TypeError: cls_outputs = self.classifier(bn_feat, targets)
# Training
if self.classifier.__class__.__name__ == 'Linear':
cls_outputs = self.classifier(bn_feat)
else:
cls_outputs = self.classifier(bn_feat, targets)
pred_class_logits = F.linear(bn_feat, self.classifier.weight)
@ -66,4 +64,3 @@ class ReductionHead(nn.Module):
raise KeyError("MODEL.HEADS.NECK_FEAT value is invalid, must choose from ('after' & 'before')")
return cls_outputs, pred_class_logits, feat

View File

@ -29,21 +29,15 @@ class CircleLoss(object):
all_embedding = embedding
all_targets = targets
dist_mat = torch.matmul(embedding, all_embedding.t())
dist_mat = torch.matmul(all_embedding, all_embedding.t())
N, M = dist_mat.size()
is_pos = targets.view(N, 1).expand(N, M).eq(all_targets.view(M, 1).expand(M, N).t()).float()
N = dist_mat.size(0)
is_pos = targets.view(N, 1).expand(N, N).eq(all_targets.view(N, 1).expand(N, N).t()).float()
# Compute the mask which ignores the relevance score of the query to itself
if M > N:
identity_indx = torch.eye(N, N, device=is_pos.device)
remain_indx = torch.zeros(N, M - N, device=is_pos.device)
identity_indx = torch.cat((identity_indx, remain_indx), dim=1)
is_pos = is_pos - identity_indx
else:
is_pos = is_pos - torch.eye(N, N, device=is_pos.device)
is_pos = is_pos - torch.eye(N, N, device=is_pos.device)
is_neg = targets.view(N, 1).expand(N, M).ne(all_targets.view(M, 1).expand(M, N).t())
is_neg = targets.view(N, 1).expand(N, N).ne(all_targets.view(N, 1).expand(N, N).t())
s_p = dist_mat * is_pos
s_n = dist_mat * is_neg

View File

@ -48,7 +48,7 @@ class CrossEntropyLoss(object):
if self._eps >= 0:
smooth_param = self._eps
else:
# adaptive lsr
# Adaptive label smooth regularization
soft_label = F.softmax(pred_class_logits, dim=1)
smooth_param = self._alpha * soft_label[torch.arange(soft_label.size(0)), gt_classes].unsqueeze(1)
@ -60,8 +60,16 @@ class CrossEntropyLoss(object):
loss = (-targets * log_probs).sum(dim=1)
"""
# confidence penalty
conf_penalty = 0.3
probs = F.softmax(pred_class_logits, dim=1)
entropy = torch.sum(-probs * log_probs, dim=1)
loss = torch.clamp_min(loss - conf_penalty * entropy, min=0.)
"""
with torch.no_grad():
non_zero_cnt = max(loss.nonzero().size(0), 1)
non_zero_cnt = max(loss.nonzero(as_tuple=False).size(0), 1)
loss = loss.sum() / non_zero_cnt

View File

@ -110,11 +110,11 @@ class TripletLoss(object):
all_embedding = embedding
all_targets = targets
dist_mat = euclidean_dist(embedding, all_embedding)
dist_mat = euclidean_dist(all_embedding, all_embedding)
N, M = dist_mat.size()
is_pos = targets.view(N, 1).expand(N, M).eq(all_targets.view(M, 1).expand(M, N).t())
is_neg = targets.view(N, 1).expand(N, M).ne(all_targets.view(M, 1).expand(M, N).t())
N = dist_mat.size(0)
is_pos = all_targets.view(N, 1).expand(N, N).eq(all_targets.view(N, 1).expand(N, N).t())
is_neg = all_targets.view(N, 1).expand(N, N).ne(all_targets.view(N, 1).expand(N, N).t())
if self._hard_mining:
dist_ap, dist_an = hard_example_mining(dist_mat, is_pos, is_neg)

View File

@ -7,7 +7,6 @@
import torch
@torch.no_grad()
def concat_all_gather(tensor):
"""
Performs all_gather operation on the provided tensors.

View File

@ -1,19 +1,13 @@
# encoding: utf-8
"""
@author: xingyu liao
@contact: sherlockliao01@gmail.com
"""
import torch
import math
import torch
from torch.optim.optimizer import Optimizer
class Adam(Optimizer):
r"""Implements Adam algorithm.
It has been proposed in `Adam: A Method for Stochastic Optimization`_.
The implementation of the L2 penalty follows changes proposed in
`Decoupled Weight Decay Regularization`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
@ -26,9 +20,10 @@ class Adam(Optimizer):
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False)
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _Decoupled Weight Decay Regularization:
https://arxiv.org/abs/1711.05101
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
@ -43,6 +38,8 @@ class Adam(Optimizer):
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
if not 0.0 <= weight_decay:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay, amsgrad=amsgrad)
super(Adam, self).__init__(params, defaults)
@ -52,22 +49,23 @@ class Adam(Optimizer):
for group in self.param_groups:
group.setdefault('amsgrad', False)
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None or group['freeze']:
continue
grad = p.grad.data
grad = p.grad
if grad.is_sparse:
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
amsgrad = group['amsgrad']
@ -78,12 +76,12 @@ class Adam(Optimizer):
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data)
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_sq'] = torch.zeros_like(p.data)
state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
if amsgrad:
@ -91,25 +89,25 @@ class Adam(Optimizer):
beta1, beta2 = group['betas']
state['step'] += 1
bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
if group['weight_decay'] != 0:
grad.add_(group['weight_decay'], p.data)
grad = grad.add(p, alpha=group['weight_decay'])
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(1 - beta1, grad)
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
if amsgrad:
# Maintains the maximum of all 2nd moment running avg. till now
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
# Use the max. for normalizing running avg. of gradient
denom = max_exp_avg_sq.sqrt().add_(group['eps'])
denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
else:
denom = exp_avg_sq.sqrt().add_(group['eps'])
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
step_size = group['lr'] / bias_correction1
p.data.addcdiv_(-step_size, exp_avg, denom)
p.addcdiv_(exp_avg, denom, value=-step_size)
return loss
return loss

View File

@ -95,21 +95,21 @@ class SGD(Optimizer):
for p in group['params']:
if p.grad is None or group['freeze']:
continue
d_p = p.grad.data
d_p = p.grad
if weight_decay != 0:
d_p.add_(weight_decay, p.data)
d_p.add_(p, alpha=weight_decay)
if momentum != 0:
param_state = self.state[p]
if 'momentum_buffer' not in param_state:
buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()
else:
buf = param_state['momentum_buffer']
buf.mul_(momentum).add_(1 - dampening, d_p)
buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
if nesterov:
d_p = d_p.add(momentum, buf)
d_p = d_p.add(buf, alpha=momentum)
else:
d_p = buf
p.data.add_(-group['lr'], d_p)
p.data.add_(d_p, alpha=-group['lr'])
return loss
return loss