mirror of https://github.com/JDAI-CV/fast-reid.git
93 lines
3.2 KiB
Python
93 lines
3.2 KiB
Python
# encoding: utf-8
|
|
"""
|
|
@author: liaoxingyu
|
|
@contact: sherlockliao01@gmail.com
|
|
"""
|
|
|
|
import math
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch.nn import Parameter
|
|
|
|
from ..model_utils import weights_init_kaiming
|
|
from ..layers import *
|
|
|
|
|
|
class OSM(nn.Module):
|
|
def __init__(self, cfg, in_feat, pool_layer=nn.AdaptiveAvgPool2d(1)):
|
|
super().__init__()
|
|
self._num_classes = cfg.MODEL.HEADS.NUM_CLASSES
|
|
|
|
self.pool_layer = nn.Sequential(
|
|
pool_layer,
|
|
Flatten()
|
|
)
|
|
# bnneck
|
|
self.bnneck = NoBiasBatchNorm1d(in_feat)
|
|
self.bnneck.apply(weights_init_kaiming)
|
|
|
|
# classifier
|
|
self.alpha = 1.2 # margin of weighted contrastive loss, as mentioned in the paper
|
|
self.l = 0.5 # hyperparameter controlling weights of positive set and the negative set
|
|
# I haven't been able to figure out the use of \sigma CAA 0.18
|
|
self.osm_sigma = 0.8 # \sigma OSM (0.8) as mentioned in paper
|
|
|
|
self.weight = Parameter(torch.Tensor(self._num_classes, in_feat))
|
|
self.reset_parameters()
|
|
|
|
def reset_parameters(self):
|
|
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
|
|
|
def forward(self, features, targets=None):
|
|
|
|
global_feat = self.pool_layer(features)
|
|
bn_feat = self.bnneck(global_feat)
|
|
if not self.training:
|
|
return bn_feat
|
|
|
|
bn_feat = F.normalize(bn_feat)
|
|
n = bn_feat.size(0)
|
|
|
|
# Compute pairwise distance, replace by the official when merged
|
|
dist = torch.pow(bn_feat, 2).sum(dim=1, keepdim=True).expand(n, n)
|
|
dist = dist + dist.t()
|
|
dist.addmm_(1, -2, bn_feat, bn_feat.t())
|
|
dist = dist.clamp(min=1e-12).sqrt() # for numerical stability & pairwise distance, dij
|
|
|
|
S = torch.exp(-1.0 * torch.pow(dist, 2) / (self.osm_sigma * self.osm_sigma))
|
|
S_ = torch.clamp(self.alpha - dist, min=1e-12) # max (0 , \alpha - dij) # 1e-12, 0 may result in nan error
|
|
|
|
p_mask = targets.expand(n, n).eq(targets.expand(n, n).t()) # same label == 1
|
|
n_mask = torch.bitwise_not(p_mask) # oposite label == 1
|
|
|
|
S = S * p_mask.float()
|
|
S = S + S_ * n_mask.float()
|
|
|
|
denominator = torch.exp(F.linear(bn_feat, F.normalize(self.weight)))
|
|
|
|
A = [] # attention corresponding to each feature fector
|
|
for i in range(n):
|
|
a_i = denominator[i][targets[i]] / torch.sum(denominator[i])
|
|
A.append(a_i)
|
|
# a_i's
|
|
atten_class = torch.stack(A)
|
|
# a_ij's
|
|
A = torch.min(atten_class.expand(n, n),
|
|
atten_class.view(-1, 1).expand(n, n)) # pairwise minimum of attention weights
|
|
|
|
W = S * A
|
|
W_P = W * p_mask.float()
|
|
W_N = W * n_mask.float()
|
|
W_P = W_P * (1 - torch.eye(n,
|
|
n).float().cuda()) # dist between (xi,xi) not necessarily 0, avoiding precision error
|
|
W_N = W_N * (1 - torch.eye(n, n).float().cuda())
|
|
|
|
L_P = 1.0 / 2 * torch.sum(W_P * torch.pow(dist, 2)) / torch.sum(W_P)
|
|
L_N = 1.0 / 2 * torch.sum(W_N * torch.pow(S_, 2)) / torch.sum(W_N)
|
|
|
|
L = (1 - self.l) * L_P + self.l * L_N
|
|
|
|
return L, global_feat
|