diff --git a/fastreid/engine/defaults.py b/fastreid/engine/defaults.py index b3c3522..d6ee545 100644 --- a/fastreid/engine/defaults.py +++ b/fastreid/engine/defaults.py @@ -496,11 +496,13 @@ class DefaultTrainer(SimpleTrainer): cfg.SOLVER.STEPS[i] *= iters_per_epoch cfg.SOLVER.SWA.ITER *= iters_per_epoch cfg.SOLVER.SWA.PERIOD *= iters_per_epoch - cfg.SOLVER.CHECKPOINT_PERIOD *= iters_per_epoch + ckpt_multiple = cfg.SOLVER.CHECKPOINT_PERIOD / cfg.TEST.EVAL_PERIOD # Evaluation period must be divided by 200 for writing into tensorboard. - num_mod = (200 - cfg.TEST.EVAL_PERIOD * iters_per_epoch) % 200 - cfg.TEST.EVAL_PERIOD = cfg.TEST.EVAL_PERIOD * iters_per_epoch + num_mod + eval_num_mod = (200 - cfg.TEST.EVAL_PERIOD * iters_per_epoch) % 200 + cfg.TEST.EVAL_PERIOD = cfg.TEST.EVAL_PERIOD * iters_per_epoch + eval_num_mod + # Change checkpoint saving period consistent with evaluation period. + cfg.SOLVER.CHECKPOINT_PERIOD = int(cfg.TEST.EVAL_PERIOD * ckpt_multiple) logger = logging.getLogger(__name__) logger.info( diff --git a/fastreid/engine/train_loop.py b/fastreid/engine/train_loop.py index 71e55f5..04097b8 100644 --- a/fastreid/engine/train_loop.py +++ b/fastreid/engine/train_loop.py @@ -210,11 +210,6 @@ class SimpleTrainer(TrainerBase): loss_dict = self.model.losses(outputs, targets) losses = sum(loss_dict.values()) - self._detect_anomaly(losses, loss_dict) - - metrics_dict = loss_dict - metrics_dict["data_time"] = data_time - self._write_metrics(metrics_dict) """ If you need accumulate gradients or something similar, you can @@ -223,6 +218,12 @@ class SimpleTrainer(TrainerBase): self.optimizer.zero_grad() losses.backward() + with torch.cuda.stream(torch.cuda.Stream()): + metrics_dict = loss_dict + metrics_dict["data_time"] = data_time + self._write_metrics(metrics_dict) + self._detect_anomaly(losses, loss_dict) + """ If you need gradient clipping/scaling or other processing, you can wrap the optimizer with your custom `step()` method. diff --git a/fastreid/layers/batch_norm.py b/fastreid/layers/batch_norm.py index 47a861e..dddecf6 100644 --- a/fastreid/layers/batch_norm.py +++ b/fastreid/layers/batch_norm.py @@ -24,8 +24,8 @@ class BatchNorm(nn.BatchNorm2d): def __init__(self, num_features, eps=1e-05, momentum=0.1, weight_freeze=False, bias_freeze=False, weight_init=1.0, bias_init=0.0): super().__init__(num_features, eps=eps, momentum=momentum) - if weight_init is not None: self.weight.data.fill_(weight_init) - if bias_init is not None: self.bias.data.fill_(bias_init) + if weight_init is not None: nn.init.constant_(self.weight, weight_init) + if bias_init is not None: nn.init.constant_(self.bias, bias_init) self.weight.requires_grad_(not weight_freeze) self.bias.requires_grad_(not bias_freeze) diff --git a/fastreid/layers/pooling.py b/fastreid/layers/pooling.py index 5aec39e..4741a54 100644 --- a/fastreid/layers/pooling.py +++ b/fastreid/layers/pooling.py @@ -57,13 +57,14 @@ class GeneralizedMeanPoolingP(GeneralizedMeanPooling): class AdaptiveAvgMaxPool2d(nn.Module): def __init__(self): super(AdaptiveAvgMaxPool2d, self).__init__() - self.avgpool = FastGlobalAvgPool2d() + self.gap = FastGlobalAvgPool2d() + self.gmp = nn.AdaptiveMaxPool2d(1) def forward(self, x): - x_avg = self.avgpool(x, self.output_size) - x_max = F.adaptive_max_pool2d(x, 1) - x = x_max + x_avg - return x + avg_feat = self.gap(x) + max_feat = self.gmp(x) + feat = avg_feat + max_feat + return feat class FastGlobalAvgPool2d(nn.Module): diff --git a/fastreid/modeling/backbones/resnet.py b/fastreid/modeling/backbones/resnet.py index 18b5c57..6080f72 100644 --- a/fastreid/modeling/backbones/resnet.py +++ b/fastreid/modeling/backbones/resnet.py @@ -304,10 +304,27 @@ def build_resnet_backbone(cfg): with_nl = cfg.MODEL.BACKBONE.WITH_NL depth = cfg.MODEL.BACKBONE.DEPTH - num_blocks_per_stage = {'18x': [2, 2, 2, 2], '34x': [3, 4, 6, 3], '50x': [3, 4, 6, 3], - '101x': [3, 4, 23, 3],}[depth] - nl_layers_per_stage = {'18x': [0, 0, 0, 0], '34x': [0, 0, 0, 0], '50x': [0, 2, 3, 0], '101x': [0, 2, 9, 0]}[depth] - block = {'18x': BasicBlock, '34x': BasicBlock, '50x': Bottleneck, '101x': Bottleneck}[depth] + num_blocks_per_stage = { + '18x': [2, 2, 2, 2], + '34x': [3, 4, 6, 3], + '50x': [3, 4, 6, 3], + '101x': [3, 4, 23, 3], + }[depth] + + nl_layers_per_stage = { + '18x': [0, 0, 0, 0], + '34x': [0, 0, 0, 0], + '50x': [0, 2, 3, 0], + '101x': [0, 2, 9, 0] + }[depth] + + block = { + '18x': BasicBlock, + '34x': BasicBlock, + '50x': Bottleneck, + '101x': Bottleneck + }[depth] + model = ResNet(last_stride, bn_norm, num_splits, with_ibn, with_se, with_nl, block, num_blocks_per_stage, nl_layers_per_stage) if pretrain: diff --git a/fastreid/modeling/heads/bnneck_head.py b/fastreid/modeling/heads/bnneck_head.py index 5b46c6a..0cd21af 100644 --- a/fastreid/modeling/heads/bnneck_head.py +++ b/fastreid/modeling/heads/bnneck_head.py @@ -36,6 +36,7 @@ class BNneckHead(nn.Module): See :class:`ReIDHeads.forward`. """ global_feat = self.pool_layer(features) + global_feat = torch.clamp(global_feat, min=0., max=1.) bn_feat = self.bnneck(global_feat) bn_feat = bn_feat[..., 0, 0] @@ -43,8 +44,10 @@ class BNneckHead(nn.Module): if not self.training: return bn_feat # Training - try: cls_outputs = self.classifier(bn_feat) - except TypeError: cls_outputs = self.classifier(bn_feat, targets) + if self.classifier.__class__.__name__ == 'Linear': + cls_outputs = self.classifier(bn_feat) + else: + cls_outputs = self.classifier(bn_feat, targets) pred_class_logits = F.linear(bn_feat, self.classifier.weight) diff --git a/fastreid/modeling/heads/linear_head.py b/fastreid/modeling/heads/linear_head.py index b22bd7f..d3816d2 100644 --- a/fastreid/modeling/heads/linear_head.py +++ b/fastreid/modeling/heads/linear_head.py @@ -38,8 +38,11 @@ class LinearHead(nn.Module): if not self.training: return global_feat # Training - try: cls_outputs = self.classifier(global_feat) - except TypeError: cls_outputs = self.classifier(global_feat, targets) + if self.classifier.__class__.__name__ == 'Linear': + cls_outputs = self.classifier(global_feat) + else: + cls_outputs = self.classifier(global_feat, targets) + pred_class_logits = F.linear(global_feat, self.classifier.weight) diff --git a/fastreid/modeling/heads/reduction_head.py b/fastreid/modeling/heads/reduction_head.py index 7d55b3f..f617a1c 100644 --- a/fastreid/modeling/heads/reduction_head.py +++ b/fastreid/modeling/heads/reduction_head.py @@ -21,14 +21,10 @@ class ReductionHead(nn.Module): self.bottleneck = nn.Sequential( nn.Conv2d(in_feat, reduction_dim, 1, 1, bias=False), - get_norm(cfg.MODEL.HEADS.NORM, reduction_dim, cfg.MODEL.HEADS.NORM_SPLIT), - nn.LeakyReLU(0.1, inplace=True), + get_norm(cfg.MODEL.HEADS.NORM, reduction_dim, cfg.MODEL.HEADS.NORM_SPLIT, bias_freeze=True), ) - self.bnneck = get_norm(cfg.MODEL.HEADS.NORM, reduction_dim, cfg.MODEL.HEADS.NORM_SPLIT, bias_freeze=True) - self.bottleneck.apply(weights_init_kaiming) - self.bnneck.apply(weights_init_kaiming) # identity classification layer cls_type = cfg.MODEL.HEADS.CLS_LAYER @@ -46,17 +42,19 @@ class ReductionHead(nn.Module): """ See :class:`ReIDHeads.forward`. """ - features = self.pool_layer(features) - global_feat = self.bottleneck(features) - bn_feat = self.bnneck(global_feat) + global_feat = self.pool_layer(features) + bn_feat = self.bottleneck(global_feat) bn_feat = bn_feat[..., 0, 0] # Evaluation if not self.training: return bn_feat # Training - try: cls_outputs = self.classifier(bn_feat) - except TypeError: cls_outputs = self.classifier(bn_feat, targets) + # Training + if self.classifier.__class__.__name__ == 'Linear': + cls_outputs = self.classifier(bn_feat) + else: + cls_outputs = self.classifier(bn_feat, targets) pred_class_logits = F.linear(bn_feat, self.classifier.weight) @@ -66,4 +64,3 @@ class ReductionHead(nn.Module): raise KeyError("MODEL.HEADS.NECK_FEAT value is invalid, must choose from ('after' & 'before')") return cls_outputs, pred_class_logits, feat - diff --git a/fastreid/modeling/losses/circle_loss.py b/fastreid/modeling/losses/circle_loss.py index b945659..2d3fcf4 100644 --- a/fastreid/modeling/losses/circle_loss.py +++ b/fastreid/modeling/losses/circle_loss.py @@ -29,21 +29,15 @@ class CircleLoss(object): all_embedding = embedding all_targets = targets - dist_mat = torch.matmul(embedding, all_embedding.t()) + dist_mat = torch.matmul(all_embedding, all_embedding.t()) - N, M = dist_mat.size() - is_pos = targets.view(N, 1).expand(N, M).eq(all_targets.view(M, 1).expand(M, N).t()).float() + N = dist_mat.size(0) + is_pos = targets.view(N, 1).expand(N, N).eq(all_targets.view(N, 1).expand(N, N).t()).float() # Compute the mask which ignores the relevance score of the query to itself - if M > N: - identity_indx = torch.eye(N, N, device=is_pos.device) - remain_indx = torch.zeros(N, M - N, device=is_pos.device) - identity_indx = torch.cat((identity_indx, remain_indx), dim=1) - is_pos = is_pos - identity_indx - else: - is_pos = is_pos - torch.eye(N, N, device=is_pos.device) + is_pos = is_pos - torch.eye(N, N, device=is_pos.device) - is_neg = targets.view(N, 1).expand(N, M).ne(all_targets.view(M, 1).expand(M, N).t()) + is_neg = targets.view(N, 1).expand(N, N).ne(all_targets.view(N, 1).expand(N, N).t()) s_p = dist_mat * is_pos s_n = dist_mat * is_neg diff --git a/fastreid/modeling/losses/cross_entroy_loss.py b/fastreid/modeling/losses/cross_entroy_loss.py index c86e47e..e6d9add 100644 --- a/fastreid/modeling/losses/cross_entroy_loss.py +++ b/fastreid/modeling/losses/cross_entroy_loss.py @@ -48,7 +48,7 @@ class CrossEntropyLoss(object): if self._eps >= 0: smooth_param = self._eps else: - # adaptive lsr + # Adaptive label smooth regularization soft_label = F.softmax(pred_class_logits, dim=1) smooth_param = self._alpha * soft_label[torch.arange(soft_label.size(0)), gt_classes].unsqueeze(1) @@ -60,8 +60,16 @@ class CrossEntropyLoss(object): loss = (-targets * log_probs).sum(dim=1) + """ + # confidence penalty + conf_penalty = 0.3 + probs = F.softmax(pred_class_logits, dim=1) + entropy = torch.sum(-probs * log_probs, dim=1) + loss = torch.clamp_min(loss - conf_penalty * entropy, min=0.) + """ + with torch.no_grad(): - non_zero_cnt = max(loss.nonzero().size(0), 1) + non_zero_cnt = max(loss.nonzero(as_tuple=False).size(0), 1) loss = loss.sum() / non_zero_cnt diff --git a/fastreid/modeling/losses/triplet_loss.py b/fastreid/modeling/losses/triplet_loss.py index 5e1aa74..809f28f 100644 --- a/fastreid/modeling/losses/triplet_loss.py +++ b/fastreid/modeling/losses/triplet_loss.py @@ -110,11 +110,11 @@ class TripletLoss(object): all_embedding = embedding all_targets = targets - dist_mat = euclidean_dist(embedding, all_embedding) + dist_mat = euclidean_dist(all_embedding, all_embedding) - N, M = dist_mat.size() - is_pos = targets.view(N, 1).expand(N, M).eq(all_targets.view(M, 1).expand(M, N).t()) - is_neg = targets.view(N, 1).expand(N, M).ne(all_targets.view(M, 1).expand(M, N).t()) + N = dist_mat.size(0) + is_pos = all_targets.view(N, 1).expand(N, N).eq(all_targets.view(N, 1).expand(N, N).t()) + is_neg = all_targets.view(N, 1).expand(N, N).ne(all_targets.view(N, 1).expand(N, N).t()) if self._hard_mining: dist_ap, dist_an = hard_example_mining(dist_mat, is_pos, is_neg) diff --git a/fastreid/modeling/losses/utils.py b/fastreid/modeling/losses/utils.py index ce64dd8..3eb6630 100644 --- a/fastreid/modeling/losses/utils.py +++ b/fastreid/modeling/losses/utils.py @@ -7,7 +7,6 @@ import torch -@torch.no_grad() def concat_all_gather(tensor): """ Performs all_gather operation on the provided tensors. diff --git a/fastreid/solver/optim/adam.py b/fastreid/solver/optim/adam.py index a9b83e3..e223606 100644 --- a/fastreid/solver/optim/adam.py +++ b/fastreid/solver/optim/adam.py @@ -1,19 +1,13 @@ -# encoding: utf-8 -""" -@author: xingyu liao -@contact: sherlockliao01@gmail.com -""" - -import torch import math +import torch from torch.optim.optimizer import Optimizer class Adam(Optimizer): r"""Implements Adam algorithm. - It has been proposed in `Adam: A Method for Stochastic Optimization`_. - + The implementation of the L2 penalty follows changes proposed in + `Decoupled Weight Decay Regularization`_. Arguments: params (iterable): iterable of parameters to optimize or dicts defining parameter groups @@ -26,9 +20,10 @@ class Adam(Optimizer): amsgrad (boolean, optional): whether to use the AMSGrad variant of this algorithm from the paper `On the Convergence of Adam and Beyond`_ (default: False) - .. _Adam\: A Method for Stochastic Optimization: https://arxiv.org/abs/1412.6980 + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 .. _On the Convergence of Adam and Beyond: https://openreview.net/forum?id=ryQu7f-RZ """ @@ -43,6 +38,8 @@ class Adam(Optimizer): raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) if not 0.0 <= betas[1] < 1.0: raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad) super(Adam, self).__init__(params, defaults) @@ -52,22 +49,23 @@ class Adam(Optimizer): for group in self.param_groups: group.setdefault('amsgrad', False) + @torch.no_grad() def step(self, closure=None): """Performs a single optimization step. - Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ loss = None if closure is not None: - loss = closure() + with torch.enable_grad(): + loss = closure() for group in self.param_groups: for p in group['params']: if p.grad is None or group['freeze']: continue - grad = p.grad.data + grad = p.grad if grad.is_sparse: raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') amsgrad = group['amsgrad'] @@ -78,12 +76,12 @@ class Adam(Optimizer): if len(state) == 0: state['step'] = 0 # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros_like(p.data) + state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) # Exponential moving average of squared gradient values - state['exp_avg_sq'] = torch.zeros_like(p.data) + state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) if amsgrad: # Maintains max of all exp. moving avg. of sq. grad. values - state['max_exp_avg_sq'] = torch.zeros_like(p.data) + state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] if amsgrad: @@ -91,25 +89,25 @@ class Adam(Optimizer): beta1, beta2 = group['betas'] state['step'] += 1 + bias_correction1 = 1 - beta1 ** state['step'] + bias_correction2 = 1 - beta2 ** state['step'] if group['weight_decay'] != 0: - grad.add_(group['weight_decay'], p.data) + grad = grad.add(p, alpha=group['weight_decay']) # Decay the first and second moment running average coefficient - exp_avg.mul_(beta1).add_(1 - beta1, grad) - exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) if amsgrad: # Maintains the maximum of all 2nd moment running avg. till now torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) # Use the max. for normalizing running avg. of gradient - denom = max_exp_avg_sq.sqrt().add_(group['eps']) + denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) else: - denom = exp_avg_sq.sqrt().add_(group['eps']) + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) - bias_correction1 = 1 - beta1 ** state['step'] - bias_correction2 = 1 - beta2 ** state['step'] - step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 + step_size = group['lr'] / bias_correction1 - p.data.addcdiv_(-step_size, exp_avg, denom) + p.addcdiv_(exp_avg, denom, value=-step_size) - return loss + return loss \ No newline at end of file diff --git a/fastreid/solver/optim/sgd.py b/fastreid/solver/optim/sgd.py index 9e31bc0..c7c37dc 100644 --- a/fastreid/solver/optim/sgd.py +++ b/fastreid/solver/optim/sgd.py @@ -95,21 +95,21 @@ class SGD(Optimizer): for p in group['params']: if p.grad is None or group['freeze']: continue - d_p = p.grad.data + d_p = p.grad if weight_decay != 0: - d_p.add_(weight_decay, p.data) + d_p.add_(p, alpha=weight_decay) if momentum != 0: param_state = self.state[p] if 'momentum_buffer' not in param_state: buf = param_state['momentum_buffer'] = torch.clone(d_p).detach() else: buf = param_state['momentum_buffer'] - buf.mul_(momentum).add_(1 - dampening, d_p) + buf.mul_(momentum).add_(d_p, alpha=1 - dampening) if nesterov: - d_p = d_p.add(momentum, buf) + d_p = d_p.add(buf, alpha=momentum) else: d_p = buf - p.data.add_(-group['lr'], d_p) + p.data.add_(d_p, alpha=-group['lr']) - return loss \ No newline at end of file + return loss