mirror of
https://github.com/open-mmlab/mmclassification.git
synced 2025-06-03 21:53:55 +08:00
* remove basehead * add moco series * add byol simclr simsiam * add ut * update configs * add simsiam hook * add and refactor beit * update ut * add cae * update extract_feat * refactor cae * add mae * refactor data preprocessor * update heads * add maskfeat * add milan * add simmim * add mixmim * fix lint * fix ut * fix lint * add eva * add densecl * add barlowtwins * add swav * fix lint * update readtherdocs rst * update docs * update * Decrease UT memory usage * Fix docstring * update DALLEEncoder * Update model docs * refactor dalle encoder * update docstring * fix ut * fix config error * add val_cfg and test_cfg * refactor clip generator * fix lint * pass check * fix ut * add lars * update type of BEiT in configs * Use MMEngine style momentum in EMA. * apply mmpretrain solarize --------- Co-authored-by: mzr1996 <mzr1996@163.com>
58 lines
1.8 KiB
Python
58 lines
1.8 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from typing import List, Optional, Union
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from mmengine.model import BaseModule
|
|
|
|
from mmpretrain.registry import MODELS
|
|
|
|
|
|
@MODELS.register_module()
|
|
class BEiTV2Head(BaseModule):
|
|
"""Head for BEiT v2 Pre-training.
|
|
|
|
Compute the logits and the cross entropy loss.
|
|
|
|
Args:
|
|
embed_dims (int): The dimension of embedding.
|
|
num_embed (int): The number of classification types.
|
|
loss (dict): The config of loss.
|
|
init_cfg (dict or List[dict], optional): Initialization config dict.
|
|
Defaults to None.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
embed_dims: int,
|
|
num_embed: int,
|
|
loss: dict,
|
|
init_cfg: Optional[Union[dict, List[dict]]] = dict(
|
|
type='TruncNormal', layer='Linear', std=0.02, bias=0)
|
|
) -> None:
|
|
super().__init__(init_cfg=init_cfg)
|
|
self.cls_head = nn.Linear(embed_dims, num_embed)
|
|
self.loss_module = MODELS.build(loss)
|
|
|
|
def loss(self, feats: torch.Tensor, feats_cls_pt: torch.Tensor,
|
|
target: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
|
"""Generate loss.
|
|
|
|
Args:
|
|
feats (torch.Tensor): Features from backbone.
|
|
feats_cls_pt (torch.Tensor) : Features from class late layers for
|
|
pretraining.
|
|
target (torch.Tensor): Target generated by target_generator.
|
|
mask (torch.Tensor): Generated mask for pretraing.
|
|
"""
|
|
mask = mask.flatten(1).to(torch.bool)
|
|
target = target[mask]
|
|
|
|
# shared cls head
|
|
logits = self.cls_head(feats[mask])
|
|
logits_cls_pt = self.cls_head(feats_cls_pt[mask])
|
|
|
|
loss_1 = self.loss_module(logits, target)
|
|
loss_2 = self.loss_module(logits_cls_pt, target)
|
|
return loss_1, loss_2
|