[Feature] Support DaViT. (#1105)
* add davit * fix mixup config * convert scripts * lint * test * test * Add checkpoint links. Co-authored-by: mzr1996 <mzr1996@163.com>pull/1207/head
parent
992d13e772
commit
c4f3883a22
|
@ -150,6 +150,7 @@ Results and models are available in the [model zoo](https://mmclassification.rea
|
|||
- [x] [MViT](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mvit)
|
||||
- [x] [HorNet](https://github.com/open-mmlab/mmclassification/tree/master/configs/hornet)
|
||||
- [x] [MobileViT](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mobilevit)
|
||||
- [x] [DaViT](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/davit)
|
||||
|
||||
</details>
|
||||
|
||||
|
|
|
@ -149,6 +149,7 @@ mim install -e .
|
|||
- [x] [MViT](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mvit)
|
||||
- [x] [HorNet](https://github.com/open-mmlab/mmclassification/tree/master/configs/hornet)
|
||||
- [x] [MobileViT](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mobilevit)
|
||||
- [x] [DaViT](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/davit)
|
||||
|
||||
</details>
|
||||
|
||||
|
|
|
@ -0,0 +1,84 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
data_preprocessor = dict(
|
||||
num_classes=1000,
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
# convert image from BGR to RGB
|
||||
to_rgb=True,
|
||||
)
|
||||
|
||||
bgr_mean = data_preprocessor['mean'][::-1]
|
||||
bgr_std = data_preprocessor['std'][::-1]
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='RandomResizedCrop',
|
||||
scale=224,
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||
dict(
|
||||
type='RandAugment',
|
||||
policies='timm_increasing',
|
||||
num_policies=2,
|
||||
total_level=10,
|
||||
magnitude_level=9,
|
||||
magnitude_std=0.5,
|
||||
hparams=dict(
|
||||
pad_val=[round(x) for x in bgr_mean], interpolation='bicubic')),
|
||||
dict(
|
||||
type='RandomErasing',
|
||||
erase_prob=0.25,
|
||||
mode='rand',
|
||||
min_area_ratio=0.02,
|
||||
max_area_ratio=1 / 3,
|
||||
fill_color=bgr_mean,
|
||||
fill_std=bgr_std),
|
||||
dict(type='PackClsInputs'),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='ResizeEdge',
|
||||
scale=236,
|
||||
edge='short',
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='CenterCrop', crop_size=224),
|
||||
dict(type='PackClsInputs'),
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=64,
|
||||
num_workers=5,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
ann_file='meta/train.txt',
|
||||
data_prefix='train',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
persistent_workers=True,
|
||||
)
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=64,
|
||||
num_workers=5,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
ann_file='meta/val.txt',
|
||||
data_prefix='val',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
persistent_workers=True,
|
||||
)
|
||||
val_evaluator = dict(type='Accuracy', topk=(1, 5))
|
||||
|
||||
# If you want standard test, please manually configure the test dataset
|
||||
test_dataloader = val_dataloader
|
||||
test_evaluator = val_evaluator
|
|
@ -0,0 +1,16 @@
|
|||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='DaViT', arch='base', out_indices=(3, ), drop_path_rate=0.4),
|
||||
neck=dict(type='GlobalAveragePooling'),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=1024,
|
||||
loss=dict(
|
||||
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
|
||||
),
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.0)
|
||||
]))
|
|
@ -0,0 +1,16 @@
|
|||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='DaViT', arch='small', out_indices=(3, ), drop_path_rate=0.2),
|
||||
neck=dict(type='GlobalAveragePooling'),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=768,
|
||||
loss=dict(
|
||||
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
|
||||
),
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.0)
|
||||
]))
|
|
@ -0,0 +1,16 @@
|
|||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='DaViT', arch='t', out_indices=(3, ), drop_path_rate=0.1),
|
||||
neck=dict(type='GlobalAveragePooling'),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=768,
|
||||
loss=dict(
|
||||
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
|
||||
),
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.0)
|
||||
]))
|
|
@ -0,0 +1,38 @@
|
|||
# DaViT
|
||||
|
||||
> [DaViT: Dual Attention Vision Transformers](https://arxiv.org/abs/2204.03645v1)
|
||||
|
||||
<!-- [ALGORITHM] -->
|
||||
|
||||
## Abstract
|
||||
|
||||
In this work, we introduce Dual Attention Vision Transformers (DaViT), a simple yet effective vision transformer architecture that is able to capture global context while maintaining computational efficiency. We propose approaching the problem from an orthogonal angle: exploiting self-attention mechanisms with both "spatial tokens" and "channel tokens". With spatial tokens, the spatial dimension defines the token scope, and the channel dimension defines the token feature dimension. With channel tokens, we have the inverse: the channel dimension defines the token scope, and the spatial dimension defines the token feature dimension. We further group tokens along the sequence direction for both spatial and channel tokens to maintain the linear complexity of the entire model. We show that these two self-attentions complement each other: (i) since each channel token contains an abstract representation of the entire image, the channel attention naturally captures global interactions and representations by taking all spatial positions into account when computing attention scores between channels; (ii) the spatial attention refines the local representations by performing fine-grained interactions across spatial locations, which in turn helps the global information modeling in channel attention. Extensive experiments show our DaViT achieves state-of-the-art performance on four different tasks with efficient computations. Without extra data, DaViT-Tiny, DaViT-Small, and DaViT-Base achieve 82.8%, 84.2%, and 84.6% top-1 accuracy on ImageNet-1K with 28.3M, 49.7M, and 87.9M parameters, respectively. When we further scale up DaViT with 1.5B weakly supervised image and text pairs, DaViT-Gaint reaches 90.4% top-1 accuracy on ImageNet-1K.
|
||||
|
||||
<div align=center>
|
||||
<img src="https://user-images.githubusercontent.com/24734142/196125065-e232409b-f710-4729-b657-4e5f9158f2d1.png" width="90%"/>
|
||||
</div>
|
||||
|
||||
## Results and models
|
||||
|
||||
### ImageNet-1k
|
||||
|
||||
| Model | Pretrain | resolution | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download |
|
||||
| :-------: | :----------: | :--------: | :-------: | :------: | :-------: | :-------: | :------------------------------------: | :----------------------------------------------------------------------------------------------: |
|
||||
| DaViT-T\* | From scratch | 224x224 | 28.36 | 4.54 | 82.24 | 96.13 | [config](./davit-tiny_4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/davit/davit-tiny_3rdparty_in1k_20221116-700fdf7d.pth) |
|
||||
| DaViT-S\* | From scratch | 224x224 | 49.74 | 8.79 | 83.61 | 96.75 | [config](./davit-small_4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/davit/davit-small_3rdparty_in1k_20221116-51a849a6.pth) |
|
||||
| DaViT-B\* | From scratch | 224x224 | 87.95 | 15.5 | 84.09 | 96.82 | [config](./davit-base_4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/davit/davit-base_3rdparty_in1k_20221116-19e0d956.pth) |
|
||||
|
||||
*Models with * are converted from the [official repo](https://github.com/dingmyu/davit). The config files of these models are only for validation. We don't ensure these config files' training accuracy and welcome you to contribute your reproduction results.*
|
||||
|
||||
Note: Inference accuracy is a bit lower than paper result because of inference code for classification doesn't exist.
|
||||
|
||||
## Citation
|
||||
|
||||
```
|
||||
@inproceedings{ding2022davit,
|
||||
title={DaViT: Dual Attention Vision Transformer},
|
||||
author={Ding, Mingyu and Xiao, Bin and Codella, Noel and Luo, Ping and Wang, Jingdong and Yuan, Lu},
|
||||
booktitle={ECCV},
|
||||
year={2022},
|
||||
}
|
||||
```
|
|
@ -0,0 +1,9 @@
|
|||
_base_ = [
|
||||
'../_base_/models/davit/davit-base.py',
|
||||
'../_base_/datasets/imagenet_bs256_davit_224.py',
|
||||
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
|
||||
'../_base_/default_runtime.py'
|
||||
]
|
||||
|
||||
# data settings
|
||||
train_dataloader = dict(batch_size=256)
|
|
@ -0,0 +1,9 @@
|
|||
_base_ = [
|
||||
'../_base_/models/davit/davit-small.py',
|
||||
'../_base_/datasets/imagenet_bs256_davit_224.py',
|
||||
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
|
||||
'../_base_/default_runtime.py'
|
||||
]
|
||||
|
||||
# data settings
|
||||
train_dataloader = dict(batch_size=256)
|
|
@ -0,0 +1,9 @@
|
|||
_base_ = [
|
||||
'../_base_/models/davit/davit-tiny.py',
|
||||
'../_base_/datasets/imagenet_bs256_davit_224.py',
|
||||
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
|
||||
'../_base_/default_runtime.py'
|
||||
]
|
||||
|
||||
# data settings
|
||||
train_dataloader = dict(batch_size=256)
|
|
@ -0,0 +1,71 @@
|
|||
Collections:
|
||||
- Name: DaViT
|
||||
Metadata:
|
||||
Architecture:
|
||||
- GELU
|
||||
- Layer Normalization
|
||||
- Multi-Head Attention
|
||||
- Scaled Dot-Product Attention
|
||||
Paper:
|
||||
URL: https://arxiv.org/abs/2204.03645v1
|
||||
Title: 'DaViT: Dual Attention Vision Transformers'
|
||||
README: configs/davit/README.md
|
||||
Code:
|
||||
URL: https://github.com/open-mmlab/mmclassification/blob/v1.0.0rc3/mmcls/models/backbones/davit.py
|
||||
Version: v1.0.0rc3
|
||||
|
||||
Models:
|
||||
- Name: davit-tiny_3rdparty_in1k
|
||||
In Collection: DaViT
|
||||
Metadata:
|
||||
FLOPs: 4539698688
|
||||
Parameters: 28360168
|
||||
Training Data:
|
||||
- ImageNet-1k
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Task: Image Classification
|
||||
Metrics:
|
||||
Top 1 Accuracy: 82.24
|
||||
Top 5 Accuracy: 96.13
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/davit/davit-tiny_3rdparty_in1k_20221116-700fdf7d.pth
|
||||
Converted From:
|
||||
Weights: https://drive.google.com/file/d/1RSpi3lxKaloOL5-or20HuG975tbPwxRZ/view?usp=sharing
|
||||
Code: https://github.com/dingmyu/davit/blob/main/mmdet/mmdet/models/backbones/davit.py#L355
|
||||
Config: configs/davit/davit-tiny_4xb256_in1k.py
|
||||
- Name: davit-small_3rdparty_in1k
|
||||
In Collection: DaViT
|
||||
Metadata:
|
||||
FLOPs: 8799942144
|
||||
Parameters: 49745896
|
||||
Training Data:
|
||||
- ImageNet-1k
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Task: Image Classification
|
||||
Metrics:
|
||||
Top 1 Accuracy: 83.61
|
||||
Top 5 Accuracy: 96.75
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/davit/davit-small_3rdparty_in1k_20221116-51a849a6.pth
|
||||
Converted From:
|
||||
Weights: https://drive.google.com/file/d/1q976ruj45mt0RhO9oxhOo6EP_cmj4ahQ/view?usp=sharing
|
||||
Code: https://github.com/dingmyu/davit/blob/main/mmdet/mmdet/models/backbones/davit.py#L355
|
||||
Config: configs/davit/davit-small_4xb256_in1k.py
|
||||
- Name: davit-base_3rdparty_in1k
|
||||
In Collection: DaViT
|
||||
Metadata:
|
||||
FLOPs: 15509702656
|
||||
Parameters: 87954408
|
||||
Training Data:
|
||||
- ImageNet-1k
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Task: Image Classification
|
||||
Metrics:
|
||||
Top 1 Accuracy: 84.09
|
||||
Top 5 Accuracy: 96.82
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/davit/davit-base_3rdparty_in1k_20221116-19e0d956.pth
|
||||
Converted From:
|
||||
Weights: https://drive.google.com/file/d/1u9sDBEueB-YFuLigvcwf4b2YyA4MIVsZ/view?usp=sharing
|
||||
Code: https://github.com/dingmyu/davit/blob/main/mmdet/mmdet/models/backbones/davit.py#L355
|
||||
Config: configs/davit/davit-base_4xb256_in1k.py
|
|
@ -65,6 +65,7 @@ Backbones
|
|||
Conformer
|
||||
ConvMixer
|
||||
ConvNeXt
|
||||
DaViT
|
||||
DeiT3
|
||||
DenseNet
|
||||
DistilledVisionTransformer
|
||||
|
|
|
@ -4,6 +4,7 @@ from .conformer import Conformer
|
|||
from .convmixer import ConvMixer
|
||||
from .convnext import ConvNeXt
|
||||
from .cspnet import CSPDarkNet, CSPNet, CSPResNet, CSPResNeXt
|
||||
from .davit import DaViT
|
||||
from .deit import DistilledVisionTransformer
|
||||
from .deit3 import DeiT3
|
||||
from .densenet import DenseNet
|
||||
|
@ -93,4 +94,5 @@ __all__ = [
|
|||
'DeiT3',
|
||||
'HorNet',
|
||||
'MobileViT',
|
||||
'DaViT',
|
||||
]
|
||||
|
|
|
@ -0,0 +1,834 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from copy import deepcopy
|
||||
from typing import Sequence, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn import build_conv_layer, build_norm_layer
|
||||
from mmcv.cnn.bricks import Conv2d
|
||||
from mmcv.cnn.bricks.transformer import FFN, AdaptivePadding, PatchEmbed
|
||||
from mmengine.model import BaseModule, ModuleList
|
||||
from mmengine.utils import to_2tuple
|
||||
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from mmcls.models.backbones.base_backbone import BaseBackbone
|
||||
from mmcls.registry import MODELS
|
||||
from ..utils import ShiftWindowMSA
|
||||
|
||||
|
||||
class DaViTWindowMSA(BaseModule):
|
||||
"""Window based multi-head self-attention (W-MSA) module for DaViT.
|
||||
|
||||
The differences between DaViTWindowMSA & WindowMSA:
|
||||
1. Without relative position bias.
|
||||
|
||||
Args:
|
||||
embed_dims (int): Number of input channels.
|
||||
window_size (tuple[int]): The height and width of the window.
|
||||
num_heads (int): Number of attention heads.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
|
||||
Defaults to True.
|
||||
qk_scale (float, optional): Override default qk scale of
|
||||
``head_dim ** -0.5`` if set. Defaults to None.
|
||||
attn_drop (float, optional): Dropout ratio of attention weight.
|
||||
Defaults to 0.
|
||||
proj_drop (float, optional): Dropout ratio of output. Defaults to 0.
|
||||
init_cfg (dict, optional): The extra config for initialization.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
window_size,
|
||||
num_heads,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
attn_drop=0.,
|
||||
proj_drop=0.,
|
||||
init_cfg=None):
|
||||
|
||||
super().__init__(init_cfg)
|
||||
self.embed_dims = embed_dims
|
||||
self.window_size = window_size # Wh, Ww
|
||||
self.num_heads = num_heads
|
||||
head_embed_dims = embed_dims // num_heads
|
||||
self.scale = qk_scale or head_embed_dims**-0.5
|
||||
|
||||
self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(embed_dims, embed_dims)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
"""
|
||||
Args:
|
||||
|
||||
x (tensor): input features with shape of (num_windows*B, N, C)
|
||||
mask (tensor, Optional): mask with shape of (num_windows, Wh*Ww,
|
||||
Wh*Ww), value should be between (-inf, 0].
|
||||
"""
|
||||
B_, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads,
|
||||
C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[
|
||||
2] # make torchscript happy (cannot use tensor as tuple)
|
||||
|
||||
q = q * self.scale
|
||||
attn = (q @ k.transpose(-2, -1))
|
||||
|
||||
if mask is not None:
|
||||
nW = mask.shape[0]
|
||||
attn = attn.view(B_ // nW, nW, self.num_heads, N,
|
||||
N) + mask.unsqueeze(1).unsqueeze(0)
|
||||
attn = attn.view(-1, self.num_heads, N, N)
|
||||
attn = self.softmax(attn)
|
||||
else:
|
||||
attn = self.softmax(attn)
|
||||
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def double_step_seq(step1, len1, step2, len2):
|
||||
seq1 = torch.arange(0, step1 * len1, step1)
|
||||
seq2 = torch.arange(0, step2 * len2, step2)
|
||||
return (seq1[:, None] + seq2[None, :]).reshape(1, -1)
|
||||
|
||||
|
||||
class ConvPosEnc(BaseModule):
|
||||
"""DaViT conv pos encode block.
|
||||
|
||||
Args:
|
||||
embed_dims (int): Number of input channels.
|
||||
kernel_size (int): The kernel size of the first convolution.
|
||||
Defaults to 3.
|
||||
init_cfg (dict, optional): The extra config for initialization.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self, embed_dims, kernel_size=3, init_cfg=None):
|
||||
super(ConvPosEnc, self).__init__(init_cfg)
|
||||
self.proj = Conv2d(
|
||||
embed_dims,
|
||||
embed_dims,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=kernel_size // 2,
|
||||
groups=embed_dims)
|
||||
|
||||
def forward(self, x, size: Tuple[int, int]):
|
||||
B, N, C = x.shape
|
||||
H, W = size
|
||||
assert N == H * W
|
||||
|
||||
feat = x.transpose(1, 2).view(B, C, H, W)
|
||||
feat = self.proj(feat)
|
||||
feat = feat.flatten(2).transpose(1, 2)
|
||||
x = x + feat
|
||||
return x
|
||||
|
||||
|
||||
class DaViTDownSample(BaseModule):
|
||||
"""DaViT down sampole block.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of input channels.
|
||||
out_channels (int): The number of output channels.
|
||||
conv_type (str): The type of convolution
|
||||
to generate patch embedding. Default: "Conv2d".
|
||||
kernel_size (int): The kernel size of the first convolution.
|
||||
Defaults to 2.
|
||||
stride (int): The stride of the second convluation module.
|
||||
Defaults to 2.
|
||||
padding (int | tuple | string ): The padding length of
|
||||
embedding conv. When it is a string, it means the mode
|
||||
of adaptive padding, support "same" and "corner" now.
|
||||
Defaults to "corner".
|
||||
dilation (int): Dilation of the convolution layers. Defaults to 1.
|
||||
bias (bool): Bias of embed conv. Default: True.
|
||||
norm_cfg (dict, optional): Config dict for normalization layer.
|
||||
Defaults to ``dict(type='LN')``.
|
||||
init_cfg (dict, optional): The extra config for initialization.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
conv_type='Conv2d',
|
||||
kernel_size=2,
|
||||
stride=2,
|
||||
padding='same',
|
||||
dilation=1,
|
||||
bias=True,
|
||||
norm_cfg=None,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.out_channels = out_channels
|
||||
if stride is None:
|
||||
stride = kernel_size
|
||||
|
||||
kernel_size = to_2tuple(kernel_size)
|
||||
stride = to_2tuple(stride)
|
||||
dilation = to_2tuple(dilation)
|
||||
|
||||
if isinstance(padding, str):
|
||||
self.adaptive_padding = AdaptivePadding(
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
padding=padding)
|
||||
# disable the padding of conv
|
||||
padding = 0
|
||||
else:
|
||||
self.adaptive_padding = None
|
||||
padding = to_2tuple(padding)
|
||||
|
||||
self.projection = build_conv_layer(
|
||||
dict(type=conv_type),
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
bias=bias)
|
||||
|
||||
if norm_cfg is not None:
|
||||
self.norm = build_norm_layer(norm_cfg, in_channels)[1]
|
||||
else:
|
||||
self.norm = None
|
||||
|
||||
def forward(self, x, input_size):
|
||||
if self.adaptive_padding:
|
||||
x = self.adaptive_padding(x)
|
||||
H, W = input_size
|
||||
B, L, C = x.shape
|
||||
assert L == H * W, 'input feature has wrong size'
|
||||
|
||||
x = self.norm(x)
|
||||
x = x.reshape(B, H, W, C).permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
x = self.projection(x)
|
||||
output_size = (x.size(2), x.size(3))
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
return x, output_size
|
||||
|
||||
|
||||
class ChannelAttention(BaseModule):
|
||||
"""DaViT channel attention.
|
||||
|
||||
Args:
|
||||
embed_dims (int): Number of input channels.
|
||||
num_heads (int): Number of attention heads.
|
||||
qkv_bias (bool): enable bias for qkv if True. Defaults to True.
|
||||
init_cfg (dict, optional): The extra config for initialization.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self, embed_dims, num_heads=8, qkv_bias=False, init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
self.embed_dims = embed_dims
|
||||
self.num_heads = num_heads
|
||||
self.head_dims = embed_dims // num_heads
|
||||
self.scale = self.head_dims**-0.5
|
||||
|
||||
self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias)
|
||||
self.proj = nn.Linear(embed_dims, embed_dims)
|
||||
|
||||
def forward(self, x):
|
||||
B, N, _ = x.shape
|
||||
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
|
||||
self.head_dims).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
|
||||
k = k * self.scale
|
||||
attention = k.transpose(-1, -2) @ v
|
||||
attention = attention.softmax(dim=-1)
|
||||
|
||||
x = (attention @ q.transpose(-1, -2)).transpose(-1, -2)
|
||||
x = x.transpose(1, 2).reshape(B, N, self.embed_dims)
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class ChannelBlock(BaseModule):
|
||||
"""DaViT channel attention block.
|
||||
|
||||
Args:
|
||||
embed_dims (int): Number of input channels.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (int): The height and width of the window. Defaults to 7.
|
||||
ffn_ratio (float): The expansion ratio of feedforward network hidden
|
||||
layer channels. Defaults to 4.
|
||||
qkv_bias (bool): enable bias for qkv if True. Defaults to True.
|
||||
drop_path (float): The drop path rate after attention and ffn.
|
||||
Defaults to 0.
|
||||
ffn_cfgs (dict): The extra config of FFN. Defaults to empty dict.
|
||||
norm_cfg (dict): The config of norm layers.
|
||||
Defaults to ``dict(type='LN')``.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Defaults to False.
|
||||
init_cfg (dict, optional): The extra config for initialization.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
num_heads,
|
||||
ffn_ratio=4.,
|
||||
qkv_bias=False,
|
||||
drop_path=0.,
|
||||
ffn_cfgs=dict(),
|
||||
norm_cfg=dict(type='LN'),
|
||||
with_cp=False,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
self.with_cp = with_cp
|
||||
|
||||
self.cpe1 = ConvPosEnc(embed_dims=embed_dims, kernel_size=3)
|
||||
self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
|
||||
self.attn = ChannelAttention(
|
||||
embed_dims, num_heads=num_heads, qkv_bias=qkv_bias)
|
||||
self.cpe2 = ConvPosEnc(embed_dims=embed_dims, kernel_size=3)
|
||||
|
||||
_ffn_cfgs = {
|
||||
'embed_dims': embed_dims,
|
||||
'feedforward_channels': int(embed_dims * ffn_ratio),
|
||||
'num_fcs': 2,
|
||||
'ffn_drop': 0,
|
||||
'dropout_layer': dict(type='DropPath', drop_prob=drop_path),
|
||||
'act_cfg': dict(type='GELU'),
|
||||
**ffn_cfgs
|
||||
}
|
||||
self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
|
||||
self.ffn = FFN(**_ffn_cfgs)
|
||||
|
||||
def forward(self, x, hw_shape):
|
||||
|
||||
def _inner_forward(x):
|
||||
x = self.cpe1(x, hw_shape)
|
||||
identity = x
|
||||
x = self.norm1(x)
|
||||
x = self.attn(x)
|
||||
x = x + identity
|
||||
|
||||
x = self.cpe2(x, hw_shape)
|
||||
identity = x
|
||||
x = self.norm2(x)
|
||||
x = self.ffn(x, identity=identity)
|
||||
|
||||
return x
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
x = cp.checkpoint(_inner_forward, x)
|
||||
else:
|
||||
x = _inner_forward(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class SpatialBlock(BaseModule):
|
||||
"""DaViT spatial attention block.
|
||||
|
||||
Args:
|
||||
embed_dims (int): Number of input channels.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (int): The height and width of the window. Defaults to 7.
|
||||
ffn_ratio (float): The expansion ratio of feedforward network hidden
|
||||
layer channels. Defaults to 4.
|
||||
qkv_bias (bool): enable bias for qkv if True. Defaults to True.
|
||||
drop_path (float): The drop path rate after attention and ffn.
|
||||
Defaults to 0.
|
||||
pad_small_map (bool): If True, pad the small feature map to the window
|
||||
size, which is common used in detection and segmentation. If False,
|
||||
avoid shifting window and shrink the window size to the size of
|
||||
feature map, which is common used in classification.
|
||||
Defaults to False.
|
||||
attn_cfgs (dict): The extra config of Shift Window-MSA.
|
||||
Defaults to empty dict.
|
||||
ffn_cfgs (dict): The extra config of FFN. Defaults to empty dict.
|
||||
norm_cfg (dict): The config of norm layers.
|
||||
Defaults to ``dict(type='LN')``.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Defaults to False.
|
||||
init_cfg (dict, optional): The extra config for initialization.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
num_heads,
|
||||
window_size=7,
|
||||
ffn_ratio=4.,
|
||||
qkv_bias=True,
|
||||
drop_path=0.,
|
||||
pad_small_map=False,
|
||||
attn_cfgs=dict(),
|
||||
ffn_cfgs=dict(),
|
||||
norm_cfg=dict(type='LN'),
|
||||
with_cp=False,
|
||||
init_cfg=None):
|
||||
|
||||
super(SpatialBlock, self).__init__(init_cfg)
|
||||
self.with_cp = with_cp
|
||||
|
||||
self.cpe1 = ConvPosEnc(embed_dims=embed_dims, kernel_size=3)
|
||||
self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
|
||||
_attn_cfgs = {
|
||||
'embed_dims': embed_dims,
|
||||
'num_heads': num_heads,
|
||||
'shift_size': 0,
|
||||
'window_size': window_size,
|
||||
'dropout_layer': dict(type='DropPath', drop_prob=drop_path),
|
||||
'qkv_bias': qkv_bias,
|
||||
'pad_small_map': pad_small_map,
|
||||
'window_msa': DaViTWindowMSA,
|
||||
**attn_cfgs
|
||||
}
|
||||
self.attn = ShiftWindowMSA(**_attn_cfgs)
|
||||
self.cpe2 = ConvPosEnc(embed_dims=embed_dims, kernel_size=3)
|
||||
|
||||
_ffn_cfgs = {
|
||||
'embed_dims': embed_dims,
|
||||
'feedforward_channels': int(embed_dims * ffn_ratio),
|
||||
'num_fcs': 2,
|
||||
'ffn_drop': 0,
|
||||
'dropout_layer': dict(type='DropPath', drop_prob=drop_path),
|
||||
'act_cfg': dict(type='GELU'),
|
||||
**ffn_cfgs
|
||||
}
|
||||
self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
|
||||
self.ffn = FFN(**_ffn_cfgs)
|
||||
|
||||
def forward(self, x, hw_shape):
|
||||
|
||||
def _inner_forward(x):
|
||||
x = self.cpe1(x, hw_shape)
|
||||
identity = x
|
||||
x = self.norm1(x)
|
||||
x = self.attn(x, hw_shape)
|
||||
x = x + identity
|
||||
|
||||
x = self.cpe2(x, hw_shape)
|
||||
identity = x
|
||||
x = self.norm2(x)
|
||||
x = self.ffn(x, identity=identity)
|
||||
|
||||
return x
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
x = cp.checkpoint(_inner_forward, x)
|
||||
else:
|
||||
x = _inner_forward(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class DaViTBlock(BaseModule):
|
||||
"""DaViT block.
|
||||
|
||||
Args:
|
||||
embed_dims (int): Number of input channels.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (int): The height and width of the window. Defaults to 7.
|
||||
ffn_ratio (float): The expansion ratio of feedforward network hidden
|
||||
layer channels. Defaults to 4.
|
||||
qkv_bias (bool): enable bias for qkv if True. Defaults to True.
|
||||
drop_path (float): The drop path rate after attention and ffn.
|
||||
Defaults to 0.
|
||||
pad_small_map (bool): If True, pad the small feature map to the window
|
||||
size, which is common used in detection and segmentation. If False,
|
||||
avoid shifting window and shrink the window size to the size of
|
||||
feature map, which is common used in classification.
|
||||
Defaults to False.
|
||||
attn_cfgs (dict): The extra config of Shift Window-MSA.
|
||||
Defaults to empty dict.
|
||||
ffn_cfgs (dict): The extra config of FFN. Defaults to empty dict.
|
||||
norm_cfg (dict): The config of norm layers.
|
||||
Defaults to ``dict(type='LN')``.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Defaults to False.
|
||||
init_cfg (dict, optional): The extra config for initialization.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
num_heads,
|
||||
window_size=7,
|
||||
ffn_ratio=4.,
|
||||
qkv_bias=True,
|
||||
drop_path=0.,
|
||||
pad_small_map=False,
|
||||
attn_cfgs=dict(),
|
||||
ffn_cfgs=dict(),
|
||||
norm_cfg=dict(type='LN'),
|
||||
with_cp=False,
|
||||
init_cfg=None):
|
||||
|
||||
super(DaViTBlock, self).__init__(init_cfg)
|
||||
self.spatial_block = SpatialBlock(
|
||||
embed_dims,
|
||||
num_heads,
|
||||
window_size=window_size,
|
||||
ffn_ratio=ffn_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
drop_path=drop_path,
|
||||
pad_small_map=pad_small_map,
|
||||
attn_cfgs=attn_cfgs,
|
||||
ffn_cfgs=ffn_cfgs,
|
||||
norm_cfg=norm_cfg,
|
||||
with_cp=with_cp)
|
||||
self.channel_block = ChannelBlock(
|
||||
embed_dims,
|
||||
num_heads,
|
||||
ffn_ratio=ffn_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
drop_path=drop_path,
|
||||
ffn_cfgs=ffn_cfgs,
|
||||
norm_cfg=norm_cfg,
|
||||
with_cp=False)
|
||||
|
||||
def forward(self, x, hw_shape):
|
||||
x = self.spatial_block(x, hw_shape)
|
||||
x = self.channel_block(x, hw_shape)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class DaViTBlockSequence(BaseModule):
|
||||
"""Module with successive DaViT blocks and downsample layer.
|
||||
|
||||
Args:
|
||||
embed_dims (int): Number of input channels.
|
||||
depth (int): Number of successive DaViT blocks.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (int): The height and width of the window. Defaults to 7.
|
||||
ffn_ratio (float): The expansion ratio of feedforward network hidden
|
||||
layer channels. Defaults to 4.
|
||||
qkv_bias (bool): enable bias for qkv if True. Defaults to True.
|
||||
downsample (bool): Downsample the output of blocks by patch merging.
|
||||
Defaults to False.
|
||||
downsample_cfg (dict): The extra config of the patch merging layer.
|
||||
Defaults to empty dict.
|
||||
drop_paths (Sequence[float] | float): The drop path rate in each block.
|
||||
Defaults to 0.
|
||||
block_cfgs (Sequence[dict] | dict): The extra config of each block.
|
||||
Defaults to empty dicts.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Defaults to False.
|
||||
pad_small_map (bool): If True, pad the small feature map to the window
|
||||
size, which is common used in detection and segmentation. If False,
|
||||
avoid shifting window and shrink the window size to the size of
|
||||
feature map, which is common used in classification.
|
||||
Defaults to False.
|
||||
init_cfg (dict, optional): The extra config for initialization.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
depth,
|
||||
num_heads,
|
||||
window_size=7,
|
||||
ffn_ratio=4.,
|
||||
qkv_bias=True,
|
||||
downsample=False,
|
||||
downsample_cfg=dict(),
|
||||
drop_paths=0.,
|
||||
block_cfgs=dict(),
|
||||
with_cp=False,
|
||||
pad_small_map=False,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
|
||||
if not isinstance(drop_paths, Sequence):
|
||||
drop_paths = [drop_paths] * depth
|
||||
|
||||
if not isinstance(block_cfgs, Sequence):
|
||||
block_cfgs = [deepcopy(block_cfgs) for _ in range(depth)]
|
||||
|
||||
self.embed_dims = embed_dims
|
||||
self.blocks = ModuleList()
|
||||
for i in range(depth):
|
||||
_block_cfg = {
|
||||
'embed_dims': embed_dims,
|
||||
'num_heads': num_heads,
|
||||
'window_size': window_size,
|
||||
'ffn_ratio': ffn_ratio,
|
||||
'qkv_bias': qkv_bias,
|
||||
'drop_path': drop_paths[i],
|
||||
'with_cp': with_cp,
|
||||
'pad_small_map': pad_small_map,
|
||||
**block_cfgs[i]
|
||||
}
|
||||
block = DaViTBlock(**_block_cfg)
|
||||
self.blocks.append(block)
|
||||
|
||||
if downsample:
|
||||
_downsample_cfg = {
|
||||
'in_channels': embed_dims,
|
||||
'out_channels': 2 * embed_dims,
|
||||
'norm_cfg': dict(type='LN'),
|
||||
**downsample_cfg
|
||||
}
|
||||
self.downsample = DaViTDownSample(**_downsample_cfg)
|
||||
else:
|
||||
self.downsample = None
|
||||
|
||||
def forward(self, x, in_shape, do_downsample=True):
|
||||
for block in self.blocks:
|
||||
x = block(x, in_shape)
|
||||
|
||||
if self.downsample is not None and do_downsample:
|
||||
x, out_shape = self.downsample(x, in_shape)
|
||||
else:
|
||||
out_shape = in_shape
|
||||
return x, out_shape
|
||||
|
||||
@property
|
||||
def out_channels(self):
|
||||
if self.downsample:
|
||||
return self.downsample.out_channels
|
||||
else:
|
||||
return self.embed_dims
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class DaViT(BaseBackbone):
|
||||
"""DaViT.
|
||||
|
||||
A PyTorch implement of : `DaViT: Dual Attention Vision Transformers
|
||||
<https://arxiv.org/abs/2204.03645v1>`_
|
||||
|
||||
Inspiration from
|
||||
https://github.com/dingmyu/davit
|
||||
|
||||
Args:
|
||||
arch (str | dict): DaViT architecture. If use string, choose from
|
||||
'tiny', 'small', 'base' and 'large', 'huge', 'giant'. If use dict,
|
||||
it should have below keys:
|
||||
|
||||
- **embed_dims** (int): The dimensions of embedding.
|
||||
- **depths** (List[int]): The number of blocks in each stage.
|
||||
- **num_heads** (List[int]): The number of heads in attention
|
||||
modules of each stage.
|
||||
|
||||
Defaults to 't'.
|
||||
patch_size (int | tuple): The patch size in patch embedding.
|
||||
Defaults to 4.
|
||||
in_channels (int): The num of input channels. Defaults to 3.
|
||||
window_size (int): The height and width of the window. Defaults to 7.
|
||||
ffn_ratio (float): The expansion ratio of feedforward network hidden
|
||||
layer channels. Defaults to 4.
|
||||
qkv_bias (bool): Whether to add bias for qkv in attention modules.
|
||||
Defaults to True.
|
||||
drop_path_rate (float): Stochastic depth rate. Defaults to 0.1.
|
||||
out_after_downsample (bool): Whether to output the feature map of a
|
||||
stage after the following downsample layer. Defaults to False.
|
||||
pad_small_map (bool): If True, pad the small feature map to the window
|
||||
size, which is common used in detection and segmentation. If False,
|
||||
avoid shifting window and shrink the window size to the size of
|
||||
feature map, which is common used in classification.
|
||||
Defaults to False.
|
||||
norm_cfg (dict): Config dict for normalization layer for all output
|
||||
features. Defaults to ``dict(type='LN')``
|
||||
stage_cfgs (Sequence[dict] | dict): Extra config dict for each
|
||||
stage. Defaults to an empty dict.
|
||||
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
||||
-1 means not freezing any parameters. Defaults to -1.
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only. Defaults to False.
|
||||
out_indices (Sequence | int): Output from which stages.
|
||||
Defaults to -1, means the last stage.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Defaults to False.
|
||||
init_cfg (dict, optional): The Config for initialization.
|
||||
Defaults to None.
|
||||
"""
|
||||
arch_zoo = {
|
||||
**dict.fromkeys(['t', 'tiny'], {
|
||||
'embed_dims': 96,
|
||||
'depths': [1, 1, 3, 1],
|
||||
'num_heads': [3, 6, 12, 24]
|
||||
}),
|
||||
**dict.fromkeys(['s', 'small'], {
|
||||
'embed_dims': 96,
|
||||
'depths': [1, 1, 9, 1],
|
||||
'num_heads': [3, 6, 12, 24]
|
||||
}),
|
||||
**dict.fromkeys(['b', 'base'], {
|
||||
'embed_dims': 128,
|
||||
'depths': [1, 1, 9, 1],
|
||||
'num_heads': [4, 8, 16, 32]
|
||||
}),
|
||||
**dict.fromkeys(
|
||||
['l', 'large'], {
|
||||
'embed_dims': 192,
|
||||
'depths': [1, 1, 9, 1],
|
||||
'num_heads': [6, 12, 24, 48]
|
||||
}),
|
||||
**dict.fromkeys(
|
||||
['h', 'huge'], {
|
||||
'embed_dims': 256,
|
||||
'depths': [1, 1, 9, 1],
|
||||
'num_heads': [8, 16, 32, 64]
|
||||
}),
|
||||
**dict.fromkeys(
|
||||
['g', 'giant'], {
|
||||
'embed_dims': 384,
|
||||
'depths': [1, 1, 12, 3],
|
||||
'num_heads': [12, 24, 48, 96]
|
||||
}),
|
||||
}
|
||||
|
||||
def __init__(self,
|
||||
arch='t',
|
||||
patch_size=4,
|
||||
in_channels=3,
|
||||
window_size=7,
|
||||
ffn_ratio=4.,
|
||||
qkv_bias=True,
|
||||
drop_path_rate=0.1,
|
||||
out_after_downsample=False,
|
||||
pad_small_map=False,
|
||||
norm_cfg=dict(type='LN'),
|
||||
stage_cfgs=dict(),
|
||||
frozen_stages=-1,
|
||||
norm_eval=False,
|
||||
out_indices=(3, ),
|
||||
with_cp=False,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
|
||||
if isinstance(arch, str):
|
||||
arch = arch.lower()
|
||||
assert arch in set(self.arch_zoo), \
|
||||
f'Arch {arch} is not in default archs {set(self.arch_zoo)}'
|
||||
self.arch_settings = self.arch_zoo[arch]
|
||||
else:
|
||||
essential_keys = {'embed_dims', 'depths', 'num_heads'}
|
||||
assert isinstance(arch, dict) and essential_keys <= set(arch), \
|
||||
f'Custom arch needs a dict with keys {essential_keys}'
|
||||
self.arch_settings = arch
|
||||
|
||||
self.embed_dims = self.arch_settings['embed_dims']
|
||||
self.depths = self.arch_settings['depths']
|
||||
self.num_heads = self.arch_settings['num_heads']
|
||||
self.num_layers = len(self.depths)
|
||||
self.out_indices = out_indices
|
||||
self.out_after_downsample = out_after_downsample
|
||||
self.frozen_stages = frozen_stages
|
||||
self.norm_eval = norm_eval
|
||||
|
||||
# stochastic depth decay rule
|
||||
total_depth = sum(self.depths)
|
||||
dpr = [
|
||||
x.item() for x in torch.linspace(0, drop_path_rate, total_depth)
|
||||
] # stochastic depth decay rule
|
||||
|
||||
_patch_cfg = dict(
|
||||
in_channels=in_channels,
|
||||
embed_dims=self.embed_dims,
|
||||
conv_type='Conv2d',
|
||||
kernel_size=7,
|
||||
stride=patch_size,
|
||||
padding='same',
|
||||
norm_cfg=dict(type='LN'),
|
||||
)
|
||||
self.patch_embed = PatchEmbed(**_patch_cfg)
|
||||
|
||||
self.stages = ModuleList()
|
||||
embed_dims = [self.embed_dims]
|
||||
for i, (depth,
|
||||
num_heads) in enumerate(zip(self.depths, self.num_heads)):
|
||||
if isinstance(stage_cfgs, Sequence):
|
||||
stage_cfg = stage_cfgs[i]
|
||||
else:
|
||||
stage_cfg = deepcopy(stage_cfgs)
|
||||
downsample = True if i < self.num_layers - 1 else False
|
||||
_stage_cfg = {
|
||||
'embed_dims': embed_dims[-1],
|
||||
'depth': depth,
|
||||
'num_heads': num_heads,
|
||||
'window_size': window_size,
|
||||
'ffn_ratio': ffn_ratio,
|
||||
'qkv_bias': qkv_bias,
|
||||
'downsample': downsample,
|
||||
'drop_paths': dpr[:depth],
|
||||
'with_cp': with_cp,
|
||||
'pad_small_map': pad_small_map,
|
||||
**stage_cfg
|
||||
}
|
||||
|
||||
stage = DaViTBlockSequence(**_stage_cfg)
|
||||
self.stages.append(stage)
|
||||
|
||||
dpr = dpr[depth:]
|
||||
embed_dims.append(stage.out_channels)
|
||||
|
||||
self.num_features = embed_dims[:-1]
|
||||
|
||||
# add a norm layer for each output
|
||||
for i in out_indices:
|
||||
if norm_cfg is not None:
|
||||
norm_layer = build_norm_layer(norm_cfg,
|
||||
self.num_features[i])[1]
|
||||
else:
|
||||
norm_layer = nn.Identity()
|
||||
|
||||
self.add_module(f'norm{i}', norm_layer)
|
||||
|
||||
def train(self, mode=True):
|
||||
super().train(mode)
|
||||
self._freeze_stages()
|
||||
if mode and self.norm_eval:
|
||||
for m in self.modules():
|
||||
# trick: eval have effect on BatchNorm only
|
||||
if isinstance(m, _BatchNorm):
|
||||
m.eval()
|
||||
|
||||
def _freeze_stages(self):
|
||||
if self.frozen_stages >= 0:
|
||||
self.patch_embed.eval()
|
||||
for param in self.patch_embed.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
for i in range(0, self.frozen_stages + 1):
|
||||
m = self.stages[i]
|
||||
m.eval()
|
||||
for param in m.parameters():
|
||||
param.requires_grad = False
|
||||
for i in self.out_indices:
|
||||
if i <= self.frozen_stages:
|
||||
for param in getattr(self, f'norm{i}').parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, x):
|
||||
x, hw_shape = self.patch_embed(x)
|
||||
|
||||
outs = []
|
||||
for i, stage in enumerate(self.stages):
|
||||
x, hw_shape = stage(
|
||||
x, hw_shape, do_downsample=self.out_after_downsample)
|
||||
if i in self.out_indices:
|
||||
norm_layer = getattr(self, f'norm{i}')
|
||||
out = norm_layer(x)
|
||||
out = out.view(-1, *hw_shape,
|
||||
self.num_features[i]).permute(0, 3, 1,
|
||||
2).contiguous()
|
||||
outs.append(out)
|
||||
if stage.downsample is not None and not self.out_after_downsample:
|
||||
x, hw_shape = stage.downsample(x, hw_shape)
|
||||
|
||||
return tuple(outs)
|
|
@ -37,3 +37,4 @@ Import:
|
|||
- configs/deit3/metafile.yml
|
||||
- configs/hornet/metafile.yml
|
||||
- configs/mobilevit/metafile.yml
|
||||
- configs/davit/metafile.yml
|
||||
|
|
|
@ -0,0 +1,110 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from copy import deepcopy
|
||||
from unittest import TestCase
|
||||
|
||||
import torch
|
||||
|
||||
from mmcls.models.backbones import DaViT
|
||||
from mmcls.models.backbones.davit import SpatialBlock
|
||||
|
||||
|
||||
class TestDaViT(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.cfg = dict(arch='t', patch_size=4, drop_path_rate=0.1)
|
||||
|
||||
def test_structure(self):
|
||||
# Test invalid default arch
|
||||
with self.assertRaisesRegex(AssertionError, 'not in default archs'):
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['arch'] = 'unknown'
|
||||
DaViT(**cfg)
|
||||
|
||||
# Test invalid custom arch
|
||||
with self.assertRaisesRegex(AssertionError, 'Custom arch needs'):
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['arch'] = {
|
||||
'num_layers': 24,
|
||||
'num_heads': 16,
|
||||
'feedforward_channels': 4096
|
||||
}
|
||||
DaViT(**cfg)
|
||||
|
||||
# Test custom arch
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['arch'] = {
|
||||
'embed_dims': 64,
|
||||
'num_heads': [3, 3, 3, 3],
|
||||
'depths': [1, 1, 2, 1]
|
||||
}
|
||||
model = DaViT(**cfg)
|
||||
self.assertEqual(model.embed_dims, 64)
|
||||
self.assertEqual(model.num_layers, 4)
|
||||
for layer in model.stages:
|
||||
self.assertEqual(
|
||||
layer.blocks[0].spatial_block.attn.w_msa.num_heads, 3)
|
||||
|
||||
def test_init_weights(self):
|
||||
# test weight init cfg
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['init_cfg'] = [
|
||||
dict(
|
||||
type='Kaiming',
|
||||
layer='Conv2d',
|
||||
mode='fan_in',
|
||||
nonlinearity='linear')
|
||||
]
|
||||
model = DaViT(**cfg)
|
||||
ori_weight = model.patch_embed.projection.weight.clone().detach()
|
||||
|
||||
model.init_weights()
|
||||
initialized_weight = model.patch_embed.projection.weight
|
||||
self.assertFalse(torch.allclose(ori_weight, initialized_weight))
|
||||
|
||||
def test_forward(self):
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
|
||||
cfg = deepcopy(self.cfg)
|
||||
model = DaViT(**cfg)
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 1)
|
||||
self.assertEqual(outs[0].shape, (1, 768, 7, 7))
|
||||
|
||||
# Test forward with multi out indices
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['out_indices'] = [2, 3]
|
||||
model = DaViT(**cfg)
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 2)
|
||||
self.assertEqual(outs[0].shape, (1, 384, 14, 14))
|
||||
self.assertEqual(outs[1].shape, (1, 768, 7, 7))
|
||||
|
||||
# test with checkpoint forward
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['with_cp'] = True
|
||||
model = DaViT(**cfg)
|
||||
for m in model.modules():
|
||||
if isinstance(m, SpatialBlock):
|
||||
self.assertTrue(m.with_cp)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 1)
|
||||
self.assertEqual(outs[0].shape, (1, 768, 7, 7))
|
||||
|
||||
# Test forward with dynamic input size
|
||||
imgs1 = torch.randn(1, 3, 224, 224)
|
||||
imgs2 = torch.randn(1, 3, 256, 256)
|
||||
imgs3 = torch.randn(1, 3, 256, 309)
|
||||
cfg = deepcopy(self.cfg)
|
||||
model = DaViT(**cfg)
|
||||
for imgs in [imgs1, imgs2, imgs3]:
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 1)
|
||||
expect_feat_shape = (imgs.shape[2] // 32, imgs.shape[3] // 32)
|
||||
self.assertEqual(outs[0].shape, (1, 768, *expect_feat_shape))
|
|
@ -0,0 +1,86 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import os.path as osp
|
||||
from collections import OrderedDict
|
||||
|
||||
import mmengine
|
||||
import torch
|
||||
from mmengine.runner import CheckpointLoader
|
||||
|
||||
|
||||
def convert_davit(ckpt):
|
||||
|
||||
new_ckpt = OrderedDict()
|
||||
|
||||
for k, v in list(ckpt.items()):
|
||||
new_v = v
|
||||
if k.startswith('patch_embeds.0'):
|
||||
new_k = k.replace('patch_embeds.0', 'patch_embed')
|
||||
new_k = new_k.replace('proj', 'projection')
|
||||
elif k.startswith('patch_embeds'):
|
||||
if k.startswith('patch_embeds.1'):
|
||||
new_k = k.replace('patch_embeds.1', 'stages.0.downsample')
|
||||
elif k.startswith('patch_embeds.2'):
|
||||
new_k = k.replace('patch_embeds.2', 'stages.1.downsample')
|
||||
elif k.startswith('patch_embeds.3'):
|
||||
new_k = k.replace('patch_embeds.3', 'stages.2.downsample')
|
||||
new_k = new_k.replace('proj', 'projection')
|
||||
elif k.startswith('main_blocks'):
|
||||
new_k = k.replace('main_blocks', 'stages')
|
||||
for num_stages in range(4):
|
||||
for num_blocks in range(9):
|
||||
if f'{num_stages}.{num_blocks}.0' in k:
|
||||
new_k = new_k.replace(
|
||||
f'{num_stages}.{num_blocks}.0',
|
||||
f'{num_stages}.blocks.{num_blocks}.spatial_block')
|
||||
elif f'{num_stages}.{num_blocks}.1' in k:
|
||||
new_k = new_k.replace(
|
||||
f'{num_stages}.{num_blocks}.1',
|
||||
f'{num_stages}.blocks.{num_blocks}.channel_block')
|
||||
if 'cpe.0' in k:
|
||||
new_k = new_k.replace('cpe.0', 'cpe1')
|
||||
elif 'cpe.1' in k:
|
||||
new_k = new_k.replace('cpe.1', 'cpe2')
|
||||
if 'mlp' in k:
|
||||
new_k = new_k.replace('mlp.fc1', 'ffn.layers.0.0')
|
||||
new_k = new_k.replace('mlp.fc2', 'ffn.layers.1')
|
||||
if 'spatial_block.attn' in new_k:
|
||||
new_k = new_k.replace('spatial_block.attn',
|
||||
'spatial_block.attn.w_msa')
|
||||
elif k.startswith('norms'):
|
||||
new_k = k.replace('norms', 'norm3')
|
||||
elif k.startswith('head'):
|
||||
new_k = k.replace('head', 'head.fc')
|
||||
else:
|
||||
new_k = k
|
||||
|
||||
if not new_k.startswith('head'):
|
||||
new_k = 'backbone.' + new_k
|
||||
new_ckpt[new_k] = new_v
|
||||
return new_ckpt
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Convert keys in pretrained van models to mmcls style.')
|
||||
parser.add_argument('src', help='src model path or url')
|
||||
# The dst path must be a full path of the new checkpoint.
|
||||
parser.add_argument('dst', help='save path')
|
||||
args = parser.parse_args()
|
||||
|
||||
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
|
||||
|
||||
if 'state_dict' in checkpoint:
|
||||
state_dict = checkpoint['state_dict']
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
|
||||
weight = convert_davit(state_dict)
|
||||
mmengine.mkdir_or_exist(osp.dirname(args.dst))
|
||||
torch.save(weight, args.dst)
|
||||
|
||||
print('Done!!')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
Reference in New Issue