mirror of https://github.com/JDAI-CV/fast-reid.git
127 lines
3.9 KiB
Python
127 lines
3.9 KiB
Python
# encoding: utf-8
|
|
"""
|
|
@author: xingyu liao
|
|
@contact: sherlockliao01@gmail.com
|
|
"""
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch import nn
|
|
|
|
from fastreid.modeling.losses.utils import concat_all_gather
|
|
from fastreid.utils import comm
|
|
from .baseline import Baseline
|
|
from .build import META_ARCH_REGISTRY
|
|
|
|
|
|
@META_ARCH_REGISTRY.register()
|
|
class MoCo(Baseline):
|
|
def __init__(self, cfg):
|
|
super().__init__(cfg)
|
|
|
|
dim = cfg.MODEL.HEADS.EMBEDDING_DIM if cfg.MODEL.HEADS.EMBEDDING_DIM \
|
|
else cfg.MODEL.BACKBONE.FEAT_DIM
|
|
size = cfg.MODEL.QUEUE_SIZE
|
|
self.memory = Memory(dim, size)
|
|
|
|
def losses(self, outputs, gt_labels):
|
|
"""
|
|
Compute loss from modeling's outputs, the loss function input arguments
|
|
must be the same as the outputs of the model forwarding.
|
|
"""
|
|
# regular reid loss
|
|
loss_dict = super().losses(outputs, gt_labels)
|
|
|
|
# memory loss
|
|
pred_features = outputs['features']
|
|
loss_mb = self.memory(pred_features, gt_labels)
|
|
loss_dict['loss_mb'] = loss_mb
|
|
return loss_dict
|
|
|
|
|
|
class Memory(nn.Module):
|
|
"""
|
|
Build a MoCo memory with a queue
|
|
https://arxiv.org/abs/1911.05722
|
|
"""
|
|
|
|
def __init__(self, dim=512, K=65536):
|
|
"""
|
|
dim: feature dimension (default: 128)
|
|
K: queue size; number of negative keys (default: 65536)
|
|
"""
|
|
super().__init__()
|
|
self.K = K
|
|
|
|
self.margin = 0.25
|
|
self.gamma = 32
|
|
|
|
# create the queue
|
|
self.register_buffer("queue", torch.randn(dim, K))
|
|
self.queue = F.normalize(self.queue, dim=0)
|
|
|
|
self.register_buffer("queue_label", torch.zeros((1, K), dtype=torch.long))
|
|
self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
|
|
|
|
@torch.no_grad()
|
|
def _dequeue_and_enqueue(self, keys, targets):
|
|
# gather keys/targets before updating queue
|
|
if comm.get_world_size() > 1:
|
|
keys = concat_all_gather(keys)
|
|
targets = concat_all_gather(targets)
|
|
else:
|
|
keys = keys.detach()
|
|
targets = targets.detach()
|
|
|
|
batch_size = keys.shape[0]
|
|
|
|
ptr = int(self.queue_ptr)
|
|
assert self.K % batch_size == 0 # for simplicity
|
|
|
|
# replace the keys at ptr (dequeue and enqueue)
|
|
self.queue[:, ptr:ptr + batch_size] = keys.T
|
|
self.queue_label[:, ptr:ptr + batch_size] = targets
|
|
ptr = (ptr + batch_size) % self.K # move pointer
|
|
|
|
self.queue_ptr[0] = ptr
|
|
|
|
def forward(self, feat_q, targets):
|
|
"""
|
|
Memory bank enqueue and compute metric loss
|
|
Args:
|
|
feat_q: model features
|
|
targets: gt labels
|
|
|
|
Returns:
|
|
"""
|
|
# normalize embedding features
|
|
feat_q = F.normalize(feat_q, p=2, dim=1)
|
|
# dequeue and enqueue
|
|
self._dequeue_and_enqueue(feat_q.detach(), targets)
|
|
# compute loss
|
|
loss = self._pairwise_cosface(feat_q, targets)
|
|
return loss
|
|
|
|
def _pairwise_cosface(self, feat_q, targets):
|
|
dist_mat = torch.matmul(feat_q, self.queue)
|
|
|
|
N, M = dist_mat.size() # (bsz, memory)
|
|
is_pos = targets.view(N, 1).expand(N, M).eq(self.queue_label.expand(N, M)).float()
|
|
is_neg = targets.view(N, 1).expand(N, M).ne(self.queue_label.expand(N, M)).float()
|
|
|
|
# Mask scores related to themselves
|
|
same_indx = torch.eye(N, N, device=is_pos.device)
|
|
other_indx = torch.zeros(N, M - N, device=is_pos.device)
|
|
same_indx = torch.cat((same_indx, other_indx), dim=1)
|
|
is_pos = is_pos - same_indx
|
|
|
|
s_p = dist_mat * is_pos
|
|
s_n = dist_mat * is_neg
|
|
|
|
logit_p = -self.gamma * s_p + (-99999999.) * (1 - is_pos)
|
|
logit_n = self.gamma * (s_n + self.margin) + (-99999999.) * (1 - is_neg)
|
|
|
|
loss = F.softplus(torch.logsumexp(logit_p, dim=1) + torch.logsumexp(logit_n, dim=1)).mean()
|
|
|
|
return loss
|