mirror of
https://github.com/JDAI-CV/fast-reid.git
synced 2025-06-03 14:50:47 +08:00
1. change dataset show to trainset show and testset show seperately 2. add cls layer to easily plug in circle loss and arcface
93 lines
3.2 KiB
Python
93 lines
3.2 KiB
Python
# encoding: utf-8
|
|
"""
|
|
@author: liaoxingyu
|
|
@contact: sherlockliao01@gmail.com
|
|
"""
|
|
|
|
import math
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch.nn import Parameter
|
|
|
|
from ..model_utils import weights_init_kaiming
|
|
from ..layers import *
|
|
|
|
|
|
class OSM(nn.Module):
|
|
def __init__(self, cfg, in_feat, pool_layer=nn.AdaptiveAvgPool2d(1)):
|
|
super().__init__()
|
|
self._num_classes = cfg.MODEL.HEADS.NUM_CLASSES
|
|
|
|
self.pool_layer = nn.Sequential(
|
|
pool_layer,
|
|
Flatten()
|
|
)
|
|
# bnneck
|
|
self.bnneck = NoBiasBatchNorm1d(in_feat)
|
|
self.bnneck.apply(weights_init_kaiming)
|
|
|
|
# classifier
|
|
self.alpha = 1.2 # margin of weighted contrastive loss, as mentioned in the paper
|
|
self.l = 0.5 # hyperparameter controlling weights of positive set and the negative set
|
|
# I haven't been able to figure out the use of \sigma CAA 0.18
|
|
self.osm_sigma = 0.8 # \sigma OSM (0.8) as mentioned in paper
|
|
|
|
self.weight = Parameter(torch.Tensor(self._num_classes, in_feat))
|
|
self.reset_parameters()
|
|
|
|
def reset_parameters(self):
|
|
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
|
|
|
def forward(self, features, targets=None):
|
|
|
|
global_feat = self.pool_layer(features)
|
|
bn_feat = self.bnneck(global_feat)
|
|
if not self.training:
|
|
return bn_feat
|
|
|
|
bn_feat = F.normalize(bn_feat)
|
|
n = bn_feat.size(0)
|
|
|
|
# Compute pairwise distance, replace by the official when merged
|
|
dist = torch.pow(bn_feat, 2).sum(dim=1, keepdim=True).expand(n, n)
|
|
dist = dist + dist.t()
|
|
dist.addmm_(1, -2, bn_feat, bn_feat.t())
|
|
dist = dist.clamp(min=1e-12).sqrt() # for numerical stability & pairwise distance, dij
|
|
|
|
S = torch.exp(-1.0 * torch.pow(dist, 2) / (self.osm_sigma * self.osm_sigma))
|
|
S_ = torch.clamp(self.alpha - dist, min=1e-12) # max (0 , \alpha - dij) # 1e-12, 0 may result in nan error
|
|
|
|
p_mask = targets.expand(n, n).eq(targets.expand(n, n).t()) # same label == 1
|
|
n_mask = torch.bitwise_not(p_mask) # oposite label == 1
|
|
|
|
S = S * p_mask.float()
|
|
S = S + S_ * n_mask.float()
|
|
|
|
denominator = torch.exp(F.linear(bn_feat, F.normalize(self.weight)))
|
|
|
|
A = [] # attention corresponding to each feature fector
|
|
for i in range(n):
|
|
a_i = denominator[i][targets[i]] / torch.sum(denominator[i])
|
|
A.append(a_i)
|
|
# a_i's
|
|
atten_class = torch.stack(A)
|
|
# a_ij's
|
|
A = torch.min(atten_class.expand(n, n),
|
|
atten_class.view(-1, 1).expand(n, n)) # pairwise minimum of attention weights
|
|
|
|
W = S * A
|
|
W_P = W * p_mask.float()
|
|
W_N = W * n_mask.float()
|
|
W_P = W_P * (1 - torch.eye(n,
|
|
n).float().cuda()) # dist between (xi,xi) not necessarily 0, avoiding precision error
|
|
W_N = W_N * (1 - torch.eye(n, n).float().cuda())
|
|
|
|
L_P = 1.0 / 2 * torch.sum(W_P * torch.pow(dist, 2)) / torch.sum(W_P)
|
|
L_N = 1.0 / 2 * torch.sum(W_N * torch.pow(S_, 2)) / torch.sum(W_N)
|
|
|
|
L = (1 - self.l) * L_P + self.l * L_N
|
|
|
|
return L, global_feat
|