mmclassification/mmpretrain/models/heads/contrastive_head.py
Yixiao Fang 08dc8c75d3
[Refactor] Add selfsup algorithms. (#1389)
* remove basehead

* add moco series

* add byol simclr simsiam

* add ut

* update configs

* add simsiam hook

* add and refactor beit

* update ut

* add cae

* update extract_feat

* refactor cae

* add mae

* refactor data preprocessor

* update heads

* add maskfeat

* add milan

* add simmim

* add mixmim

* fix lint

* fix ut

* fix lint

* add eva

* add densecl

* add barlowtwins

* add swav

* fix lint

* update readtherdocs rst

* update docs

* update

* Decrease UT memory usage

* Fix docstring

* update DALLEEncoder

* Update model docs

* refactor dalle encoder

* update docstring

* fix ut

* fix config error

* add val_cfg and test_cfg

* refactor clip generator

* fix lint

* pass check

* fix ut

* add lars

* update type of BEiT in configs

* Use MMEngine style momentum in EMA.

* apply mmpretrain solarize

---------

Co-authored-by: mzr1996 <mzr1996@163.com>
2023-03-06 16:53:15 +08:00

51 lines
1.6 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Union
import torch
from mmengine.model import BaseModule
from mmpretrain.registry import MODELS
@MODELS.register_module()
class ContrastiveHead(BaseModule):
"""Head for contrastive learning.
The contrastive loss is implemented in this head and is used in SimCLR,
MoCo, DenseCL, etc.
Args:
loss (dict): Config dict for module of loss functions.
temperature (float): The temperature hyper-parameter that
controls the concentration level of the distribution.
Defaults to 0.1.
init_cfg (dict or List[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
loss: dict,
temperature: float = 0.1,
init_cfg: Optional[Union[dict, List[dict]]] = None) -> None:
super().__init__(init_cfg=init_cfg)
self.loss_module = MODELS.build(loss)
self.temperature = temperature
def loss(self, pos: torch.Tensor, neg: torch.Tensor) -> torch.Tensor:
"""Forward function to compute contrastive loss.
Args:
pos (torch.Tensor): Nx1 positive similarity.
neg (torch.Tensor): Nxk negative similarity.
Returns:
torch.Tensor: The contrastive loss.
"""
N = pos.size(0)
logits = torch.cat((pos, neg), dim=1)
logits /= self.temperature
labels = torch.zeros((N, ), dtype=torch.long).to(pos.device)
loss = self.loss_module(logits, labels)
return loss