Shawn_ c920d976e7
[Feature]: Add dino (#658)
* [Feature]: Add the scaffold

* Code camp

* Update dino.py

* add loss

* Update dino_neck.py

* data_pipeline

* Update imagenet_dino.py

* [Feature]: Add dino neck

* [Feature]: Add dino neck

* [Feature]: Add teacher temp update hook

* [Feature]: Add dino algorithm

* [Feature]: Add Transform

* [Feature]: Add init

* [Feature]: Forward DINO

* [Feature]: Add DINO

* [Fix]: Delete dino dataset

* [Feature]: Add docstring

* [Feature]: Add readme

* [Fix]: Fix reviews

* [Fix]: Fix lint

---------

Co-authored-by: YuanLiuuuuuu <3463423099@qq.com>
2023-03-28 11:07:16 +08:00

83 lines
2.9 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Union
import torch
from torch import nn
from mmselfsup.models import BaseModel, CosineEMA
from mmselfsup.registry import MODELS
from mmselfsup.structures import SelfSupDataSample
@MODELS.register_module()
class DINO(BaseModel):
"""Implementation for DINO.
This module is proposed in `DINO: Emerging Properties in Self-Supervised
Vision Transformers <https://arxiv.org/abs/2104.14294>`_.
Args:
backbone (dict): Config for backbone.
neck (dict): Config for neck.
head (dict): Config for head.
pretrained (str, optional): Path for pretrained model.
Defaults to None.
base_momentum (float, optional): Base momentum for momentum update.
Defaults to 0.99.
data_preprocessor (dict, optional): Config for data preprocessor.
Defaults to None.
init_cfg (list[dict] | dict, optional): Config for initialization.
Defaults to None.
"""
def __init__(self,
backbone: dict,
neck: dict,
head: dict,
pretrained: Optional[str] = None,
base_momentum: float = 0.99,
data_preprocessor: Optional[dict] = None,
init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
super().__init__(
backbone=backbone,
neck=neck,
head=head,
pretrained=pretrained,
data_preprocessor=data_preprocessor,
init_cfg=init_cfg)
# create momentum model
self.teacher = CosineEMA(
nn.Sequential(self.backbone, self.neck), momentum=base_momentum)
# weight normalization layer
self.neck.last_layer = nn.utils.weight_norm(self.neck.last_layer)
self.neck.last_layer.weight_g.data.fill_(1)
self.neck.last_layer.weight_g.requires_grad = False
self.teacher.module[1].last_layer = nn.utils.weight_norm(
self.teacher.module[1].last_layer)
self.teacher.module[1].last_layer.weight_g.data.fill_(1)
self.teacher.module[1].last_layer.weight_g.requires_grad = False
def loss(self, inputs: torch.Tensor,
data_samples: List[SelfSupDataSample]) -> dict:
global_crops = torch.cat(inputs[:2])
local_crops = torch.cat(inputs[2:])
# teacher forward
teacher_output = self.teacher(global_crops)
# student forward global
student_output_global = self.backbone(global_crops)
student_output_global = self.neck(student_output_global)
# student forward local
student_output_local = self.backbone(local_crops)
student_output_local = self.neck(student_output_local)
student_output = torch.cat(
(student_output_global, student_output_local))
# compute loss
loss = self.head(student_output, teacher_output)
return dict(loss=loss)