mirror of
https://github.com/open-mmlab/mmpretrain.git
synced 2025-06-03 14:59:18 +08:00
* [CodeCamp2023-584]Support DINO self-supervised learning in project (#1756) * feat: impelemt DINO * chore: delete debug code * chore: impplement pre-commit * fix: fix imported package * chore: pre-commit check * [CodeCamp2023-340] New Version of config Adapting MobileNet Algorithm (#1774) * add new config adapting MobileNetV2,V3 * add base model config for mobile net v3, modified all training configs of mobile net v3 inherit from the base model config * removed directory _base_/models/mobilenet_v3 * [Feature] Implement of Zero-Shot CLIP Classifier (#1737) * zero-shot CLIP * modify zero-shot clip config * add in1k_sub_prompt(8 prompts) for improvement * add some annotations doc * clip base class & clip_zs sub-class * some modifications of details after review * convert into and use mmpretrain-vit * modify names of some files and directories * ram init commit * [Fix] Fix pipeline bug in image retrieval inferencer * [CodeCamp2023-341] 多模态数据集文档补充-COCO Retrieval * Update OFA to compat with latest huggingface. * Update train.py to compat with new config * Bump version to v1.1.0 * Update __init__.py --------- Co-authored-by: LALBJ <40877073+LALBJ@users.noreply.github.com> Co-authored-by: DE009 <57087096+DE009@users.noreply.github.com> Co-authored-by: mzr1996 <mzr1996@163.com> Co-authored-by: 飞飞 <102729089+ASHORE1225@users.noreply.github.com>
83 lines
2.9 KiB
Python
83 lines
2.9 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from typing import List, Optional, Union
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
from mmpretrain.models import BaseSelfSupervisor, CosineEMA
|
|
from mmpretrain.registry import MODELS
|
|
from mmpretrain.structures import DataSample
|
|
|
|
|
|
@MODELS.register_module()
|
|
class DINO(BaseSelfSupervisor):
|
|
"""Implementation for DINO.
|
|
|
|
This module is proposed in `DINO: Emerging Properties in Self-Supervised
|
|
Vision Transformers <https://arxiv.org/abs/2104.14294>`_.
|
|
|
|
Args:
|
|
backbone (dict): Config for backbone.
|
|
neck (dict): Config for neck.
|
|
head (dict): Config for head.
|
|
pretrained (str, optional): Path for pretrained model.
|
|
Defaults to None.
|
|
base_momentum (float, optional): Base momentum for momentum update.
|
|
Defaults to 0.99.
|
|
data_preprocessor (dict, optional): Config for data preprocessor.
|
|
Defaults to None.
|
|
init_cfg (list[dict] | dict, optional): Config for initialization.
|
|
Defaults to None.
|
|
"""
|
|
|
|
def __init__(self,
|
|
backbone: dict,
|
|
neck: dict,
|
|
head: dict,
|
|
pretrained: Optional[str] = None,
|
|
base_momentum: float = 0.99,
|
|
data_preprocessor: Optional[dict] = None,
|
|
init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
|
|
super().__init__(
|
|
backbone=backbone,
|
|
neck=neck,
|
|
head=head,
|
|
pretrained=pretrained,
|
|
data_preprocessor=data_preprocessor,
|
|
init_cfg=init_cfg)
|
|
|
|
# create momentum model
|
|
self.teacher = CosineEMA(
|
|
nn.Sequential(self.backbone, self.neck), momentum=base_momentum)
|
|
# weight normalization layer
|
|
self.neck.last_layer = nn.utils.weight_norm(self.neck.last_layer)
|
|
self.neck.last_layer.weight_g.data.fill_(1)
|
|
self.neck.last_layer.weight_g.requires_grad = False
|
|
self.teacher.module[1].last_layer = nn.utils.weight_norm(
|
|
self.teacher.module[1].last_layer)
|
|
self.teacher.module[1].last_layer.weight_g.data.fill_(1)
|
|
self.teacher.module[1].last_layer.weight_g.requires_grad = False
|
|
|
|
def loss(self, inputs: torch.Tensor,
|
|
data_samples: List[DataSample]) -> dict:
|
|
global_crops = torch.cat(inputs[:2])
|
|
local_crops = torch.cat(inputs[2:])
|
|
# teacher forward
|
|
teacher_output = self.teacher(global_crops)
|
|
|
|
# student forward global
|
|
student_output_global = self.backbone(global_crops)
|
|
student_output_global = self.neck(student_output_global)
|
|
|
|
# student forward local
|
|
student_output_local = self.backbone(local_crops)
|
|
student_output_local = self.neck(student_output_local)
|
|
|
|
student_output = torch.cat(
|
|
(student_output_global, student_output_local))
|
|
|
|
# compute loss
|
|
loss = self.head(student_output, teacher_output)
|
|
|
|
return dict(loss=loss)
|