moco-v3/moco/builder.py

112 lines
4.3 KiB
Python
Raw Normal View History

2021-06-17 17:09:43 +08:00
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
2021-06-17 10:59:59 +08:00
import torch
import torch.nn as nn
class MoCo(nn.Module):
"""
2021-06-17 17:09:43 +08:00
Build a MoCo model with: a base encoder, a momentum encoder
2021-06-17 10:59:59 +08:00
https://arxiv.org/abs/1911.05722
"""
2021-06-17 17:09:43 +08:00
def __init__(self, base_encoder, dim=256, mlp_dim=4096, m=0.99, T=1.0):
2021-06-17 10:59:59 +08:00
"""
2021-06-17 17:09:43 +08:00
dim: feature dimension (default: 256)
mlp_dim: hidden dimension in MLPs (default: 4096)
m: moco momentum of updating momentum encoder (default: 0.99)
T: softmax temperature (default: 1.0)
2021-06-17 10:59:59 +08:00
"""
super(MoCo, self).__init__()
self.m = m
self.T = T
# create the encoders
2021-06-17 17:09:43 +08:00
# num_classes is the hidden MLP dimension
self.base_encoder = base_encoder(num_classes=mlp_dim, zero_init_residual=True)
self.momentum_encoder = base_encoder(num_classes=mlp_dim, zero_init_residual=True)
self.base_encoder.fc = nn.Sequential(self.base_encoder.fc,
nn.BatchNorm1d(mlp_dim),
nn.ReLU(inplace=True), # first layer
2021-06-17 17:39:28 +08:00
nn.Linear(mlp_dim, dim)) # second layer
2021-06-17 17:09:43 +08:00
self.base_encoder.fc[0].bias.requires_grad = False # hack: not use bias as it is followed by BN
self.momentum_encoder.fc = nn.Sequential(self.momentum_encoder.fc,
nn.BatchNorm1d(mlp_dim),
nn.ReLU(inplace=True), # first layer
2021-06-17 17:39:28 +08:00
nn.Linear(mlp_dim, dim)) # second layer
2021-06-17 17:09:43 +08:00
for param_b, param_m in zip(self.base_encoder.parameters(), self.momentum_encoder.parameters()):
param_m.data.copy_(param_b.data) # initialize
param_m.requires_grad = False # not update by gradient
# build a 2-layer predictor
self.predictor = nn.Sequential(nn.Linear(dim, mlp_dim, bias=False),
nn.BatchNorm1d(mlp_dim),
nn.ReLU(inplace=True), # hidden layer
nn.Linear(mlp_dim, dim)) # output layer
2021-06-17 10:59:59 +08:00
@torch.no_grad()
2021-06-17 17:09:43 +08:00
def _update_momentum_encoder(self):
"""Momentum update of the momentum encoder"""
for param_b, param_m in zip(self.base_encoder.parameters(), self.momentum_encoder.parameters()):
param_m.data = param_m.data * self.m + param_b.data * (1. - self.m)
2021-06-17 10:59:59 +08:00
2021-06-17 17:09:43 +08:00
def forward(self, im1, im2):
2021-06-17 10:59:59 +08:00
"""
Input:
2021-06-17 17:09:43 +08:00
im1: first views of images
im2: second views of images
2021-06-17 10:59:59 +08:00
Output:
logits, targets
"""
2021-06-17 17:09:43 +08:00
# compute features
p1 = self.predictor(self.base_encoder(im1))
p2 = self.predictor(self.base_encoder(im2))
# normalize
p1 = nn.functional.normalize(p1, dim=1)
p2 = nn.functional.normalize(p2, dim=1)
# compute momentum features as targets
with torch.no_grad(): # no gradient
self._update_momentum_encoder() # update the momentum encoder
t1 = self.momentum_encoder(im1)
t2 = self.momentum_encoder(im2)
# normalize
t1 = nn.functional.normalize(t1, dim=1)
t2 = nn.functional.normalize(t2, dim=1)
# gather all targets
t1 = concat_all_gather(t1)
t2 = concat_all_gather(t2)
2021-06-17 10:59:59 +08:00
# compute logits
# Einstein sum is more intuitive
2021-06-17 17:09:43 +08:00
logits1 = torch.einsum('nc,mc->nm', [p1, t2]) / self.T
logits2 = torch.einsum('nc,mc->nm', [p2, t1]) / self.T
2021-06-17 10:59:59 +08:00
2021-06-17 17:09:43 +08:00
N = logits1.shape[0] # batch size per GPU
2021-06-17 17:39:28 +08:00
labels = torch.arange(N, dtype=torch.long) + N * torch.distributed.get_rank()
2021-06-17 10:59:59 +08:00
2021-06-17 17:39:28 +08:00
return logits1, logits2, labels.cuda()
2021-06-17 10:59:59 +08:00
# utils
@torch.no_grad()
def concat_all_gather(tensor):
"""
Performs all_gather operation on the provided tensors.
*** Warning ***: torch.distributed.all_gather has no gradient.
"""
tensors_gather = [torch.ones_like(tensor)
for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
output = torch.cat(tensors_gather, dim=0)
return output