[Feature] Support dinov2 backbone (#1522)
* support dinov2 backbone * update metafile and readme * compatible to use_layer_scale * update SwiGLUFFN * add deprecation warning * updatepull/1554/head
parent
496e098b21
commit
d9e561a09d
|
@ -1,10 +1,15 @@
|
|||
version: 2
|
||||
|
||||
# Set the version of Python and other tools you might need
|
||||
build:
|
||||
os: ubuntu-22.04
|
||||
tools:
|
||||
python: "3.8"
|
||||
|
||||
formats:
|
||||
- epub
|
||||
|
||||
python:
|
||||
version: 3.8
|
||||
install:
|
||||
- requirements: requirements/docs.txt
|
||||
- requirements: requirements/readthedocs.txt
|
||||
|
|
|
@ -0,0 +1,58 @@
|
|||
# DINOv2
|
||||
|
||||
> [DINOv2: Learning Robust Visual Features without Supervision](https://arxiv.org/abs/2304.07193)
|
||||
|
||||
<!-- [ALGORITHM] -->
|
||||
|
||||
## Abstract
|
||||
|
||||
The recent breakthroughs in natural language processing for model pretraining on large quantities of data have opened the way for similar foundation models in computer vision. These models could greatly simplify the use of images in any system by producing allpurpose visual features, i.e., features that work across image distributions and tasks without finetuning. This work shows that existing pretraining methods, especially self-supervised methods, can produce such features if trained on enough curated data from diverse sources. We revisit existing approaches and combine different techniques to scale our pretraining in terms of data and model size. Most of the technical contributions aim at accelerating and stabilizing the training at scale. In terms of data, we propose an automatic pipeline to build a dedicated, diverse, and curated image dataset instead of uncurated data, as typically done in the self-supervised literature. In terms of models, we train a ViT model (Dosovitskiy et al., 2020) with 1B parameters and distill it into a series of smaller models that surpass the best available all-purpose features, OpenCLIP (Ilharco et al., 2021) on most of the benchmarks at image and pixel levels.
|
||||
|
||||
<div align=center>
|
||||
<img src="https://user-images.githubusercontent.com/36138628/234560516-b495795c-c75c-444c-a712-bb61a3de444e.png" width="70%"/>
|
||||
</div>
|
||||
|
||||
## How to use it?
|
||||
|
||||
<!-- [TABS-BEGIN] -->
|
||||
|
||||
**Use the model**
|
||||
|
||||
```python
|
||||
import torch
|
||||
from mmpretrain import get_model
|
||||
|
||||
model = get_model('vit-small-p14_dinov2-pre_3rdparty', pretrained=True)
|
||||
inputs = torch.rand(1, 3, 224, 224)
|
||||
out = model(inputs)
|
||||
print(type(out))
|
||||
# To extract features.
|
||||
feats = model.extract_feat(inputs)
|
||||
print(type(feats))
|
||||
```
|
||||
|
||||
<!-- [TABS-END] -->
|
||||
|
||||
## Models and results
|
||||
|
||||
### Pretrained models
|
||||
|
||||
| Model | Params (M) | Flops (G) | Config | Download |
|
||||
| :------------------------------------ | :--------: | :-------: | :--------------------------------------------: | :------------------------------------------------------------------------------------------------: |
|
||||
| `vit-small-p14_dinov2-pre_3rdparty`\* | 22.06 | 46.76 | [config](vit-small-p14_dinov2-pre_headless.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/dinov2/vit-small-p14_dinov2-pre_3rdparty_20230426-5641ca5a.pth) |
|
||||
| `vit-base-p14_dinov2-pre_3rdparty`\* | 86.58 | 152.00 | [config](vit-base-p14_dinov2-pre_headless.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/dinov2/vit-base-p14_dinov2-pre_3rdparty_20230426-ba246503.pth) |
|
||||
| `vit-large-p14_dinov2-pre_3rdparty`\* | 304.00 | 507.00 | [config](vit-large-p14_dinov2-pre_headless.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/dinov2/vit-large-p14_dinov2-pre_3rdparty_20230426-f3302d9e.pth) |
|
||||
| `vit-giant-p14_dinov2-pre_3rdparty`\* | 1136.00 | 1784.00 | [config](vit-giant-p14_dinov2-pre_headless.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/dinov2/vit-giant-p14_dinov2-pre_3rdparty_20230426-2934a630.pth) |
|
||||
|
||||
*Models with * are converted from the [official repo](https://github.com/facebookresearch/dinov2). The config files of these models are only for inference. We haven't reprodcue the training results.*
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@misc{oquab2023dinov2,
|
||||
title={DINOv2: Learning Robust Visual Features without Supervision},
|
||||
author={Oquab, Maxime and Darcet, Timothée and Moutakanni, Theo and Vo, Huy V. and Szafraniec, Marc and Khalidov, Vasil and Fernandez, Pierre and Haziza, Daniel and Massa, Francisco and El-Nouby, Alaaeldin and Howes, Russell and Huang, Po-Yao and Xu, Hu and Sharma, Vasu and Li, Shang-Wen and Galuba, Wojciech and Rabbat, Mike and Assran, Mido and Ballas, Nicolas and Synnaeve, Gabriel and Misra, Ishan and Jegou, Herve and Mairal, Julien and Labatut, Patrick and Joulin, Armand and Bojanowski, Piotr},
|
||||
journal={arXiv:2304.07193},
|
||||
year={2023}
|
||||
}
|
||||
```
|
|
@ -0,0 +1,73 @@
|
|||
Collections:
|
||||
- Name: DINOv2
|
||||
Metadata:
|
||||
Architecture:
|
||||
- Dropout
|
||||
- GELU
|
||||
- Layer Normalization
|
||||
- Multi-Head Attention
|
||||
- Scaled Dot-Product Attention
|
||||
Paper:
|
||||
Title: 'DINOv2: Learning Robust Visual Features without Supervision'
|
||||
URL: https://arxiv.org/abs/2304.07193
|
||||
README: configs/dinov2/README.md
|
||||
Code:
|
||||
URL: null
|
||||
Version: null
|
||||
|
||||
Models:
|
||||
- Name: vit-small-p14_dinov2-pre_3rdparty
|
||||
Metadata:
|
||||
FLOPs: 46762000000
|
||||
Parameters: 22056000
|
||||
Training Data:
|
||||
- LVD-142M
|
||||
In Collection: DINOv2
|
||||
Results: null
|
||||
Weights: https://download.openmmlab.com/mmpretrain/v1.0/dinov2/vit-small-p14_dinov2-pre_3rdparty_20230426-5641ca5a.pth
|
||||
Config: configs/dinov2/vit-small-p14_dinov2-pre_headless.py
|
||||
Converted From:
|
||||
Weights: https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_pretrain.pth
|
||||
Code: https://github.com/facebookresearch/dinov2
|
||||
|
||||
- Name: vit-base-p14_dinov2-pre_3rdparty
|
||||
Metadata:
|
||||
FLOPs: 152000000000
|
||||
Parameters: 86580000
|
||||
Training Data:
|
||||
- LVD-142M
|
||||
In Collection: DINOv2
|
||||
Results: null
|
||||
Weights: https://download.openmmlab.com/mmpretrain/v1.0/dinov2/vit-base-p14_dinov2-pre_3rdparty_20230426-ba246503.pth
|
||||
Config: configs/dinov2/vit-base-p14_dinov2-pre_headless.py
|
||||
Converted From:
|
||||
Weights: https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_pretrain.pth
|
||||
Code: https://github.com/facebookresearch/dinov2
|
||||
|
||||
- Name: vit-large-p14_dinov2-pre_3rdparty
|
||||
Metadata:
|
||||
FLOPs: 507000000000
|
||||
Parameters: 304000000
|
||||
Training Data:
|
||||
- LVD-142M
|
||||
In Collection: DINOv2
|
||||
Results: null
|
||||
Weights: https://download.openmmlab.com/mmpretrain/v1.0/dinov2/vit-large-p14_dinov2-pre_3rdparty_20230426-f3302d9e.pth
|
||||
Config: configs/dinov2/vit-large-p14_dinov2-pre_headless.py
|
||||
Converted From:
|
||||
Weights: https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth
|
||||
Code: https://github.com/facebookresearch/dinov2
|
||||
|
||||
- Name: vit-giant-p14_dinov2-pre_3rdparty
|
||||
Metadata:
|
||||
FLOPs: 1784000000000
|
||||
Parameters: 1136000000
|
||||
Training Data:
|
||||
- LVD-142M
|
||||
In Collection: DINOv2
|
||||
Results: null
|
||||
Weights: https://download.openmmlab.com/mmpretrain/v1.0/dinov2/vit-giant-p14_dinov2-pre_3rdparty_20230426-2934a630.pth
|
||||
Config: configs/dinov2/vit-giant-p14_dinov2-pre_headless.py
|
||||
Converted From:
|
||||
Weights: https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_pretrain.pth
|
||||
Code: https://github.com/facebookresearch/dinov2
|
|
@ -0,0 +1,20 @@
|
|||
# model settings
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='VisionTransformer',
|
||||
arch='base',
|
||||
img_size=518,
|
||||
patch_size=14,
|
||||
layer_scale_init_value=1e-5,
|
||||
),
|
||||
neck=None,
|
||||
head=None)
|
||||
|
||||
data_preprocessor = dict(
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
# convert image from BGR to RGB
|
||||
to_rgb=True,
|
||||
)
|
|
@ -0,0 +1,21 @@
|
|||
# model settings
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='VisionTransformer',
|
||||
arch='dinov2-giant',
|
||||
img_size=518,
|
||||
patch_size=14,
|
||||
layer_scale_init_value=1e-5,
|
||||
layer_cfgs=dict(ffn_type='swiglu_fused'),
|
||||
),
|
||||
neck=None,
|
||||
head=None)
|
||||
|
||||
data_preprocessor = dict(
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
# convert image from BGR to RGB
|
||||
to_rgb=True,
|
||||
)
|
|
@ -0,0 +1,20 @@
|
|||
# model settings
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='VisionTransformer',
|
||||
arch='large',
|
||||
img_size=518,
|
||||
patch_size=14,
|
||||
layer_scale_init_value=1e-5,
|
||||
),
|
||||
neck=None,
|
||||
head=None)
|
||||
|
||||
data_preprocessor = dict(
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
# convert image from BGR to RGB
|
||||
to_rgb=True,
|
||||
)
|
|
@ -0,0 +1,20 @@
|
|||
# model settings
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='VisionTransformer',
|
||||
arch='dinov2-small',
|
||||
img_size=518,
|
||||
patch_size=14,
|
||||
layer_scale_init_value=1e-5,
|
||||
),
|
||||
neck=None,
|
||||
head=None)
|
||||
|
||||
data_preprocessor = dict(
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
# convert image from BGR to RGB
|
||||
to_rgb=True,
|
||||
)
|
|
@ -9,8 +9,8 @@ from mmengine.model import BaseModule, ModuleList
|
|||
from mmengine.model.weight_init import trunc_normal_
|
||||
|
||||
from mmpretrain.registry import MODELS
|
||||
from ..utils import (MultiheadAttention, build_norm_layer, resize_pos_embed,
|
||||
to_2tuple)
|
||||
from ..utils import (MultiheadAttention, SwiGLUFFNFused, build_norm_layer,
|
||||
resize_pos_embed, to_2tuple)
|
||||
from .base_backbone import BaseBackbone
|
||||
|
||||
|
||||
|
@ -21,6 +21,8 @@ class TransformerEncoderLayer(BaseModule):
|
|||
embed_dims (int): The feature dimension
|
||||
num_heads (int): Parallel attention heads
|
||||
feedforward_channels (int): The hidden dimension for FFNs
|
||||
layer_scale_init_value (float or torch.Tensor): Init value of layer
|
||||
scale. Defaults to 0.
|
||||
drop_rate (float): Probability of an element to be zeroed
|
||||
after the feed forward layer. Defaults to 0.
|
||||
attn_drop_rate (float): The drop out rate for attention output weights.
|
||||
|
@ -29,6 +31,7 @@ class TransformerEncoderLayer(BaseModule):
|
|||
num_fcs (int): The number of fully-connected layers for FFNs.
|
||||
Defaults to 2.
|
||||
qkv_bias (bool): enable bias for qkv if True. Defaults to True.
|
||||
ffn_type (str): Select the type of ffn layers. Defaults to 'origin'.
|
||||
act_cfg (dict): The activation config for FFNs.
|
||||
Defaluts to ``dict(type='GELU')``.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
|
@ -41,11 +44,13 @@ class TransformerEncoderLayer(BaseModule):
|
|||
embed_dims,
|
||||
num_heads,
|
||||
feedforward_channels,
|
||||
layer_scale_init_value=0.,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
num_fcs=2,
|
||||
qkv_bias=True,
|
||||
ffn_type='origin',
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=dict(type='LN'),
|
||||
init_cfg=None):
|
||||
|
@ -61,17 +66,27 @@ class TransformerEncoderLayer(BaseModule):
|
|||
attn_drop=attn_drop_rate,
|
||||
proj_drop=drop_rate,
|
||||
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
|
||||
qkv_bias=qkv_bias)
|
||||
qkv_bias=qkv_bias,
|
||||
layer_scale_init_value=layer_scale_init_value)
|
||||
|
||||
self.ln2 = build_norm_layer(norm_cfg, self.embed_dims)
|
||||
|
||||
self.ffn = FFN(
|
||||
embed_dims=embed_dims,
|
||||
feedforward_channels=feedforward_channels,
|
||||
num_fcs=num_fcs,
|
||||
ffn_drop=drop_rate,
|
||||
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
|
||||
act_cfg=act_cfg)
|
||||
if ffn_type == 'origin':
|
||||
self.ffn = FFN(
|
||||
embed_dims=embed_dims,
|
||||
feedforward_channels=feedforward_channels,
|
||||
num_fcs=num_fcs,
|
||||
ffn_drop=drop_rate,
|
||||
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
|
||||
act_cfg=act_cfg,
|
||||
layer_scale_init_value=layer_scale_init_value)
|
||||
elif ffn_type == 'swiglu_fused':
|
||||
self.ffn = SwiGLUFFNFused(
|
||||
embed_dims=embed_dims,
|
||||
feedforward_channels=feedforward_channels,
|
||||
layer_scale_init_value=layer_scale_init_value)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def norm1(self):
|
||||
|
@ -147,6 +162,8 @@ class VisionTransformer(BaseBackbone):
|
|||
-1 means not freezing any parameters. Defaults to -1.
|
||||
interpolate_mode (str): Select the interpolate mode for position
|
||||
embeding vector resize. Defaults to "bicubic".
|
||||
layer_scale_init_value (float or torch.Tensor): Init value of layer
|
||||
scale. Defaults to 0.
|
||||
patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
|
||||
layer_cfgs (Sequence | dict): Configs of each transformer layer in
|
||||
encoder. Defaults to an empty dict.
|
||||
|
@ -203,7 +220,7 @@ class VisionTransformer(BaseBackbone):
|
|||
'feedforward_channels': 192 * 4
|
||||
}),
|
||||
**dict.fromkeys(
|
||||
['deit-s', 'deit-small'], {
|
||||
['deit-s', 'deit-small', 'dinov2-s', 'dinov2-small'], {
|
||||
'embed_dims': 384,
|
||||
'num_layers': 12,
|
||||
'num_heads': 6,
|
||||
|
@ -216,6 +233,13 @@ class VisionTransformer(BaseBackbone):
|
|||
'num_heads': 12,
|
||||
'feedforward_channels': 768 * 4
|
||||
}),
|
||||
**dict.fromkeys(
|
||||
['dinov2-g', 'dinov2-giant'], {
|
||||
'embed_dims': 1536,
|
||||
'num_layers': 40,
|
||||
'num_heads': 24,
|
||||
'feedforward_channels': 6144
|
||||
}),
|
||||
}
|
||||
num_extra_tokens = 1 # class token
|
||||
OUT_TYPES = {'raw', 'cls_token', 'featmap', 'avg_featmap'}
|
||||
|
@ -235,6 +259,7 @@ class VisionTransformer(BaseBackbone):
|
|||
with_cls_token=True,
|
||||
frozen_stages=-1,
|
||||
interpolate_mode='bicubic',
|
||||
layer_scale_init_value=0.,
|
||||
patch_cfg=dict(),
|
||||
layer_cfgs=dict(),
|
||||
pre_norm=False,
|
||||
|
@ -322,6 +347,7 @@ class VisionTransformer(BaseBackbone):
|
|||
num_heads=self.arch_settings['num_heads'],
|
||||
feedforward_channels=self.
|
||||
arch_settings['feedforward_channels'],
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
drop_rate=drop_rate,
|
||||
drop_path_rate=dpr[i],
|
||||
qkv_bias=qkv_bias,
|
||||
|
|
|
@ -22,6 +22,7 @@ from .position_encoding import (ConditionalPositionEncoding,
|
|||
build_2d_sincos_position_embedding)
|
||||
from .res_layer_extra_norm import ResLayerExtraNorm
|
||||
from .se_layer import SELayer
|
||||
from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
|
||||
from .vector_quantizer import NormEMAVectorQuantizer
|
||||
|
||||
__all__ = [
|
||||
|
@ -69,4 +70,6 @@ __all__ = [
|
|||
'VideoDataPreprocessor',
|
||||
'CosineEMA',
|
||||
'ResLayerExtraNorm',
|
||||
'SwiGLUFFN',
|
||||
'SwiGLUFFNFused',
|
||||
]
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import itertools
|
||||
import warnings
|
||||
from functools import partial
|
||||
from typing import List, Optional, Union
|
||||
|
||||
|
@ -528,6 +529,9 @@ class MultiheadAttention(BaseModule):
|
|||
v_shortcut (bool): Add a shortcut from value to output. It's usually
|
||||
used if ``input_dims`` is different from ``embed_dims``.
|
||||
Defaults to False.
|
||||
use_layer_scale (bool): Whether to use layer scale. Defaults to False.
|
||||
layer_scale_init_value (float or torch.Tensor): Init value of layer
|
||||
scale. Defaults to 0.
|
||||
init_cfg (dict, optional): The Config for initialization.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
@ -544,6 +548,7 @@ class MultiheadAttention(BaseModule):
|
|||
proj_bias=True,
|
||||
v_shortcut=False,
|
||||
use_layer_scale=False,
|
||||
layer_scale_init_value=0.,
|
||||
init_cfg=None):
|
||||
super(MultiheadAttention, self).__init__(init_cfg=init_cfg)
|
||||
|
||||
|
@ -568,7 +573,14 @@ class MultiheadAttention(BaseModule):
|
|||
self.out_drop = build_dropout(dropout_layer)
|
||||
|
||||
if use_layer_scale:
|
||||
self.gamma1 = LayerScale(embed_dims)
|
||||
warnings.warn('The `use_layer_scale` in `MultiheadAttention` will '
|
||||
'be deprecated. Please use `layer_scale_init_value` '
|
||||
'to control whether using layer scale or not.')
|
||||
|
||||
if use_layer_scale or (layer_scale_init_value > 0):
|
||||
layer_scale_init_value = layer_scale_init_value or 1e-5
|
||||
self.gamma1 = LayerScale(
|
||||
embed_dims, layer_scale_init_value=layer_scale_init_value)
|
||||
else:
|
||||
self.gamma1 = nn.Identity()
|
||||
|
||||
|
@ -1057,9 +1069,19 @@ class PromptMultiheadAttention(MultiheadAttention):
|
|||
v_shortcut: bool = False,
|
||||
use_layer_scale: bool = False,
|
||||
init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
|
||||
super().__init__(embed_dims, num_heads, input_dims, attn_drop,
|
||||
proj_drop, dropout_layer, qkv_bias, qk_scale,
|
||||
proj_bias, v_shortcut, use_layer_scale, init_cfg)
|
||||
super().__init__(
|
||||
embed_dims=embed_dims,
|
||||
num_heads=num_heads,
|
||||
input_dims=input_dims,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=proj_drop,
|
||||
dropout_layer=dropout_layer,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
proj_bias=proj_bias,
|
||||
v_shortcut=v_shortcut,
|
||||
use_layer_scale=use_layer_scale,
|
||||
init_cfg=init_cfg)
|
||||
# no longer need qkv
|
||||
del self.qkv
|
||||
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
@ -8,6 +10,8 @@ class LayerScale(nn.Module):
|
|||
|
||||
Args:
|
||||
dim (int): Dimension of input features.
|
||||
layer_scale_init_value (float or torch.Tensor): Init value of layer
|
||||
scale. Defaults to 1e-5.
|
||||
inplace (bool): inplace: can optionally do the
|
||||
operation in-place. Defaults to False.
|
||||
data_format (str): The input data format, could be 'channels_last'
|
||||
|
@ -17,6 +21,7 @@ class LayerScale(nn.Module):
|
|||
|
||||
def __init__(self,
|
||||
dim: int,
|
||||
layer_scale_init_value: Union[float, torch.Tensor] = 1e-5,
|
||||
inplace: bool = False,
|
||||
data_format: str = 'channels_last'):
|
||||
super().__init__()
|
||||
|
@ -24,7 +29,7 @@ class LayerScale(nn.Module):
|
|||
"'data_format' could only be channels_last or channels_first."
|
||||
self.inplace = inplace
|
||||
self.data_format = data_format
|
||||
self.weight = nn.Parameter(torch.ones(dim) * 1e-5)
|
||||
self.weight = nn.Parameter(torch.ones(dim) * layer_scale_init_value)
|
||||
|
||||
def forward(self, x):
|
||||
if self.data_format == 'channels_first':
|
||||
|
|
|
@ -0,0 +1,98 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn.bricks.drop import build_dropout
|
||||
|
||||
from .layer_scale import LayerScale
|
||||
from .norm import build_norm_layer
|
||||
|
||||
|
||||
class SwiGLUFFN(nn.Module):
|
||||
"""SwiGLU FFN layer.
|
||||
|
||||
Modified from https://github.com/facebookresearch/dinov2/blob/main/dinov2/layers/swiglu_ffn.py
|
||||
""" # noqa
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dims: int,
|
||||
feedforward_channels: Optional[int] = None,
|
||||
out_dims: Optional[int] = None,
|
||||
layer_scale_init_value: float = 0.,
|
||||
bias: bool = True,
|
||||
dropout_layer: Optional[dict] = None,
|
||||
norm_cfg: Optional[dict] = None,
|
||||
add_identity: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.embed_dims = embed_dims
|
||||
self.out_dims = out_dims or embed_dims
|
||||
hidden_dims = feedforward_channels or embed_dims
|
||||
|
||||
self.w12 = nn.Linear(self.embed_dims, 2 * hidden_dims, bias=bias)
|
||||
|
||||
if norm_cfg is not None:
|
||||
self.norm = build_norm_layer(norm_cfg, hidden_dims)
|
||||
else:
|
||||
self.norm = nn.Identity()
|
||||
|
||||
self.w3 = nn.Linear(hidden_dims, self.out_dims, bias=bias)
|
||||
|
||||
if layer_scale_init_value > 0:
|
||||
self.gamma2 = LayerScale(
|
||||
dim=embed_dims, layer_scale_init_value=layer_scale_init_value)
|
||||
else:
|
||||
self.gamma2 = nn.Identity()
|
||||
|
||||
self.dropout_layer = build_dropout(
|
||||
dropout_layer) if dropout_layer else torch.nn.Identity()
|
||||
self.add_identity = add_identity
|
||||
|
||||
def forward(self,
|
||||
x: torch.Tensor,
|
||||
identity: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
x12 = self.w12(x)
|
||||
x1, x2 = x12.chunk(2, dim=-1)
|
||||
hidden = F.silu(x1) * x2
|
||||
hidden = self.norm(hidden)
|
||||
out = self.w3(hidden)
|
||||
out = self.gamma2(out)
|
||||
out = self.dropout_layer(out)
|
||||
|
||||
if self.out_dims != self.embed_dims or not self.add_identity:
|
||||
# due to the dimension inconsistence or user setting
|
||||
# not to apply residual operation
|
||||
return out
|
||||
|
||||
if identity is None:
|
||||
identity = x
|
||||
return identity + out
|
||||
|
||||
|
||||
class SwiGLUFFNFused(SwiGLUFFN):
|
||||
"""SwiGLU FFN layer with fusing.
|
||||
|
||||
Modified from https://github.com/facebookresearch/dinov2/blob/main/dinov2/layers/swiglu_ffn.py
|
||||
""" # noqa
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dims: int,
|
||||
feedforward_channels: Optional[int] = None,
|
||||
out_dims: Optional[int] = None,
|
||||
layer_scale_init_value: float = 0.,
|
||||
bias: bool = True,
|
||||
) -> None:
|
||||
out_dims = out_dims or embed_dims
|
||||
feedforward_channels = feedforward_channels or embed_dims
|
||||
feedforward_channels = (int(feedforward_channels * 2 / 3) + 7) // 8 * 8
|
||||
super().__init__(
|
||||
embed_dims=embed_dims,
|
||||
feedforward_channels=feedforward_channels,
|
||||
out_dims=out_dims,
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
bias=bias,
|
||||
)
|
|
@ -69,3 +69,4 @@ Import:
|
|||
- configs/riformer/metafile.yml
|
||||
- configs/sam/metafile.yml
|
||||
- configs/glip/metafile.yml
|
||||
- configs/dinov2/metafile.yml
|
||||
|
|
|
@ -30,6 +30,10 @@ test_list = [
|
|||
backbone=mmpretrain.models.ViTSAM,
|
||||
forward=False,
|
||||
backward=False),
|
||||
Cfg(name='vit-base-p14_dinov2-pre_3rdparty',
|
||||
backbone=mmpretrain.models.VisionTransformer,
|
||||
forward=False,
|
||||
backward=False),
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest import TestCase
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmpretrain.models.utils import LayerScale, SwiGLUFFN, SwiGLUFFNFused
|
||||
|
||||
|
||||
class TestSwiGLUFFN(TestCase):
|
||||
|
||||
def test_init(self):
|
||||
swiglu = SwiGLUFFN(embed_dims=4)
|
||||
assert swiglu.w12.weight.shape == torch.ones((8, 4)).shape
|
||||
assert swiglu.w3.weight.shape == torch.ones((4, 4)).shape
|
||||
assert isinstance(swiglu.gamma2, nn.Identity)
|
||||
|
||||
swiglu = SwiGLUFFN(embed_dims=4, layer_scale_init_value=0.1)
|
||||
assert isinstance(swiglu.gamma2, LayerScale)
|
||||
|
||||
def test_forward(self):
|
||||
swiglu = SwiGLUFFN(embed_dims=4)
|
||||
x = torch.randn((1, 8, 4))
|
||||
out = swiglu(x)
|
||||
self.assertEqual(out.size(), x.size())
|
||||
|
||||
swiglu = SwiGLUFFN(embed_dims=4, out_dims=12)
|
||||
x = torch.randn((1, 8, 4))
|
||||
out = swiglu(x)
|
||||
self.assertEqual(tuple(out.size()), (1, 8, 12))
|
||||
|
||||
|
||||
class TestSwiGLUFFNFused(TestCase):
|
||||
|
||||
def test_init(self):
|
||||
swiglu = SwiGLUFFNFused(embed_dims=4)
|
||||
assert swiglu.w12.weight.shape == torch.ones((16, 4)).shape
|
||||
assert swiglu.w3.weight.shape == torch.ones((4, 8)).shape
|
||||
assert isinstance(swiglu.gamma2, nn.Identity)
|
||||
|
||||
swiglu = SwiGLUFFNFused(embed_dims=4, layer_scale_init_value=0.1)
|
||||
assert isinstance(swiglu.gamma2, LayerScale)
|
||||
|
||||
def test_forward(self):
|
||||
swiglu = SwiGLUFFNFused(embed_dims=4)
|
||||
x = torch.randn((1, 8, 4))
|
||||
out = swiglu(x)
|
||||
self.assertEqual(out.size(), x.size())
|
||||
|
||||
swiglu = SwiGLUFFNFused(embed_dims=4, out_dims=12)
|
||||
x = torch.randn((1, 8, 4))
|
||||
out = swiglu(x)
|
||||
self.assertEqual(tuple(out.size()), (1, 8, 12))
|
Loading…
Reference in New Issue