# encoding: utf-8 """ @author: xingyu liao @contact: sherlockliao01@gmail.com """ import torch import torch.nn.functional as F from torch import nn from fastreid.modeling.losses.utils import concat_all_gather from fastreid.utils import comm from .baseline import Baseline from .build import META_ARCH_REGISTRY @META_ARCH_REGISTRY.register() class MoCo(Baseline): def __init__(self, cfg): super().__init__(cfg) dim = cfg.MODEL.HEADS.EMBEDDING_DIM if cfg.MODEL.HEADS.EMBEDDING_DIM \ else cfg.MODEL.BACKBONE.FEAT_DIM size = cfg.MODEL.QUEUE_SIZE self.memory = Memory(dim, size) def losses(self, outputs, gt_labels): """ Compute loss from modeling's outputs, the loss function input arguments must be the same as the outputs of the model forwarding. """ # regular reid loss loss_dict = super().losses(outputs, gt_labels) # memory loss pred_features = outputs['features'] loss_mb = self.memory(pred_features, gt_labels) loss_dict['loss_mb'] = loss_mb return loss_dict class Memory(nn.Module): """ Build a MoCo memory with a queue https://arxiv.org/abs/1911.05722 """ def __init__(self, dim=512, K=65536): """ dim: feature dimension (default: 128) K: queue size; number of negative keys (default: 65536) """ super().__init__() self.K = K self.margin = 0.25 self.gamma = 32 # create the queue self.register_buffer("queue", torch.randn(dim, K)) self.queue = F.normalize(self.queue, dim=0) self.register_buffer("queue_label", torch.zeros((1, K), dtype=torch.long)) self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) @torch.no_grad() def _dequeue_and_enqueue(self, keys, targets): # gather keys/targets before updating queue if comm.get_world_size() > 1: keys = concat_all_gather(keys) targets = concat_all_gather(targets) else: keys = keys.detach() targets = targets.detach() batch_size = keys.shape[0] ptr = int(self.queue_ptr) assert self.K % batch_size == 0 # for simplicity # replace the keys at ptr (dequeue and enqueue) self.queue[:, ptr:ptr + batch_size] = keys.T self.queue_label[:, ptr:ptr + batch_size] = targets ptr = (ptr + batch_size) % self.K # move pointer self.queue_ptr[0] = ptr def forward(self, feat_q, targets): """ Memory bank enqueue and compute metric loss Args: feat_q: model features targets: gt labels Returns: """ # normalize embedding features feat_q = F.normalize(feat_q, p=2, dim=1) # dequeue and enqueue self._dequeue_and_enqueue(feat_q.detach(), targets) # compute loss loss = self._pairwise_cosface(feat_q, targets) return loss def _pairwise_cosface(self, feat_q, targets): dist_mat = torch.matmul(feat_q, self.queue) N, M = dist_mat.size() # (bsz, memory) is_pos = targets.view(N, 1).expand(N, M).eq(self.queue_label.expand(N, M)).float() is_neg = targets.view(N, 1).expand(N, M).ne(self.queue_label.expand(N, M)).float() # Mask scores related to themselves same_indx = torch.eye(N, N, device=is_pos.device) other_indx = torch.zeros(N, M - N, device=is_pos.device) same_indx = torch.cat((same_indx, other_indx), dim=1) is_pos = is_pos - same_indx s_p = dist_mat * is_pos s_n = dist_mat * is_neg logit_p = -self.gamma * s_p + (-99999999.) * (1 - is_pos) logit_n = self.gamma * (s_n + self.margin) + (-99999999.) * (1 - is_neg) loss = F.softplus(torch.logsumexp(logit_p, dim=1) + torch.logsumexp(logit_n, dim=1)).mean() return loss