190 lines
6.1 KiB
Python
190 lines
6.1 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
|
|
from openselfsup.utils import print_log
|
|
|
|
from . import builder
|
|
from .registry import MODELS
|
|
|
|
|
|
@MODELS.register_module
|
|
class MOCO(nn.Module):
|
|
'''MOCO.
|
|
Part of the code is borrowed from:
|
|
"https://github.com/facebookresearch/moco/blob/master/moco/builder.py".
|
|
'''
|
|
|
|
def __init__(self,
|
|
backbone,
|
|
neck=None,
|
|
head=None,
|
|
pretrained=None,
|
|
queue_len=65536,
|
|
feat_dim=128,
|
|
momentum=0.999,
|
|
**kwargs):
|
|
super(MOCO, self).__init__()
|
|
self.encoder_q = nn.Sequential(
|
|
builder.build_backbone(backbone), builder.build_neck(neck))
|
|
self.encoder_k = nn.Sequential(
|
|
builder.build_backbone(backbone), builder.build_neck(neck))
|
|
self.backbone = self.encoder_q[0]
|
|
for param in self.encoder_k.parameters():
|
|
param.requires_grad = False
|
|
self.head = builder.build_head(head)
|
|
self.init_weights(pretrained=pretrained)
|
|
|
|
self.queue_len = queue_len
|
|
self.momentum = momentum
|
|
|
|
# create the queue
|
|
self.register_buffer("queue", torch.randn(feat_dim, queue_len))
|
|
self.queue = nn.functional.normalize(self.queue, dim=0)
|
|
self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
|
|
|
|
def init_weights(self, pretrained=None):
|
|
if pretrained is not None:
|
|
print_log('load model from: {}'.format(pretrained), logger='root')
|
|
self.encoder_q[0].init_weights(pretrained=pretrained)
|
|
self.encoder_q[1].init_weights(init_linear='kaiming')
|
|
for param_q, param_k in zip(self.encoder_q.parameters(),
|
|
self.encoder_k.parameters()):
|
|
param_k.data.copy_(param_q.data)
|
|
|
|
@torch.no_grad()
|
|
def _momentum_update_key_encoder(self):
|
|
"""
|
|
Momentum update of the key encoder
|
|
"""
|
|
for param_q, param_k in zip(self.encoder_q.parameters(),
|
|
self.encoder_k.parameters()):
|
|
param_k.data = param_k.data * self.momentum + \
|
|
param_q.data * (1. - self.momentum)
|
|
|
|
@torch.no_grad()
|
|
def _dequeue_and_enqueue(self, keys):
|
|
# gather keys before updating queue
|
|
keys = concat_all_gather(keys)
|
|
|
|
batch_size = keys.shape[0]
|
|
|
|
ptr = int(self.queue_ptr)
|
|
assert self.queue_len % batch_size == 0 # for simplicity
|
|
|
|
# replace the keys at ptr (dequeue and enqueue)
|
|
self.queue[:, ptr:ptr + batch_size] = keys.transpose(0, 1)
|
|
ptr = (ptr + batch_size) % self.queue_len # move pointer
|
|
|
|
self.queue_ptr[0] = ptr
|
|
|
|
@torch.no_grad()
|
|
def _batch_shuffle_ddp(self, x):
|
|
"""
|
|
Batch shuffle, for making use of BatchNorm.
|
|
*** Only support DistributedDataParallel (DDP) model. ***
|
|
"""
|
|
# gather from all gpus
|
|
batch_size_this = x.shape[0]
|
|
x_gather = concat_all_gather(x)
|
|
batch_size_all = x_gather.shape[0]
|
|
|
|
num_gpus = batch_size_all // batch_size_this
|
|
|
|
# random shuffle index
|
|
idx_shuffle = torch.randperm(batch_size_all).cuda()
|
|
|
|
# broadcast to all gpus
|
|
torch.distributed.broadcast(idx_shuffle, src=0)
|
|
|
|
# index for restoring
|
|
idx_unshuffle = torch.argsort(idx_shuffle)
|
|
|
|
# shuffled index for this gpu
|
|
gpu_idx = torch.distributed.get_rank()
|
|
idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx]
|
|
|
|
return x_gather[idx_this], idx_unshuffle
|
|
|
|
@torch.no_grad()
|
|
def _batch_unshuffle_ddp(self, x, idx_unshuffle):
|
|
"""
|
|
Undo batch shuffle.
|
|
*** Only support DistributedDataParallel (DDP) model. ***
|
|
"""
|
|
# gather from all gpus
|
|
batch_size_this = x.shape[0]
|
|
x_gather = concat_all_gather(x)
|
|
batch_size_all = x_gather.shape[0]
|
|
|
|
num_gpus = batch_size_all // batch_size_this
|
|
|
|
# restored index for this gpu
|
|
gpu_idx = torch.distributed.get_rank()
|
|
idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx]
|
|
|
|
return x_gather[idx_this]
|
|
|
|
def forward_train(self, img, **kwargs):
|
|
assert img.dim() == 5, \
|
|
"Input must have 5 dims, got: {}".format(img.dim())
|
|
im_q = img[:, 0, ...].contiguous()
|
|
im_k = img[:, 1, ...].contiguous()
|
|
# compute query features
|
|
q = self.encoder_q(im_q)[0] # queries: NxC
|
|
q = nn.functional.normalize(q, dim=1)
|
|
|
|
# compute key features
|
|
with torch.no_grad(): # no gradient to keys
|
|
self._momentum_update_key_encoder() # update the key encoder
|
|
|
|
# shuffle for making use of BN
|
|
im_k, idx_unshuffle = self._batch_shuffle_ddp(im_k)
|
|
|
|
k = self.encoder_k(im_k)[0] # keys: NxC
|
|
k = nn.functional.normalize(k, dim=1)
|
|
|
|
# undo shuffle
|
|
k = self._batch_unshuffle_ddp(k, idx_unshuffle)
|
|
|
|
# compute logits
|
|
# Einstein sum is more intuitive
|
|
# positive logits: Nx1
|
|
l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
|
|
# negative logits: NxK
|
|
l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])
|
|
|
|
losses = self.head(l_pos, l_neg)
|
|
self._dequeue_and_enqueue(k)
|
|
|
|
return losses
|
|
|
|
def forward_test(self, img, **kwargs):
|
|
pass
|
|
|
|
def forward(self, img, mode='train', **kwargs):
|
|
if mode == 'train':
|
|
return self.forward_train(img, **kwargs)
|
|
elif mode == 'test':
|
|
return self.forward_test(img, **kwargs)
|
|
elif mode == 'extract':
|
|
return self.backbone(img)
|
|
else:
|
|
raise Exception("No such mode: {}".format(mode))
|
|
|
|
|
|
# utils
|
|
@torch.no_grad()
|
|
def concat_all_gather(tensor):
|
|
"""
|
|
Performs all_gather operation on the provided tensors.
|
|
*** Warning ***: torch.distributed.all_gather has no gradient.
|
|
"""
|
|
tensors_gather = [
|
|
torch.ones_like(tensor)
|
|
for _ in range(torch.distributed.get_world_size())
|
|
]
|
|
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
|
|
|
|
output = torch.cat(tensors_gather, dim=0)
|
|
return output
|