90 lines
3.3 KiB
Python
90 lines
3.3 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from typing import Dict, List, Optional, Union
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from mmpretrain.registry import MODELS
|
|
from mmpretrain.structures import DataSample
|
|
from ..utils import CosineEMA
|
|
from .base import BaseSelfSupervisor
|
|
|
|
|
|
@MODELS.register_module()
|
|
class BYOL(BaseSelfSupervisor):
|
|
"""BYOL.
|
|
|
|
Implementation of `Bootstrap Your Own Latent: A New Approach to
|
|
Self-Supervised Learning <https://arxiv.org/abs/2006.07733>`_.
|
|
|
|
Args:
|
|
backbone (dict): Config dict for module of backbone.
|
|
neck (dict): Config dict for module of deep features
|
|
to compact feature vectors.
|
|
head (dict): Config dict for module of head functions.
|
|
base_momentum (float): The base momentum coefficient for the target
|
|
network. Defaults to 0.004.
|
|
pretrained (str, optional): The pretrained checkpoint path, support
|
|
local path and remote path. Defaults to None.
|
|
data_preprocessor (dict, optional): The config for preprocessing
|
|
input data. If None or no specified type, it will use
|
|
"SelfSupDataPreprocessor" as type.
|
|
See :class:`SelfSupDataPreprocessor` for more details.
|
|
Defaults to None.
|
|
init_cfg (Union[List[dict], dict], optional): Config dict for weight
|
|
initialization. Defaults to None.
|
|
"""
|
|
|
|
def __init__(self,
|
|
backbone: dict,
|
|
neck: dict,
|
|
head: dict,
|
|
base_momentum: float = 0.004,
|
|
pretrained: Optional[str] = None,
|
|
data_preprocessor: Optional[dict] = None,
|
|
init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
|
|
super().__init__(
|
|
backbone=backbone,
|
|
neck=neck,
|
|
head=head,
|
|
pretrained=pretrained,
|
|
data_preprocessor=data_preprocessor,
|
|
init_cfg=init_cfg)
|
|
|
|
# create momentum model
|
|
self.target_net = CosineEMA(
|
|
nn.Sequential(self.backbone, self.neck), momentum=base_momentum)
|
|
|
|
def loss(self, inputs: List[torch.Tensor], data_samples: List[DataSample],
|
|
**kwargs) -> Dict[str, torch.Tensor]:
|
|
"""The forward function in training.
|
|
|
|
Args:
|
|
inputs (List[torch.Tensor]): The input images.
|
|
data_samples (List[DataSample]): All elements required
|
|
during the forward function.
|
|
|
|
Returns:
|
|
Dict[str, torch.Tensor]: A dictionary of loss components.
|
|
"""
|
|
assert isinstance(inputs, list)
|
|
img_v1 = inputs[0]
|
|
img_v2 = inputs[1]
|
|
# compute online features
|
|
proj_online_v1 = self.neck(self.backbone(img_v1))[0]
|
|
proj_online_v2 = self.neck(self.backbone(img_v2))[0]
|
|
# compute target features
|
|
with torch.no_grad():
|
|
# update the target net
|
|
self.target_net.update_parameters(
|
|
nn.Sequential(self.backbone, self.neck))
|
|
|
|
proj_target_v1 = self.target_net(img_v1)[0]
|
|
proj_target_v2 = self.target_net(img_v2)[0]
|
|
|
|
loss_1 = self.head.loss(proj_online_v1, proj_target_v2)
|
|
loss_2 = self.head.loss(proj_online_v2, proj_target_v1)
|
|
|
|
losses = dict(loss=2. * (loss_1 + loss_2))
|
|
return losses
|