Add new backbones trained with registers (#282)
Add new backbones (and matching linear classification heads) trained with 4 registers following [Vision Transformers Need Registers](https://arxiv.org/abs/2309.16588).qasfb-patch-2
parent
44abdbe27c
commit
9c7e324579
|
@ -1,12 +1,16 @@
|
|||
# Model Card for DINOv2-S/B/L/g
|
||||
|
||||
These are Vision Transformer models trained following the method described in the paper:
|
||||
These are Vision Transformer models trained following the method described in the papers:
|
||||
"DINOv2: Learning Robust Visual Features without Supervision"
|
||||
and
|
||||
"Vision Transformers Need Registers".
|
||||
|
||||
We provide 4 models: 1 ViT-g trained from scratch, and 3 ViT-S/B/L models distilled from the ViT-g.
|
||||
We provide 8 models:
|
||||
- 1 ViT-g trained from scratch with 3 ViT-S/B/L models distilled from the ViT-g, without registers.
|
||||
- 1 ViT-g trained from scratch with 3 ViT-S/B/L models distilled from the ViT-g, with registers.
|
||||
|
||||
## Model Details
|
||||
The model takes an image as input and returns a class token and patch tokens.
|
||||
The model takes an image as input and returns a class token and patch tokens, and optionally 4 register tokens.
|
||||
|
||||
The embedding dimension is:
|
||||
- 384 for ViT-S.
|
||||
|
@ -14,9 +18,9 @@ The embedding dimension is:
|
|||
- 1024 for ViT-L.
|
||||
- 1536 for ViT-g.
|
||||
|
||||
The models follow a Transformer architecture, with a patch size of 14.
|
||||
The models follow a Transformer architecture, with a patch size of 14. In the case of registers, we add 4 register tokens, learned during training, to the input sequence after the patch embedding.
|
||||
|
||||
For a 224x224 image, this results in 1 class token + 256 patch tokens.
|
||||
For a 224x224 image, this results in 1 class token + 256 patch tokens, and optionally 4 register tokens.
|
||||
|
||||
The models can accept larger images provided the image shapes are multiples of the patch size (14).
|
||||
If this condition is not verified, the model will crop to the closest smaller multiple of the patch size.
|
||||
|
@ -63,10 +67,18 @@ Use the code below to get started with the model.
|
|||
|
||||
```python
|
||||
import torch
|
||||
|
||||
# DINOv2
|
||||
dinov2_vits14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
|
||||
dinov2_vitb14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14')
|
||||
dinov2_vitl14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14')
|
||||
dinov2_vitg14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14')
|
||||
|
||||
# DINOv2 with registers
|
||||
dinov2_vits14_reg = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14_reg')
|
||||
dinov2_vitb14_reg = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14_reg')
|
||||
dinov2_vitl14_reg = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14_reg')
|
||||
dinov2_vitg14_reg = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14_reg')
|
||||
```
|
||||
|
||||
## Training Details
|
||||
|
@ -92,11 +104,11 @@ dinov2_vitg14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14')
|
|||
|
||||
## Evaluation
|
||||
|
||||
We refer users to the associated paper for the evaluation protocols.
|
||||
We refer users to the associated papers for the evaluation protocols.
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<th>model</th>
|
||||
<th colspan="2"></th>
|
||||
<th colspan="3">ImageNet-1k</th>
|
||||
<th>NYU-Depth v2</th>
|
||||
<th>SUN-RGBD</th>
|
||||
|
@ -105,7 +117,8 @@ We refer users to the associated paper for the evaluation protocols.
|
|||
<th>Oxford-H</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<th rowspan="2">task</th>
|
||||
<th rowspan="2">model</th>
|
||||
<th rowspan="2">with <br /> registers</th>
|
||||
<th>classif. (acc)</th>
|
||||
<th>classif. (acc)</th>
|
||||
<th>classif. V2 (acc)</th>
|
||||
|
@ -128,6 +141,7 @@ We refer users to the associated paper for the evaluation protocols.
|
|||
</tr>
|
||||
<tr>
|
||||
<td>ViT-S/14</td>
|
||||
<td align="center">:x:</td>
|
||||
<td align="right">79.0%</td>
|
||||
<td align="right">81.1%</td>
|
||||
<td align="right">70.8%</td>
|
||||
|
@ -137,8 +151,21 @@ We refer users to the associated paper for the evaluation protocols.
|
|||
<td align="right">69.5%</td>
|
||||
<td align="right">43.2</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>ViT-S/14</td>
|
||||
<td align="center">:white_check_mark:</td>
|
||||
<td align="right">79.1%</td>
|
||||
<td align="right">80.9%</td>
|
||||
<td align="right">71.0%</td>
|
||||
<td align="right">N/A</td>
|
||||
<td align="right">N/A</td>
|
||||
<td align="right">N/A</td>
|
||||
<td align="right">67.6%</td>
|
||||
<td align="right">39.5</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>ViT-B/14</td>
|
||||
<td align="center">:x:</td>
|
||||
<td align="right">82.1%</td>
|
||||
<td align="right">84.5%</td>
|
||||
<td align="right">74.9%</td>
|
||||
|
@ -147,9 +174,21 @@ We refer users to the associated paper for the evaluation protocols.
|
|||
<td align="right">51.3</td>
|
||||
<td align="right">76.3%</td>
|
||||
<td align="right">49.5</td>
|
||||
</tr>
|
||||
<td>ViT-B/14</td>
|
||||
<td align="center">:white_check_mark:</td>
|
||||
<td align="right">82.0%</td>
|
||||
<td align="right">84.6%</td>
|
||||
<td align="right">75.6%</td>
|
||||
<td align="right">N/A</td>
|
||||
<td align="right">N/A</td>
|
||||
<td align="right">N/A</td>
|
||||
<td align="right">73.8%</td>
|
||||
<td align="right">51.0</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>ViT-L/14</td>
|
||||
<td align="center">:x:</td>
|
||||
<td align="right">83.5%</td>
|
||||
<td align="right">86.3%</td>
|
||||
<td align="right">77.6%</td>
|
||||
|
@ -159,8 +198,21 @@ We refer users to the associated paper for the evaluation protocols.
|
|||
<td align="right">79.8%</td>
|
||||
<td align="right">54.0</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>ViT-L/14</td>
|
||||
<td align="center">:white_check_mark:</td>
|
||||
<td align="right">83.8%</td>
|
||||
<td align="right">86.7%</td>
|
||||
<td align="right">78.5%</td>
|
||||
<td align="right">N/A</td>
|
||||
<td align="right">N/A</td>
|
||||
<td align="right">N/A</td>
|
||||
<td align="right">80.9%</td>
|
||||
<td align="right">55.7</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>ViT-g/14</td>
|
||||
<td align="center">:x:</td>
|
||||
<td align="right">83.5%</td>
|
||||
<td align="right">86.5%</td>
|
||||
<td align="right">78.4%</td>
|
||||
|
@ -170,6 +222,19 @@ We refer users to the associated paper for the evaluation protocols.
|
|||
<td align="right">81.6%</td>
|
||||
<td align="right">52.3</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<tr>
|
||||
<td>ViT-g/14</td>
|
||||
<td align="center">:white_check_mark:</td>
|
||||
<td align="right">83.7%</td>
|
||||
<td align="right">87.1%</td>
|
||||
<td align="right">78.8%</td>
|
||||
<td align="right">N/A</td>
|
||||
<td align="right">N/A</td>
|
||||
<td align="right">N/A</td>
|
||||
<td align="right">81.5%</td>
|
||||
<td align="right">58.2</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
## Environmental Impact
|
||||
|
@ -198,4 +263,10 @@ xFormers 0.0.18
|
|||
journal={arXiv:2304.07193},
|
||||
year={2023}
|
||||
}
|
||||
@misc{darcet2023vitneedreg,
|
||||
title={Vision Transformers Need Registers},
|
||||
author={Darcet, Timothée and Oquab, Maxime and Mairal, Julien and Bojanowski, Piotr},
|
||||
journal={arXiv:2309.16588},
|
||||
year={2023}
|
||||
}
|
||||
```
|
||||
|
|
125
README.md
125
README.md
|
@ -1,3 +1,5 @@
|
|||
:new: [2023-10-26] *Added DINOv2 backbones with registers.*
|
||||
|
||||
# DINOv2: Learning Robust Visual Features without Supervision
|
||||
|
||||
**[Meta AI Research, FAIR](https://ai.facebook.com/research/)**
|
||||
|
@ -31,6 +33,7 @@ https://github.com/facebookresearch/dinov2/assets/60359573/f168823e-7922-415a-b4
|
|||
<tr>
|
||||
<th>model</th>
|
||||
<th># of<br />params</th>
|
||||
<th>with<br />registers</th>
|
||||
<th>ImageNet<br />k-NN</th>
|
||||
<th>ImageNet<br />linear</th>
|
||||
<th>download</th>
|
||||
|
@ -40,31 +43,67 @@ https://github.com/facebookresearch/dinov2/assets/60359573/f168823e-7922-415a-b4
|
|||
<tr>
|
||||
<td>ViT-S/14 distilled</td>
|
||||
<td align="right">21 M</td>
|
||||
<td align="center">:x:</td>
|
||||
<td align="right">79.0%</td>
|
||||
<td align="right">81.1%</td>
|
||||
<td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_pretrain.pth">backbone only</a></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>ViT-S/14 distilled</td>
|
||||
<td align="right">21 M</td>
|
||||
<td align="center">:white_check_mark:</td>
|
||||
<td align="right">79.1%</td>
|
||||
<td align="right">80.9%</td>
|
||||
<td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_reg4_pretrain.pth">backbone only</a></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>ViT-B/14 distilled</td>
|
||||
<td align="right">86 M</td>
|
||||
<td align="center">:x:</td>
|
||||
<td align="right">82.1%</td>
|
||||
<td align="right">84.5%</td>
|
||||
<td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_pretrain.pth">backbone only</a></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>ViT-B/14 distilled</td>
|
||||
<td align="right">86 M</td>
|
||||
<td align="center">:white_check_mark:</td>
|
||||
<td align="right">82.0%</td>
|
||||
<td align="right">84.6%</td>
|
||||
<td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_reg4_pretrain.pth">backbone only</a></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>ViT-L/14 distilled</td>
|
||||
<td align="right">300 M</td>
|
||||
<td align="center">:x:</td>
|
||||
<td align="right">83.5%</td>
|
||||
<td align="right">86.3%</td>
|
||||
<td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth">backbone only</a></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>ViT-L/14 distilled</td>
|
||||
<td align="right">300 M</td>
|
||||
<td align="center">:white_check_mark:</td>
|
||||
<td align="right">83.8%</td>
|
||||
<td align="right">86.7%</td>
|
||||
<td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_reg4_pretrain.pth">backbone only</a></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>ViT-g/14</td>
|
||||
<td align="right">1,100 M</td>
|
||||
<td align="center">:x:</td>
|
||||
<td align="right">83.5%</td>
|
||||
<td align="right">86.5%</td>
|
||||
<td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_pretrain.pth">backbone only</a></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>ViT-g/14</td>
|
||||
<td align="right">1,100 M</td>
|
||||
<td align="center">:white_check_mark:</td>
|
||||
<td align="right">83.7%</td>
|
||||
<td align="right">87.1%</td>
|
||||
<td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_reg4_pretrain.pth">backbone only</a></td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
|
@ -77,10 +116,17 @@ A corresponding [model card](MODEL_CARD.md) is included in the repository.
|
|||
```python
|
||||
import torch
|
||||
|
||||
# DINOv2
|
||||
dinov2_vits14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
|
||||
dinov2_vitb14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14')
|
||||
dinov2_vitl14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14')
|
||||
dinov2_vitg14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14')
|
||||
|
||||
# DINOv2 with registers
|
||||
dinov2_vits14_reg = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14_reg')
|
||||
dinov2_vitb14_reg = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14_reg')
|
||||
dinov2_vitl14_reg = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14_reg')
|
||||
dinov2_vitg14_reg = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14_reg')
|
||||
```
|
||||
|
||||
### Pretrained heads - Image classification
|
||||
|
@ -89,6 +135,7 @@ dinov2_vitg14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14')
|
|||
<thead>
|
||||
<tr>
|
||||
<th rowspan="2">backbone</th>
|
||||
<th rowspan="2">with<br />registers</th>
|
||||
<th>download</th>
|
||||
</tr>
|
||||
<tr>
|
||||
|
@ -98,29 +145,62 @@ dinov2_vitg14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14')
|
|||
<tbody>
|
||||
<tr>
|
||||
<td>ViT-S/14 distilled</td>
|
||||
<td align="center">:x:</td>
|
||||
<td>
|
||||
linear head (<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_linear_head.pth">1 layer</a>,
|
||||
<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_linear4_head.pth">4 layers</a>)
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>ViT-S/14 distilled</td>
|
||||
<td align="center">:white_check_mark:</td>
|
||||
<td>
|
||||
linear head (<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_reg4_linear_head.pth">1 layer</a>,
|
||||
<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_reg4_linear4_head.pth">4 layers</a>)
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>ViT-B/14 distilled</td>
|
||||
<td align="center">:x:</td>
|
||||
<td>
|
||||
linear head (<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_linear_head.pth">1 layer</a>,
|
||||
<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_linear4_head.pth">4 layers</a>)
|
||||
</tr>
|
||||
<tr>
|
||||
<td>ViT-B/14 distilled</td>
|
||||
<td align="center">:white_check_mark:</td>
|
||||
<td>
|
||||
linear head (<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_reg4_linear_head.pth">1 layer</a>,
|
||||
<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_reg4_linear4_head.pth">4 layers</a>)
|
||||
</tr>
|
||||
<tr>
|
||||
<td>ViT-L/14 distilled</td>
|
||||
<td align="center">:x:</td>
|
||||
<td>
|
||||
linear head (<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_linear_head.pth">1 layer</a>,
|
||||
<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_linear4_head.pth">4 layers</a>)
|
||||
</tr>
|
||||
<tr>
|
||||
<td>ViT-L/14 distilled</td>
|
||||
<td align="center">:white_check_mark:</td>
|
||||
<td>
|
||||
linear head (<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_reg4_linear_head.pth">1 layer</a>,
|
||||
<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_reg4_linear4_head.pth">4 layers</a>)
|
||||
</tr>
|
||||
<tr>
|
||||
<td>ViT-g/14</td>
|
||||
<td align="center">:x:</td>
|
||||
<td>
|
||||
linear head (<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_linear_head.pth">1 layer</a>,
|
||||
<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_linear4_head.pth">4 layers</a>)
|
||||
</tr>
|
||||
<tr>
|
||||
<td>ViT-g/14</td>
|
||||
<td align="center">:white_check_mark:</td>
|
||||
<td>
|
||||
linear head (<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_lreg4_inear_head.pth">1 layer</a>,
|
||||
<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_reg4_linear4_head.pth">4 layers</a>)
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
|
@ -129,10 +209,17 @@ The (full) classifier models can be loaded via PyTorch Hub:
|
|||
```python
|
||||
import torch
|
||||
|
||||
# DINOv2
|
||||
dinov2_vits14_lc = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14_lc')
|
||||
dinov2_vitb14_lc = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14_lc')
|
||||
dinov2_vitl14_lc = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14_lc')
|
||||
dinov2_vitg14_lc = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14_lc')
|
||||
|
||||
# DINOv2 with registers
|
||||
dinov2_vits14_reg_lc = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14_reg_lc')
|
||||
dinov2_vitb14_reg_lc = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14_reg_lc')
|
||||
dinov2_vitl14_reg_lc = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14_reg_lc')
|
||||
dinov2_vitg14_reg_lc = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14_reg_lc')
|
||||
```
|
||||
|
||||
### Pretrained heads - Depth estimation
|
||||
|
@ -429,29 +516,58 @@ We release the weights from evaluating the different models:
|
|||
<table style="margin: auto">
|
||||
<tr>
|
||||
<th>model</th>
|
||||
<th>with<br />registers</th>
|
||||
<th>ImageNet<br />top-1</th>
|
||||
<th>linear evaluation</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>ViT-S/14 distilled</td>
|
||||
<td align="center">:x:</td>
|
||||
<td align="right">81.1%</td>
|
||||
<td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_linear_head.pth">linear head weights</a></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>ViT-S/14 distilled</td>
|
||||
<td align="center">:white_check_mark:</td>
|
||||
<td align="right">80.8%</td>
|
||||
<td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_reg4_linear_head.pth">linear head weights</a></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>ViT-B/14 distilled</td>
|
||||
<td align="center">:x:</td>
|
||||
<td align="right">84.5%</td>
|
||||
<td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_linear_head.pth">linear head weights</a></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>ViT-B/14 distilled</td>
|
||||
<td align="center">:white_check_mark:</td>
|
||||
<td align="right">84.4%</td>
|
||||
<td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_reg4_linear_head.pth">linear head weights</a></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>ViT-L/14 distilled</td>
|
||||
<td align="center">:x:</td>
|
||||
<td align="right">86.3%</td>
|
||||
<td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_linear_head.pth">linear head weights</a></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>ViT-L/14 distilled</td>
|
||||
<td align="center">:white_check_mark:</td>
|
||||
<td align="right">86.5%</td>
|
||||
<td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_reg4_linear_head.pth">linear head weights</a></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>ViT-g/14</td>
|
||||
<td align="center">:x:</td>
|
||||
<td align="right">86.5%</td>
|
||||
<td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_linear_head.pth">linear head weights</a></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>ViT-g/14</td>
|
||||
<td align="center">:white_check_mark:</td>
|
||||
<td align="right">87.0%</td>
|
||||
<td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_reg4_linear_head.pth">linear head weights</a></td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
The performance of the provided pretrained model weights can be evaluated as follows on ImageNet-1k:
|
||||
|
@ -493,3 +609,12 @@ If you find this repository useful, please consider giving a star :star: and cit
|
|||
year={2023}
|
||||
}
|
||||
```
|
||||
|
||||
```
|
||||
@misc{darcet2023vitneedreg,
|
||||
title={Vision Transformers Need Registers},
|
||||
author={Darcet, Timothée and Oquab, Maxime and Mairal, Julien and Bojanowski, Piotr},
|
||||
journal={arXiv:2309.16588},
|
||||
year={2023}
|
||||
}
|
||||
```
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
student:
|
||||
arch: vit_base
|
||||
patch_size: 14
|
||||
num_register_tokens: 4
|
||||
interpolate_antialias: true
|
||||
interpolate_offset: 0.0
|
||||
crops:
|
||||
global_crops_size: 518 # this is to set up the position embeddings properly
|
||||
local_crops_size: 98
|
|
@ -0,0 +1,10 @@
|
|||
student:
|
||||
arch: vit_giant2
|
||||
patch_size: 14
|
||||
ffn_layer: swiglufused
|
||||
num_register_tokens: 4
|
||||
interpolate_antialias: true
|
||||
interpolate_offset: 0.0
|
||||
crops:
|
||||
global_crops_size: 518 # this is to set up the position embeddings properly
|
||||
local_crops_size: 98
|
|
@ -0,0 +1,9 @@
|
|||
student:
|
||||
arch: vit_large
|
||||
patch_size: 14
|
||||
num_register_tokens: 4
|
||||
interpolate_antialias: true
|
||||
interpolate_offset: 0.0
|
||||
crops:
|
||||
global_crops_size: 518 # this is to set up the position embeddings properly
|
||||
local_crops_size: 98
|
|
@ -0,0 +1,9 @@
|
|||
student:
|
||||
arch: vit_small
|
||||
patch_size: 14
|
||||
num_register_tokens: 4
|
||||
interpolate_antialias: true
|
||||
interpolate_offset: 0.0
|
||||
crops:
|
||||
global_crops_size: 518 # this is to set up the position embeddings properly
|
||||
local_crops_size: 98
|
|
@ -80,6 +80,9 @@ student:
|
|||
qkv_bias: true
|
||||
proj_bias: true
|
||||
ffn_bias: true
|
||||
num_register_tokens: 0
|
||||
interpolate_antialias: false
|
||||
interpolate_offset: 0.1
|
||||
teacher:
|
||||
momentum_teacher: 0.992
|
||||
final_momentum_teacher: 1
|
||||
|
|
|
@ -23,6 +23,9 @@ def _make_dinov2_model(
|
|||
init_values: float = 1.0,
|
||||
ffn_layer: str = "mlp",
|
||||
block_chunks: int = 0,
|
||||
num_register_tokens: int = 0,
|
||||
interpolate_antialias: bool = False,
|
||||
interpolate_offset: float = 0.1,
|
||||
pretrained: bool = True,
|
||||
weights: Union[Weights, str] = Weights.LVD142M,
|
||||
**kwargs,
|
||||
|
@ -35,21 +38,25 @@ def _make_dinov2_model(
|
|||
except KeyError:
|
||||
raise AssertionError(f"Unsupported weights: {weights}")
|
||||
|
||||
model_name = _make_dinov2_model_name(arch_name, patch_size)
|
||||
model_base_name = _make_dinov2_model_name(arch_name, patch_size)
|
||||
vit_kwargs = dict(
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
init_values=init_values,
|
||||
ffn_layer=ffn_layer,
|
||||
block_chunks=block_chunks,
|
||||
num_register_tokens=num_register_tokens,
|
||||
interpolate_antialias=interpolate_antialias,
|
||||
interpolate_offset=interpolate_offset,
|
||||
)
|
||||
vit_kwargs.update(**kwargs)
|
||||
model = vits.__dict__[arch_name](**vit_kwargs)
|
||||
|
||||
if pretrained:
|
||||
url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}_pretrain.pth"
|
||||
model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens)
|
||||
url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth"
|
||||
state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
model.load_state_dict(state_dict, strict=True)
|
||||
|
||||
return model
|
||||
|
||||
|
@ -80,5 +87,70 @@ def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Wei
|
|||
DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset.
|
||||
"""
|
||||
return _make_dinov2_model(
|
||||
arch_name="vit_giant2", ffn_layer="swiglufused", weights=weights, pretrained=pretrained, **kwargs
|
||||
arch_name="vit_giant2",
|
||||
ffn_layer="swiglufused",
|
||||
weights=weights,
|
||||
pretrained=pretrained,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def dinov2_vits14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
||||
"""
|
||||
DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset.
|
||||
"""
|
||||
return _make_dinov2_model(
|
||||
arch_name="vit_small",
|
||||
pretrained=pretrained,
|
||||
weights=weights,
|
||||
num_register_tokens=4,
|
||||
interpolate_antialias=True,
|
||||
interpolate_offset=0.0,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def dinov2_vitb14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
||||
"""
|
||||
DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset.
|
||||
"""
|
||||
return _make_dinov2_model(
|
||||
arch_name="vit_base",
|
||||
pretrained=pretrained,
|
||||
weights=weights,
|
||||
num_register_tokens=4,
|
||||
interpolate_antialias=True,
|
||||
interpolate_offset=0.0,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def dinov2_vitl14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
||||
"""
|
||||
DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset.
|
||||
"""
|
||||
return _make_dinov2_model(
|
||||
arch_name="vit_large",
|
||||
pretrained=pretrained,
|
||||
weights=weights,
|
||||
num_register_tokens=4,
|
||||
interpolate_antialias=True,
|
||||
interpolate_offset=0.0,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def dinov2_vitg14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
||||
"""
|
||||
DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset.
|
||||
"""
|
||||
return _make_dinov2_model(
|
||||
arch_name="vit_giant2",
|
||||
ffn_layer="swiglufused",
|
||||
weights=weights,
|
||||
pretrained=pretrained,
|
||||
num_register_tokens=4,
|
||||
interpolate_antialias=True,
|
||||
interpolate_offset=0.0,
|
||||
**kwargs,
|
||||
)
|
||||
|
|
|
@ -19,11 +19,13 @@ class Weights(Enum):
|
|||
|
||||
def _make_dinov2_linear_classification_head(
|
||||
*,
|
||||
model_name: str = "dinov2_vitl14",
|
||||
arch_name: str = "vit_large",
|
||||
patch_size: int = 14,
|
||||
embed_dim: int = 1024,
|
||||
layers: int = 4,
|
||||
pretrained: bool = True,
|
||||
weights: Union[Weights, str] = Weights.IMAGENET1K,
|
||||
num_register_tokens: int = 0,
|
||||
**kwargs,
|
||||
):
|
||||
if layers not in (1, 4):
|
||||
|
@ -37,10 +39,12 @@ def _make_dinov2_linear_classification_head(
|
|||
linear_head = nn.Linear((1 + layers) * embed_dim, 1_000)
|
||||
|
||||
if pretrained:
|
||||
model_base_name = _make_dinov2_model_name(arch_name, patch_size)
|
||||
model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens)
|
||||
layers_str = str(layers) if layers == 4 else ""
|
||||
url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}_linear{layers_str}_head.pth"
|
||||
url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_linear{layers_str}_head.pth"
|
||||
state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
|
||||
linear_head.load_state_dict(state_dict, strict=False)
|
||||
linear_head.load_state_dict(state_dict, strict=True)
|
||||
|
||||
return linear_head
|
||||
|
||||
|
@ -85,63 +89,180 @@ def _make_dinov2_linear_classifier(
|
|||
layers: int = 4,
|
||||
pretrained: bool = True,
|
||||
weights: Union[Weights, str] = Weights.IMAGENET1K,
|
||||
num_register_tokens: int = 0,
|
||||
interpolate_antialias: bool = False,
|
||||
interpolate_offset: float = 0.1,
|
||||
**kwargs,
|
||||
):
|
||||
backbone = _make_dinov2_model(arch_name=arch_name, pretrained=pretrained, **kwargs)
|
||||
backbone = _make_dinov2_model(
|
||||
arch_name=arch_name,
|
||||
pretrained=pretrained,
|
||||
num_register_tokens=num_register_tokens,
|
||||
interpolate_antialias=interpolate_antialias,
|
||||
interpolate_offset=interpolate_offset,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
embed_dim = backbone.embed_dim
|
||||
patch_size = backbone.patch_size
|
||||
model_name = _make_dinov2_model_name(arch_name, patch_size)
|
||||
linear_head = _make_dinov2_linear_classification_head(
|
||||
model_name=model_name,
|
||||
arch_name=arch_name,
|
||||
patch_size=patch_size,
|
||||
embed_dim=embed_dim,
|
||||
layers=layers,
|
||||
pretrained=pretrained,
|
||||
weights=weights,
|
||||
num_register_tokens=num_register_tokens,
|
||||
)
|
||||
|
||||
return _LinearClassifierWrapper(backbone=backbone, linear_head=linear_head, layers=layers)
|
||||
|
||||
|
||||
def dinov2_vits14_lc(
|
||||
*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs
|
||||
*,
|
||||
layers: int = 4,
|
||||
pretrained: bool = True,
|
||||
weights: Union[Weights, str] = Weights.IMAGENET1K,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-S/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
|
||||
"""
|
||||
return _make_dinov2_linear_classifier(
|
||||
arch_name="vit_small", layers=layers, pretrained=pretrained, weights=weights, **kwargs
|
||||
arch_name="vit_small",
|
||||
layers=layers,
|
||||
pretrained=pretrained,
|
||||
weights=weights,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def dinov2_vitb14_lc(
|
||||
*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs
|
||||
*,
|
||||
layers: int = 4,
|
||||
pretrained: bool = True,
|
||||
weights: Union[Weights, str] = Weights.IMAGENET1K,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-B/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
|
||||
"""
|
||||
return _make_dinov2_linear_classifier(
|
||||
arch_name="vit_base", layers=layers, pretrained=pretrained, weights=weights, **kwargs
|
||||
arch_name="vit_base",
|
||||
layers=layers,
|
||||
pretrained=pretrained,
|
||||
weights=weights,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def dinov2_vitl14_lc(
|
||||
*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs
|
||||
*,
|
||||
layers: int = 4,
|
||||
pretrained: bool = True,
|
||||
weights: Union[Weights, str] = Weights.IMAGENET1K,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-L/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
|
||||
"""
|
||||
return _make_dinov2_linear_classifier(
|
||||
arch_name="vit_large", layers=layers, pretrained=pretrained, weights=weights, **kwargs
|
||||
arch_name="vit_large",
|
||||
layers=layers,
|
||||
pretrained=pretrained,
|
||||
weights=weights,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def dinov2_vitg14_lc(
|
||||
*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs
|
||||
*,
|
||||
layers: int = 4,
|
||||
pretrained: bool = True,
|
||||
weights: Union[Weights, str] = Weights.IMAGENET1K,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-g/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
|
||||
"""
|
||||
return _make_dinov2_linear_classifier(
|
||||
arch_name="vit_giant2", layers=layers, ffn_layer="swiglufused", pretrained=pretrained, weights=weights, **kwargs
|
||||
arch_name="vit_giant2",
|
||||
layers=layers,
|
||||
ffn_layer="swiglufused",
|
||||
pretrained=pretrained,
|
||||
weights=weights,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def dinov2_vits14_reg_lc(
|
||||
*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs
|
||||
):
|
||||
"""
|
||||
Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-S/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
|
||||
"""
|
||||
return _make_dinov2_linear_classifier(
|
||||
arch_name="vit_small",
|
||||
layers=layers,
|
||||
pretrained=pretrained,
|
||||
weights=weights,
|
||||
num_register_tokens=4,
|
||||
interpolate_antialias=True,
|
||||
interpolate_offset=0.0,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def dinov2_vitb14_reg_lc(
|
||||
*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs
|
||||
):
|
||||
"""
|
||||
Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-B/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
|
||||
"""
|
||||
return _make_dinov2_linear_classifier(
|
||||
arch_name="vit_base",
|
||||
layers=layers,
|
||||
pretrained=pretrained,
|
||||
weights=weights,
|
||||
num_register_tokens=4,
|
||||
interpolate_antialias=True,
|
||||
interpolate_offset=0.0,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def dinov2_vitl14_reg_lc(
|
||||
*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs
|
||||
):
|
||||
"""
|
||||
Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-L/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
|
||||
"""
|
||||
return _make_dinov2_linear_classifier(
|
||||
arch_name="vit_large",
|
||||
layers=layers,
|
||||
pretrained=pretrained,
|
||||
weights=weights,
|
||||
num_register_tokens=4,
|
||||
interpolate_antialias=True,
|
||||
interpolate_offset=0.0,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def dinov2_vitg14_reg_lc(
|
||||
*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs
|
||||
):
|
||||
"""
|
||||
Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-g/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
|
||||
"""
|
||||
return _make_dinov2_linear_classifier(
|
||||
arch_name="vit_giant2",
|
||||
layers=layers,
|
||||
ffn_layer="swiglufused",
|
||||
pretrained=pretrained,
|
||||
weights=weights,
|
||||
num_register_tokens=4,
|
||||
interpolate_antialias=True,
|
||||
interpolate_offset=0.0,
|
||||
**kwargs,
|
||||
)
|
||||
|
|
|
@ -14,9 +14,10 @@ import torch.nn.functional as F
|
|||
_DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2"
|
||||
|
||||
|
||||
def _make_dinov2_model_name(arch_name: str, patch_size: int) -> str:
|
||||
def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str:
|
||||
compact_arch_name = arch_name.replace("_", "")[:4]
|
||||
return f"dinov2_{compact_arch_name}{patch_size}"
|
||||
registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else ""
|
||||
return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}"
|
||||
|
||||
|
||||
class CenterPadding(nn.Module):
|
||||
|
|
|
@ -23,6 +23,9 @@ def build_model(args, only_teacher=False, img_size=224):
|
|||
qkv_bias=args.qkv_bias,
|
||||
proj_bias=args.proj_bias,
|
||||
ffn_bias=args.ffn_bias,
|
||||
num_register_tokens=args.num_register_tokens,
|
||||
interpolate_offset=args.interpolate_offset,
|
||||
interpolate_antialias=args.interpolate_antialias,
|
||||
)
|
||||
teacher = vits.__dict__[args.arch](**vit_kwargs)
|
||||
if only_teacher:
|
||||
|
|
|
@ -62,6 +62,9 @@ class DinoVisionTransformer(nn.Module):
|
|||
block_fn=Block,
|
||||
ffn_layer="mlp",
|
||||
block_chunks=1,
|
||||
num_register_tokens=0,
|
||||
interpolate_antialias=False,
|
||||
interpolate_offset=0.1,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
|
@ -84,6 +87,9 @@ class DinoVisionTransformer(nn.Module):
|
|||
block_fn (nn.Module): transformer block class
|
||||
ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
|
||||
block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
|
||||
num_register_tokens: (int) number of extra cls tokens (so-called "registers")
|
||||
interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
|
||||
interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
|
||||
"""
|
||||
super().__init__()
|
||||
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
||||
|
@ -93,12 +99,19 @@ class DinoVisionTransformer(nn.Module):
|
|||
self.n_blocks = depth
|
||||
self.num_heads = num_heads
|
||||
self.patch_size = patch_size
|
||||
self.num_register_tokens = num_register_tokens
|
||||
self.interpolate_antialias = interpolate_antialias
|
||||
self.interpolate_offset = interpolate_offset
|
||||
|
||||
self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
|
||||
assert num_register_tokens >= 0
|
||||
self.register_tokens = (
|
||||
nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
|
||||
)
|
||||
|
||||
if drop_path_uniform is True:
|
||||
dpr = [drop_path_rate] * depth
|
||||
|
@ -159,6 +172,8 @@ class DinoVisionTransformer(nn.Module):
|
|||
def init_weights(self):
|
||||
trunc_normal_(self.pos_embed, std=0.02)
|
||||
nn.init.normal_(self.cls_token, std=1e-6)
|
||||
if self.register_tokens is not None:
|
||||
nn.init.normal_(self.register_tokens, std=1e-6)
|
||||
named_apply(init_weights_vit_timm, self)
|
||||
|
||||
def interpolate_pos_encoding(self, x, w, h):
|
||||
|
@ -175,7 +190,7 @@ class DinoVisionTransformer(nn.Module):
|
|||
h0 = h // self.patch_size
|
||||
# we add a small number to avoid floating point error in the interpolation
|
||||
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
||||
w0, h0 = w0 + 0.1, h0 + 0.1
|
||||
w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset
|
||||
|
||||
sqrt_N = math.sqrt(N)
|
||||
sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
|
||||
|
@ -183,6 +198,7 @@ class DinoVisionTransformer(nn.Module):
|
|||
patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
|
||||
scale_factor=(sx, sy),
|
||||
mode="bicubic",
|
||||
antialias=self.interpolate_antialias,
|
||||
)
|
||||
|
||||
assert int(w0) == patch_pos_embed.shape[-2]
|
||||
|
@ -199,6 +215,16 @@ class DinoVisionTransformer(nn.Module):
|
|||
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
||||
x = x + self.interpolate_pos_encoding(x, w, h)
|
||||
|
||||
if self.register_tokens is not None:
|
||||
x = torch.cat(
|
||||
(
|
||||
x[:, :1],
|
||||
self.register_tokens.expand(x.shape[0], -1, -1),
|
||||
x[:, 1:],
|
||||
),
|
||||
dim=1,
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
def forward_features_list(self, x_list, masks_list):
|
||||
|
@ -213,7 +239,8 @@ class DinoVisionTransformer(nn.Module):
|
|||
output.append(
|
||||
{
|
||||
"x_norm_clstoken": x_norm[:, 0],
|
||||
"x_norm_patchtokens": x_norm[:, 1:],
|
||||
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
||||
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
||||
"x_prenorm": x,
|
||||
"masks": masks,
|
||||
}
|
||||
|
@ -232,7 +259,8 @@ class DinoVisionTransformer(nn.Module):
|
|||
x_norm = self.norm(x)
|
||||
return {
|
||||
"x_norm_clstoken": x_norm[:, 0],
|
||||
"x_norm_patchtokens": x_norm[:, 1:],
|
||||
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
||||
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
||||
"x_prenorm": x,
|
||||
"masks": masks,
|
||||
}
|
||||
|
@ -305,7 +333,7 @@ def init_weights_vit_timm(module: nn.Module, name: str = ""):
|
|||
nn.init.zeros_(module.bias)
|
||||
|
||||
|
||||
def vit_small(patch_size=16, **kwargs):
|
||||
def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
|
||||
model = DinoVisionTransformer(
|
||||
patch_size=patch_size,
|
||||
embed_dim=384,
|
||||
|
@ -313,12 +341,13 @@ def vit_small(patch_size=16, **kwargs):
|
|||
num_heads=6,
|
||||
mlp_ratio=4,
|
||||
block_fn=partial(Block, attn_class=MemEffAttention),
|
||||
num_register_tokens=num_register_tokens,
|
||||
**kwargs,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def vit_base(patch_size=16, **kwargs):
|
||||
def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
|
||||
model = DinoVisionTransformer(
|
||||
patch_size=patch_size,
|
||||
embed_dim=768,
|
||||
|
@ -326,12 +355,13 @@ def vit_base(patch_size=16, **kwargs):
|
|||
num_heads=12,
|
||||
mlp_ratio=4,
|
||||
block_fn=partial(Block, attn_class=MemEffAttention),
|
||||
num_register_tokens=num_register_tokens,
|
||||
**kwargs,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def vit_large(patch_size=16, **kwargs):
|
||||
def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
|
||||
model = DinoVisionTransformer(
|
||||
patch_size=patch_size,
|
||||
embed_dim=1024,
|
||||
|
@ -339,12 +369,13 @@ def vit_large(patch_size=16, **kwargs):
|
|||
num_heads=16,
|
||||
mlp_ratio=4,
|
||||
block_fn=partial(Block, attn_class=MemEffAttention),
|
||||
num_register_tokens=num_register_tokens,
|
||||
**kwargs,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def vit_giant2(patch_size=16, **kwargs):
|
||||
def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
|
||||
"""
|
||||
Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
|
||||
"""
|
||||
|
@ -355,6 +386,7 @@ def vit_giant2(patch_size=16, **kwargs):
|
|||
num_heads=24,
|
||||
mlp_ratio=4,
|
||||
block_fn=partial(Block, attn_class=MemEffAttention),
|
||||
num_register_tokens=num_register_tokens,
|
||||
**kwargs,
|
||||
)
|
||||
return model
|
||||
|
|
|
@ -5,7 +5,9 @@
|
|||
|
||||
|
||||
from dinov2.hub.backbones import dinov2_vitb14, dinov2_vitg14, dinov2_vitl14, dinov2_vits14
|
||||
from dinov2.hub.backbones import dinov2_vitb14_reg, dinov2_vitg14_reg, dinov2_vitl14_reg, dinov2_vits14_reg
|
||||
from dinov2.hub.classifiers import dinov2_vitb14_lc, dinov2_vitg14_lc, dinov2_vitl14_lc, dinov2_vits14_lc
|
||||
from dinov2.hub.classifiers import dinov2_vitb14_reg_lc, dinov2_vitg14_reg_lc, dinov2_vitl14_reg_lc, dinov2_vits14_reg_lc
|
||||
from dinov2.hub.depthers import dinov2_vitb14_ld, dinov2_vitg14_ld, dinov2_vitl14_ld, dinov2_vits14_ld
|
||||
from dinov2.hub.depthers import dinov2_vitb14_dd, dinov2_vitg14_dd, dinov2_vitl14_dd, dinov2_vits14_dd
|
||||
|
||||
|
|
Loading…
Reference in New Issue