mirror of
https://github.com/open-mmlab/mmclassification.git
synced 2025-06-03 21:53:55 +08:00
* remove basehead * add moco series * add byol simclr simsiam * add ut * update configs * add simsiam hook * add and refactor beit * update ut * add cae * update extract_feat * refactor cae * add mae * refactor data preprocessor * update heads * add maskfeat * add milan * add simmim * add mixmim * fix lint * fix ut * fix lint * add eva * add densecl * add barlowtwins * add swav * fix lint * update readtherdocs rst * update docs * update * Decrease UT memory usage * Fix docstring * update DALLEEncoder * Update model docs * refactor dalle encoder * update docstring * fix ut * fix config error * add val_cfg and test_cfg * refactor clip generator * fix lint * pass check * fix ut * add lars * update type of BEiT in configs * Use MMEngine style momentum in EMA. * apply mmpretrain solarize --------- Co-authored-by: mzr1996 <mzr1996@163.com>
95 lines
3.0 KiB
Python
95 lines
3.0 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from typing import List, Optional, Tuple, Union
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from mmengine.dist import all_reduce, get_world_size
|
|
from mmengine.model import BaseModule
|
|
|
|
from mmpretrain.registry import MODELS
|
|
|
|
|
|
@MODELS.register_module()
|
|
class LatentPredictHead(BaseModule):
|
|
"""Head for latent feature prediction.
|
|
|
|
This head builds a predictor, which can be any registered neck component.
|
|
For example, BYOL and SimSiam call this head and build NonLinearNeck.
|
|
It also implements similarity loss between two forward features.
|
|
|
|
Args:
|
|
loss (dict): Config dict for the loss.
|
|
predictor (dict): Config dict for the predictor.
|
|
init_cfg (dict or List[dict], optional): Initialization config dict.
|
|
Defaults to None.
|
|
"""
|
|
|
|
def __init__(self,
|
|
loss: dict,
|
|
predictor: dict,
|
|
init_cfg: Optional[Union[dict, List[dict]]] = None) -> None:
|
|
super().__init__(init_cfg=init_cfg)
|
|
self.loss_module = MODELS.build(loss)
|
|
self.predictor = MODELS.build(predictor)
|
|
|
|
def loss(self, input: torch.Tensor,
|
|
target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""Forward head.
|
|
|
|
Args:
|
|
input (torch.Tensor): NxC input features.
|
|
target (torch.Tensor): NxC target features.
|
|
|
|
Returns:
|
|
torch.Tensor: The latent predict loss.
|
|
"""
|
|
pred = self.predictor([input])[0]
|
|
target = target.detach()
|
|
|
|
loss = self.loss_module(pred, target)
|
|
|
|
return loss
|
|
|
|
|
|
@MODELS.register_module()
|
|
class LatentCrossCorrelationHead(BaseModule):
|
|
"""Head for latent feature cross correlation.
|
|
|
|
Part of the code is borrowed from `script
|
|
<https://github.com/facebookresearch/barlowtwins/blob/main/main.py>`_.
|
|
|
|
Args:
|
|
in_channels (int): Number of input channels.
|
|
loss (dict): Config dict for module of loss functions.
|
|
init_cfg (dict or List[dict], optional): Initialization config dict.
|
|
Defaults to None.
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels: int,
|
|
loss: dict,
|
|
init_cfg: Optional[Union[dict, List[dict]]] = None) -> None:
|
|
super().__init__(init_cfg=init_cfg)
|
|
self.world_size = get_world_size()
|
|
self.bn = nn.BatchNorm1d(in_channels, affine=False)
|
|
self.loss_module = MODELS.build(loss)
|
|
|
|
def loss(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
|
"""Forward head.
|
|
|
|
Args:
|
|
input (torch.Tensor): NxC input features.
|
|
target (torch.Tensor): NxC target features.
|
|
|
|
Returns:
|
|
torch.Tensor: The cross correlation loss.
|
|
"""
|
|
# cross-correlation matrix
|
|
cross_correlation_matrix = self.bn(input).T @ self.bn(target)
|
|
cross_correlation_matrix.div_(input.size(0) * self.world_size)
|
|
|
|
all_reduce(cross_correlation_matrix)
|
|
|
|
loss = self.loss_module(cross_correlation_matrix)
|
|
return loss
|