moco-v3/moco/builder.py

140 lines
5.1 KiB
Python
Raw Normal View History

2021-06-17 17:09:43 +08:00
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
2021-06-17 10:59:59 +08:00
import torch
import torch.nn as nn
class MoCo(nn.Module):
"""
2021-06-17 17:09:43 +08:00
Build a MoCo model with: a base encoder, a momentum encoder
2021-06-17 10:59:59 +08:00
https://arxiv.org/abs/1911.05722
"""
2021-06-24 18:44:50 +08:00
def __init__(self, base_encoder, with_vit, dim=256, mlp_dim=4096, T=1.0):
2021-06-17 10:59:59 +08:00
"""
2021-06-17 17:09:43 +08:00
dim: feature dimension (default: 256)
mlp_dim: hidden dimension in MLPs (default: 4096)
m: moco momentum of updating momentum encoder (default: 0.99)
T: softmax temperature (default: 1.0)
2021-06-17 10:59:59 +08:00
"""
super(MoCo, self).__init__()
self.T = T
2021-07-10 03:28:53 +08:00
self.criterion = nn.CrossEntropyLoss()
2021-06-24 18:44:50 +08:00
if with_vit:
self._init_encoders_with_vit(base_encoder, dim, mlp_dim)
else:
self._init_encoders_with_resnet(base_encoder, dim, mlp_dim)
2021-06-17 10:59:59 +08:00
2021-06-24 18:44:50 +08:00
for param_b, param_m in zip(self.base_encoder.parameters(), self.momentum_encoder.parameters()):
param_m.data.copy_(param_b.data) # initialize
param_m.requires_grad = False # not update by gradient
2021-07-10 02:47:28 +08:00
def _build_mlp(self, num_layers, input_dim, mlp_dim, output_dim):
mlp = []
for l in range(num_layers):
dim1 = input_dim if l == 0 else mlp_dim
dim2 = output_dim if l == num_layers - 1 else mlp_dim
mlp.append(nn.Linear(dim1, dim2, bias=False))
if l < num_layers - 1:
mlp.append(nn.BatchNorm1d(dim2))
mlp.append(nn.ReLU(inplace=True))
else:
mlp.append(nn.BatchNorm1d(dim2, affine=False))
2021-07-10 03:28:53 +08:00
return nn.Sequential(*mlp)
2021-06-24 18:44:50 +08:00
def _init_encoders_with_resnet(self, base_encoder, dim=256, mlp_dim=4096):
2021-06-17 10:59:59 +08:00
# create the encoders
2021-06-17 17:09:43 +08:00
# num_classes is the hidden MLP dimension
2021-06-24 18:44:50 +08:00
self.base_encoder = base_encoder(num_classes=mlp_dim)
self.momentum_encoder = base_encoder(num_classes=mlp_dim)
2021-06-17 17:09:43 +08:00
2021-07-02 03:54:43 +08:00
hidden_dim = self.base_encoder.fc.weight.shape[1]
2021-07-10 03:28:53 +08:00
del self.base_encoder.fc, self.momentum_encoder.fc # remove original fc layer
2021-07-10 02:47:28 +08:00
self.base_encoder.fc = self._build_mlp(2, hidden_dim, mlp_dim, dim)
self.momentum_encoder.fc = self._build_mlp(2, hidden_dim, mlp_dim, dim)
2021-06-27 18:22:50 +08:00
# build a 2-layer predictor
2021-07-10 02:47:28 +08:00
self.predictor = self._build_mlp(2, dim, mlp_dim, dim)
2021-06-17 17:09:43 +08:00
2021-06-24 18:44:50 +08:00
def _init_encoders_with_vit(self, base_encoder, dim=256, mlp_dim=4096):
# create the encoders
# num_classes is the hidden MLP dimension
self.base_encoder = base_encoder(num_classes=mlp_dim)
self.momentum_encoder = base_encoder(num_classes=mlp_dim)
2021-07-02 03:54:43 +08:00
hidden_dim = self.base_encoder.head.weight.shape[1]
2021-07-10 03:28:53 +08:00
del self.base_encoder.head, self.momentum_encoder.head # remove original fc layer
2021-07-10 02:47:28 +08:00
self.base_encoder.head = self._build_mlp(3, hidden_dim, mlp_dim, dim)
self.momentum_encoder.head = self._build_mlp(3, hidden_dim, mlp_dim, dim)
2021-06-27 18:22:50 +08:00
# build a 2-layer predictor
2021-07-10 02:47:28 +08:00
self.predictor = self._build_mlp(2, dim, mlp_dim, dim)
2021-06-17 10:59:59 +08:00
@torch.no_grad()
2021-06-24 14:26:16 +08:00
def _update_momentum_encoder(self, m):
2021-06-17 17:09:43 +08:00
"""Momentum update of the momentum encoder"""
for param_b, param_m in zip(self.base_encoder.parameters(), self.momentum_encoder.parameters()):
2021-06-24 14:26:16 +08:00
param_m.data = param_m.data * m + param_b.data * (1. - m)
2021-06-17 10:59:59 +08:00
2021-07-10 03:41:16 +08:00
def ctr_loss(self, q, k, labels):
2021-07-10 03:28:53 +08:00
# normalize
q = nn.functional.normalize(q, dim=1)
k = nn.functional.normalize(k, dim=1)
# Einstein sum is more intuitive
logits = torch.einsum('nc,mc->nm', [q, k]) / self.T
2021-07-10 03:41:16 +08:00
return self.criterion(logits, labels) * (2 * self.T), logits
2021-07-10 03:28:53 +08:00
def forward(self, x1, x2, m):
2021-06-17 10:59:59 +08:00
"""
Input:
2021-07-10 03:28:53 +08:00
x1: first views of images
x2: second views of images
2021-06-24 14:26:16 +08:00
m: moco momentum
2021-06-17 10:59:59 +08:00
Output:
2021-07-10 03:28:53 +08:00
loss
2021-06-17 10:59:59 +08:00
"""
2021-06-17 17:09:43 +08:00
# compute features
2021-07-10 03:28:53 +08:00
q1 = self.predictor(self.base_encoder(x1))
q2 = self.predictor(self.base_encoder(x2))
2021-06-17 17:09:43 +08:00
with torch.no_grad(): # no gradient
2021-06-24 14:26:16 +08:00
self._update_momentum_encoder(m) # update the momentum encoder
2021-06-17 17:09:43 +08:00
2021-07-10 03:28:53 +08:00
# compute momentum features as targets
k1 = self.momentum_encoder(x1)
k2 = self.momentum_encoder(x2)
2021-06-17 10:59:59 +08:00
2021-07-10 03:28:53 +08:00
# gather all targets
k1 = concat_all_gather(k1)
k2 = concat_all_gather(k2)
2021-06-17 10:59:59 +08:00
2021-07-10 03:41:16 +08:00
N = q1.shape[0] # batch size per GPU
labels = (torch.arange(N, dtype=torch.long) + N * torch.distributed.get_rank()).cuda()
l1, logits1 = self.ctr_loss(q1, k2, labels)
l2, logits2 = self.ctr_loss(q2, k1, labels)
return l1 + l2, logits1, logits2, labels
2021-06-17 10:59:59 +08:00
# utils
@torch.no_grad()
def concat_all_gather(tensor):
"""
Performs all_gather operation on the provided tensors.
*** Warning ***: torch.distributed.all_gather has no gradient.
"""
tensors_gather = [torch.ones_like(tensor)
for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
output = torch.cat(tensors_gather, dim=0)
return output