mirror of https://github.com/JDAI-CV/fast-reid.git
Bugfix for cls_layer
In `any_softmax`, all operations are in-place, so pass into the `logits.clone()` to prevent outside logits changed.pull/504/head
parent
c3ac4f504c
commit
6300bd756e
|
@ -2,9 +2,15 @@
|
||||||
|
|
||||||
### v1.3
|
### v1.3
|
||||||
|
|
||||||
#### Improvements
|
#### New Features
|
||||||
|
- Vision Transformer backbone, see config in `configs/Market1501/bagtricks_vit.yml`
|
||||||
|
- Self-Distillation with EMA update
|
||||||
|
- Gradient Clip
|
||||||
|
|
||||||
|
#### Improvements
|
||||||
- Faster dataloader with pre-fetch thread and cuda stream
|
- Faster dataloader with pre-fetch thread and cuda stream
|
||||||
|
- Optimize DDP training speed by removing `find_unused_parameters` in DDP
|
||||||
|
|
||||||
|
|
||||||
### v1.2 (06/04/2021)
|
### v1.2 (06/04/2021)
|
||||||
|
|
||||||
|
|
|
@ -27,7 +27,7 @@ class ClasHead(EmbeddingHead):
|
||||||
# Evaluation
|
# Evaluation
|
||||||
if not self.training: return logits.mul_(self.cls_layer.s)
|
if not self.training: return logits.mul_(self.cls_layer.s)
|
||||||
|
|
||||||
cls_outputs = self.cls_layer(logits, targets)
|
cls_outputs = self.cls_layer(logits.clone(), targets)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"cls_outputs": cls_outputs,
|
"cls_outputs": cls_outputs,
|
||||||
|
|
|
@ -4,8 +4,6 @@
|
||||||
@contact: sherlockliao01@gmail.com
|
@contact: sherlockliao01@gmail.com
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import math
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
@ -13,7 +11,7 @@ from torch import nn
|
||||||
from fastreid.config import configurable
|
from fastreid.config import configurable
|
||||||
from fastreid.layers import *
|
from fastreid.layers import *
|
||||||
from fastreid.layers import pooling, any_softmax
|
from fastreid.layers import pooling, any_softmax
|
||||||
from fastreid.utils.weight_init import weights_init_kaiming
|
from fastreid.layers.weight_init import weights_init_kaiming
|
||||||
from .build import REID_HEADS_REGISTRY
|
from .build import REID_HEADS_REGISTRY
|
||||||
|
|
||||||
|
|
||||||
|
@ -78,14 +76,19 @@ class EmbeddingHead(nn.Module):
|
||||||
neck.append(get_norm(norm_type, feat_dim, bias_freeze=True))
|
neck.append(get_norm(norm_type, feat_dim, bias_freeze=True))
|
||||||
|
|
||||||
self.bottleneck = nn.Sequential(*neck)
|
self.bottleneck = nn.Sequential(*neck)
|
||||||
self.bottleneck.apply(weights_init_kaiming)
|
|
||||||
|
|
||||||
# Linear layer
|
# Classification head
|
||||||
assert hasattr(any_softmax, cls_type), "Expected cls types are {}, " \
|
assert hasattr(any_softmax, cls_type), "Expected cls types are {}, " \
|
||||||
"but got {}".format(any_softmax.__all__, cls_type)
|
"but got {}".format(any_softmax.__all__, cls_type)
|
||||||
self.weight = nn.Parameter(torch.normal(0, 0.01, (num_classes, feat_dim)))
|
self.weight = nn.Parameter(torch.Tensor(num_classes, feat_dim))
|
||||||
self.cls_layer = getattr(any_softmax, cls_type)(num_classes, scale, margin)
|
self.cls_layer = getattr(any_softmax, cls_type)(num_classes, scale, margin)
|
||||||
|
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def reset_parameters(self) -> None:
|
||||||
|
self.bottleneck.apply(weights_init_kaiming)
|
||||||
|
nn.init.normal_(self.weight, std=0.01)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, cfg):
|
def from_config(cls, cfg):
|
||||||
# fmt: off
|
# fmt: off
|
||||||
|
@ -132,7 +135,8 @@ class EmbeddingHead(nn.Module):
|
||||||
else:
|
else:
|
||||||
logits = F.linear(F.normalize(neck_feat), F.normalize(self.weight))
|
logits = F.linear(F.normalize(neck_feat), F.normalize(self.weight))
|
||||||
|
|
||||||
cls_outputs = self.cls_layer(logits, targets)
|
# Pass logits.clone() into cls_layer, because there is in-place operations
|
||||||
|
cls_outputs = self.cls_layer(logits.clone(), targets)
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
if self.neck_feat == 'before': feat = pool_feat[..., 0, 0]
|
if self.neck_feat == 'before': feat = pool_feat[..., 0, 0]
|
||||||
|
|
Loading…
Reference in New Issue