138 lines
5.0 KiB
Python
138 lines
5.0 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from typing import Dict, List, Optional, Union
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from mmengine.dist import all_gather
|
|
from mmengine.model import ExponentialMovingAverage
|
|
|
|
from mmpretrain.registry import MODELS
|
|
from mmpretrain.structures import DataSample
|
|
from ..utils import batch_shuffle_ddp, batch_unshuffle_ddp
|
|
from .base import BaseSelfSupervisor
|
|
|
|
|
|
@MODELS.register_module()
|
|
class MoCo(BaseSelfSupervisor):
|
|
"""MoCo.
|
|
|
|
Implementation of `Momentum Contrast for Unsupervised Visual
|
|
Representation Learning <https://arxiv.org/abs/1911.05722>`_.
|
|
Part of the code is borrowed from:
|
|
`<https://github.com/facebookresearch/moco/blob/master/moco/builder.py>`_.
|
|
|
|
Args:
|
|
backbone (dict): Config dict for module of backbone.
|
|
neck (dict): Config dict for module of deep features to compact feature
|
|
vectors.
|
|
head (dict): Config dict for module of head functions.
|
|
queue_len (int): Number of negative keys maintained in the
|
|
queue. Defaults to 65536.
|
|
feat_dim (int): Dimension of compact feature vectors.
|
|
Defaults to 128.
|
|
momentum (float): Momentum coefficient for the momentum-updated
|
|
encoder. Defaults to 0.001.
|
|
pretrained (str, optional): The pretrained checkpoint path, support
|
|
local path and remote path. Defaults to None.
|
|
data_preprocessor (dict, optional): The config for preprocessing
|
|
input data. If None or no specified type, it will use
|
|
"SelfSupDataPreprocessor" as type.
|
|
See :class:`SelfSupDataPreprocessor` for more details.
|
|
Defaults to None.
|
|
init_cfg (Union[List[dict], dict], optional): Config dict for weight
|
|
initialization. Defaults to None.
|
|
"""
|
|
|
|
def __init__(self,
|
|
backbone: dict,
|
|
neck: dict,
|
|
head: dict,
|
|
queue_len: int = 65536,
|
|
feat_dim: int = 128,
|
|
momentum: float = 0.001,
|
|
pretrained: Optional[str] = None,
|
|
data_preprocessor: Optional[dict] = None,
|
|
init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
|
|
super().__init__(
|
|
backbone=backbone,
|
|
neck=neck,
|
|
head=head,
|
|
pretrained=pretrained,
|
|
data_preprocessor=data_preprocessor,
|
|
init_cfg=init_cfg)
|
|
|
|
# create momentum model
|
|
self.encoder_k = ExponentialMovingAverage(
|
|
nn.Sequential(self.backbone, self.neck), momentum)
|
|
|
|
# create the queue
|
|
self.queue_len = queue_len
|
|
self.register_buffer('queue', torch.randn(feat_dim, queue_len))
|
|
self.queue = nn.functional.normalize(self.queue, dim=0)
|
|
self.register_buffer('queue_ptr', torch.zeros(1, dtype=torch.long))
|
|
|
|
@torch.no_grad()
|
|
def _dequeue_and_enqueue(self, keys: torch.Tensor) -> None:
|
|
"""Update queue."""
|
|
# gather keys before updating queue
|
|
keys = torch.cat(all_gather(keys), dim=0)
|
|
|
|
batch_size = keys.shape[0]
|
|
|
|
ptr = int(self.queue_ptr)
|
|
assert self.queue_len % batch_size == 0 # for simplicity
|
|
|
|
# replace the keys at ptr (dequeue and enqueue)
|
|
self.queue[:, ptr:ptr + batch_size] = keys.transpose(0, 1)
|
|
ptr = (ptr + batch_size) % self.queue_len # move pointer
|
|
|
|
self.queue_ptr[0] = ptr
|
|
|
|
def loss(self, inputs: List[torch.Tensor], data_samples: List[DataSample],
|
|
**kwargs) -> Dict[str, torch.Tensor]:
|
|
"""The forward function in training.
|
|
|
|
Args:
|
|
inputs (List[torch.Tensor]): The input images.
|
|
data_samples (List[DataSample]): All elements required
|
|
during the forward function.
|
|
|
|
Returns:
|
|
Dict[str, torch.Tensor]: A dictionary of loss components.
|
|
"""
|
|
assert isinstance(inputs, list)
|
|
im_q = inputs[0]
|
|
im_k = inputs[1]
|
|
# compute query features from encoder_q
|
|
q = self.neck(self.backbone(im_q))[0] # queries: NxC
|
|
q = nn.functional.normalize(q, dim=1)
|
|
|
|
# compute key features
|
|
with torch.no_grad(): # no gradient to keys
|
|
# update the key encoder
|
|
self.encoder_k.update_parameters(
|
|
nn.Sequential(self.backbone, self.neck))
|
|
|
|
# shuffle for making use of BN
|
|
im_k, idx_unshuffle = batch_shuffle_ddp(im_k)
|
|
|
|
k = self.encoder_k(im_k)[0] # keys: NxC
|
|
k = nn.functional.normalize(k, dim=1)
|
|
|
|
# undo shuffle
|
|
k = batch_unshuffle_ddp(k, idx_unshuffle)
|
|
|
|
# compute logits
|
|
# Einstein sum is more intuitive
|
|
# positive logits: Nx1
|
|
l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
|
|
# negative logits: NxK
|
|
l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])
|
|
|
|
loss = self.head.loss(l_pos, l_neg)
|
|
# update the queue
|
|
self._dequeue_and_enqueue(k)
|
|
|
|
losses = dict(loss=loss)
|
|
return losses
|