Yixiao Fang 08dc8c75d3
[Refactor] Add selfsup algorithms. (#1389)
* remove basehead

* add moco series

* add byol simclr simsiam

* add ut

* update configs

* add simsiam hook

* add and refactor beit

* update ut

* add cae

* update extract_feat

* refactor cae

* add mae

* refactor data preprocessor

* update heads

* add maskfeat

* add milan

* add simmim

* add mixmim

* fix lint

* fix ut

* fix lint

* add eva

* add densecl

* add barlowtwins

* add swav

* fix lint

* update readtherdocs rst

* update docs

* update

* Decrease UT memory usage

* Fix docstring

* update DALLEEncoder

* Update model docs

* refactor dalle encoder

* update docstring

* fix ut

* fix config error

* add val_cfg and test_cfg

* refactor clip generator

* fix lint

* pass check

* fix ut

* add lars

* update type of BEiT in configs

* Use MMEngine style momentum in EMA.

* apply mmpretrain solarize

---------

Co-authored-by: mzr1996 <mzr1996@163.com>
2023-03-06 16:53:15 +08:00

90 lines
3.3 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Union
import torch
import torch.nn as nn
from mmpretrain.registry import MODELS
from mmpretrain.structures import DataSample
from ..utils import CosineEMA
from .base import BaseSelfSupervisor
@MODELS.register_module()
class BYOL(BaseSelfSupervisor):
"""BYOL.
Implementation of `Bootstrap Your Own Latent: A New Approach to
Self-Supervised Learning <https://arxiv.org/abs/2006.07733>`_.
Args:
backbone (dict): Config dict for module of backbone.
neck (dict): Config dict for module of deep features
to compact feature vectors.
head (dict): Config dict for module of head functions.
base_momentum (float): The base momentum coefficient for the target
network. Defaults to 0.004.
pretrained (str, optional): The pretrained checkpoint path, support
local path and remote path. Defaults to None.
data_preprocessor (dict, optional): The config for preprocessing
input data. If None or no specified type, it will use
"SelfSupDataPreprocessor" as type.
See :class:`SelfSupDataPreprocessor` for more details.
Defaults to None.
init_cfg (Union[List[dict], dict], optional): Config dict for weight
initialization. Defaults to None.
"""
def __init__(self,
backbone: dict,
neck: dict,
head: dict,
base_momentum: float = 0.004,
pretrained: Optional[str] = None,
data_preprocessor: Optional[dict] = None,
init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
super().__init__(
backbone=backbone,
neck=neck,
head=head,
pretrained=pretrained,
data_preprocessor=data_preprocessor,
init_cfg=init_cfg)
# create momentum model
self.target_net = CosineEMA(
nn.Sequential(self.backbone, self.neck), momentum=base_momentum)
def loss(self, inputs: List[torch.Tensor], data_samples: List[DataSample],
**kwargs) -> Dict[str, torch.Tensor]:
"""The forward function in training.
Args:
inputs (List[torch.Tensor]): The input images.
data_samples (List[DataSample]): All elements required
during the forward function.
Returns:
Dict[str, torch.Tensor]: A dictionary of loss components.
"""
assert isinstance(inputs, list)
img_v1 = inputs[0]
img_v2 = inputs[1]
# compute online features
proj_online_v1 = self.neck(self.backbone(img_v1))[0]
proj_online_v2 = self.neck(self.backbone(img_v2))[0]
# compute target features
with torch.no_grad():
# update the target net
self.target_net.update_parameters(
nn.Sequential(self.backbone, self.neck))
proj_target_v1 = self.target_net(img_v1)[0]
proj_target_v2 = self.target_net(img_v2)[0]
loss_1 = self.head.loss(proj_online_v1, proj_target_v2)
loss_2 = self.head.loss(proj_online_v2, proj_target_v1)
losses = dict(loss=2. * (loss_1 + loss_2))
return losses