Bugfix for cls_layer

In `any_softmax`, all operations are in-place, so pass into the `logits.clone()` to prevent outside logits changed.
pull/504/head
liaoxingyu 2021-05-31 17:32:24 +08:00
parent c3ac4f504c
commit 6300bd756e
3 changed files with 19 additions and 9 deletions

View File

@ -2,9 +2,15 @@
### v1.3
#### Improvements
#### New Features
- Vision Transformer backbone, see config in `configs/Market1501/bagtricks_vit.yml`
- Self-Distillation with EMA update
- Gradient Clip
#### Improvements
- Faster dataloader with pre-fetch thread and cuda stream
- Optimize DDP training speed by removing `find_unused_parameters` in DDP
### v1.2 (06/04/2021)

View File

@ -27,7 +27,7 @@ class ClasHead(EmbeddingHead):
# Evaluation
if not self.training: return logits.mul_(self.cls_layer.s)
cls_outputs = self.cls_layer(logits, targets)
cls_outputs = self.cls_layer(logits.clone(), targets)
return {
"cls_outputs": cls_outputs,

View File

@ -4,8 +4,6 @@
@contact: sherlockliao01@gmail.com
"""
import math
import torch
import torch.nn.functional as F
from torch import nn
@ -13,7 +11,7 @@ from torch import nn
from fastreid.config import configurable
from fastreid.layers import *
from fastreid.layers import pooling, any_softmax
from fastreid.utils.weight_init import weights_init_kaiming
from fastreid.layers.weight_init import weights_init_kaiming
from .build import REID_HEADS_REGISTRY
@ -78,14 +76,19 @@ class EmbeddingHead(nn.Module):
neck.append(get_norm(norm_type, feat_dim, bias_freeze=True))
self.bottleneck = nn.Sequential(*neck)
self.bottleneck.apply(weights_init_kaiming)
# Linear layer
# Classification head
assert hasattr(any_softmax, cls_type), "Expected cls types are {}, " \
"but got {}".format(any_softmax.__all__, cls_type)
self.weight = nn.Parameter(torch.normal(0, 0.01, (num_classes, feat_dim)))
self.weight = nn.Parameter(torch.Tensor(num_classes, feat_dim))
self.cls_layer = getattr(any_softmax, cls_type)(num_classes, scale, margin)
self.reset_parameters()
def reset_parameters(self) -> None:
self.bottleneck.apply(weights_init_kaiming)
nn.init.normal_(self.weight, std=0.01)
@classmethod
def from_config(cls, cfg):
# fmt: off
@ -132,7 +135,8 @@ class EmbeddingHead(nn.Module):
else:
logits = F.linear(F.normalize(neck_feat), F.normalize(self.weight))
cls_outputs = self.cls_layer(logits, targets)
# Pass logits.clone() into cls_layer, because there is in-place operations
cls_outputs = self.cls_layer(logits.clone(), targets)
# fmt: off
if self.neck_feat == 'before': feat = pool_feat[..., 0, 0]