131 lines
4.5 KiB
Python
131 lines
4.5 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
|
|
from openselfsup.utils import print_log
|
|
|
|
from . import builder
|
|
from .registry import MODELS
|
|
|
|
|
|
@MODELS.register_module
|
|
class NPID(nn.Module):
|
|
"""NPID.
|
|
|
|
Implementation of "Unsupervised Feature Learning via Non-parametric
|
|
Instance Discrimination (https://arxiv.org/abs/1805.01978)".
|
|
|
|
Args:
|
|
backbone (dict): Config dict for module of backbone ConvNet.
|
|
neck (dict): Config dict for module of deep features to compact feature vectors.
|
|
Default: None.
|
|
head (dict): Config dict for module of loss functions. Default: None.
|
|
memory_bank (dict): Config dict for module of memory banks. Default: None.
|
|
neg_num (int): Number of negative samples for each image. Default: 65536.
|
|
ensure_neg (bool): If False, there is a small probability
|
|
that negative samples contain positive ones. Default: False.
|
|
pretrained (str, optional): Path to pre-trained weights. Default: None.
|
|
"""
|
|
|
|
def __init__(self,
|
|
backbone,
|
|
neck=None,
|
|
head=None,
|
|
memory_bank=None,
|
|
neg_num=65536,
|
|
ensure_neg=False,
|
|
pretrained=None):
|
|
super(NPID, self).__init__()
|
|
self.backbone = builder.build_backbone(backbone)
|
|
self.neck = builder.build_neck(neck)
|
|
self.head = builder.build_head(head)
|
|
self.memory_bank = builder.build_memory(memory_bank)
|
|
self.init_weights(pretrained=pretrained)
|
|
|
|
self.neg_num = neg_num
|
|
self.ensure_neg = ensure_neg
|
|
|
|
def init_weights(self, pretrained=None):
|
|
"""Initialize the weights of model.
|
|
|
|
Args:
|
|
pretrained (str, optional): Path to pre-trained weights.
|
|
Default: None.
|
|
"""
|
|
if pretrained is not None:
|
|
print_log('load model from: {}'.format(pretrained), logger='root')
|
|
self.backbone.init_weights(pretrained=pretrained)
|
|
self.neck.init_weights(init_linear='kaiming')
|
|
|
|
def forward_backbone(self, img):
|
|
"""Forward backbone.
|
|
|
|
Args:
|
|
img (Tensor): Input images of shape (N, C, H, W).
|
|
Typically these should be mean centered and std scaled.
|
|
|
|
Returns:
|
|
tuple[Tensor]: backbone outputs.
|
|
"""
|
|
x = self.backbone(img)
|
|
return x
|
|
|
|
def forward_train(self, img, idx, **kwargs):
|
|
"""Forward computation during training.
|
|
|
|
Args:
|
|
img (Tensor): Input images of shape (N, C, H, W).
|
|
Typically these should be mean centered and std scaled.
|
|
idx (Tensor): Index corresponding to each image.
|
|
kwargs: Any keyword arguments to be used to forward.
|
|
|
|
Returns:
|
|
dict[str, Tensor]: A dictionary of loss components.
|
|
"""
|
|
x = self.forward_backbone(img)
|
|
idx = idx.cuda()
|
|
feature = self.neck(x)[0]
|
|
feature = nn.functional.normalize(feature) # BxC
|
|
bs, feat_dim = feature.shape[:2]
|
|
neg_idx = self.memory_bank.multinomial.draw(bs * self.neg_num)
|
|
if self.ensure_neg:
|
|
neg_idx = neg_idx.view(bs, -1)
|
|
while True:
|
|
wrong = (neg_idx == idx.view(-1, 1))
|
|
if wrong.sum().item() > 0:
|
|
neg_idx[wrong] = self.memory_bank.multinomial.draw(
|
|
wrong.sum().item())
|
|
else:
|
|
break
|
|
neg_idx = neg_idx.flatten()
|
|
|
|
pos_feat = torch.index_select(self.memory_bank.feature_bank, 0,
|
|
idx) # BXC
|
|
neg_feat = torch.index_select(self.memory_bank.feature_bank, 0,
|
|
neg_idx).view(bs, self.neg_num,
|
|
feat_dim) # BxKxC
|
|
|
|
pos_logits = torch.einsum('nc,nc->n',
|
|
[pos_feat, feature]).unsqueeze(-1)
|
|
neg_logits = torch.bmm(neg_feat, feature.unsqueeze(2)).squeeze(2)
|
|
|
|
losses = self.head(pos_logits, neg_logits)
|
|
|
|
# update memory bank
|
|
with torch.no_grad():
|
|
self.memory_bank.update(idx, feature.detach())
|
|
|
|
return losses
|
|
|
|
def forward_test(self, img, **kwargs):
|
|
pass
|
|
|
|
def forward(self, img, mode='train', **kwargs):
|
|
if mode == 'train':
|
|
return self.forward_train(img, **kwargs)
|
|
elif mode == 'test':
|
|
return self.forward_test(img, **kwargs)
|
|
elif mode == 'extract':
|
|
return self.forward_backbone(img)
|
|
else:
|
|
raise Exception("No such mode: {}".format(mode))
|