mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
add BEIT3
This commit is contained in:
parent
c8c4f256b8
commit
7aeebf20e2
@ -1,4 +1,5 @@
|
||||
from .beit import *
|
||||
from .beit3 import *
|
||||
from .byoanet import *
|
||||
from .byobnet import *
|
||||
from .cait import *
|
||||
|
492
timm/models/beit3.py
Normal file
492
timm/models/beit3.py
Normal file
@ -0,0 +1,492 @@
|
||||
""" BEiT3
|
||||
Paper: `Image as a Foreign Language: BEiT Pretraining for Vision and Vision-Language Tasks`
|
||||
- https://arxiv.org/abs/2208.10442
|
||||
- https://openaccess.thecvf.com/content/CVPR2023/papers/Wang_Image_as_a_Foreign_Language_BEiT_Pretraining_for_Vision_and_CVPR_2023_paper.pdf
|
||||
|
||||
Model from official source:
|
||||
- https://github.com/microsoft/unilm/tree/master/beit3
|
||||
- https://github.com/microsoft/torchscale/blob/main/torchscale/model/BEiT3.py
|
||||
|
||||
@inproceedings{beit3,
|
||||
title={Image as a foreign language: {BEiT} pretraining for vision and vision-language tasks},
|
||||
author={Wenhui Wang and Hangbo Bao and Li Dong and Johan Bjorck and Zhiliang Peng and Qiang Liu and Kriti Aggarwal and Owais Khan Mohammed and Saksham Singhal and Subhojit Som and Furu Wei},
|
||||
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
|
||||
year={2023}
|
||||
}
|
||||
@InProceedings{Wang_2023_CVPR,
|
||||
author = {Wang, Wenhui and Bao, Hangbo and Dong, Li and Bjorck, Johan and Peng, Zhiliang and Liu, Qiang and Aggarwal, Kriti and Mohammed, Owais Khan and Singhal, Saksham and Som, Subhojit and Wei, Furu},
|
||||
title = {Image as a Foreign Language: BEiT Pretraining for Vision and Vision-Language Tasks},
|
||||
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
|
||||
month = {June},
|
||||
year = {2023},
|
||||
pages = {19175-19186}
|
||||
}
|
||||
|
||||
Original implementation by Wenhui Wang et al.,
|
||||
adapted for timm by Ryan Hou and Ross Wightman.
|
||||
|
||||
At this point only the 1k fine-tuned classification weights and model configs have been added,
|
||||
see original source above for pre-training models and procedure.
|
||||
|
||||
Adapted from https://github.com/microsoft/torchscale/blob/main/torchscale/model/BEiT3.py, original copyright below
|
||||
"""
|
||||
# Copyright (c) 2022 Microsoft
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
|
||||
# --------------------------------------------------------
|
||||
# Image as a Foreign Language: BEiT Pretraining for Vision and Vision-Language Tasks (https://arxiv.org/abs/2208.10442)
|
||||
# Github source: https://github.com/microsoft/unilm/tree/master/beit3
|
||||
# Copyright (c) 2023 Microsoft
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# --------------------------------------------------------'
|
||||
|
||||
import math
|
||||
from functools import partial
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import PatchEmbed, Mlp, LayerNorm, DropPath, trunc_normal_, LayerType
|
||||
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import feature_take_indices
|
||||
from ._manipulate import checkpoint
|
||||
from ._registry import generate_default_cfgs, register_model
|
||||
|
||||
__all__ = ['BEiT3']
|
||||
|
||||
|
||||
class PositionalEmbedding(nn.Embedding):
|
||||
"""
|
||||
Reference from:
|
||||
https://github.com/microsoft/torchscale/blob/main/torchscale/component/embedding.py#L99-L119
|
||||
"""
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return F.embedding(
|
||||
torch.arange(2, self.num_embeddings).long().unsqueeze(0).to(x.device),
|
||||
self.weight,
|
||||
self.padding_idx,
|
||||
self.max_norm,
|
||||
self.norm_type,
|
||||
self.scale_grad_by_freq,
|
||||
self.sparse,
|
||||
)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""
|
||||
Reference from:
|
||||
https://github.com/microsoft/torchscale/blob/main/torchscale/component/multihead_attention.py#L20-L171
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
drop_rate: float = 0.,
|
||||
norm_layer: LayerType = partial(LayerNorm, eps=1e-5)
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scaling = self.head_dim ** -0.5
|
||||
|
||||
self.k_proj = nn.Linear(dim, dim)
|
||||
self.v_proj = nn.Linear(dim, dim)
|
||||
self.q_proj = nn.Linear(dim, dim)
|
||||
self.out_proj = nn.Linear(dim, dim)
|
||||
self.inner_attn_ln = norm_layer(dim)
|
||||
self.attn_drop = nn.Dropout(drop_rate)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
B, N, C = x.shape
|
||||
|
||||
q = self.q_proj(x)
|
||||
k = self.k_proj(x)
|
||||
v = self.v_proj(x)
|
||||
q *= self.scaling
|
||||
|
||||
q = q.view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
k = k.view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
v = v.view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
q = q.reshape(B * self.num_heads, N, self.head_dim)
|
||||
k = k.reshape(B * self.num_heads, N, self.head_dim)
|
||||
v = v.reshape(B * self.num_heads, N, self.head_dim)
|
||||
|
||||
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
||||
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as(
|
||||
attn_weights
|
||||
)
|
||||
attn_probs = self.attn_drop(attn_weights)
|
||||
|
||||
attn = torch.bmm(attn_probs, v)
|
||||
attn = attn.transpose(0, 1).reshape(N, B, C).transpose(0, 1)
|
||||
attn = self.inner_attn_ln(attn)
|
||||
attn = self.out_proj(attn)
|
||||
return attn
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.,
|
||||
drop_rate: float = 0.,
|
||||
drop_path: float = 0.,
|
||||
attn_drop: float = 0.,
|
||||
act_layer: LayerType = nn.GELU,
|
||||
norm_layer: LayerType = partial(LayerNorm, eps=1e-5),
|
||||
):
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = Attention(dim, num_heads=num_heads, drop_rate=attn_drop, norm_layer=norm_layer)
|
||||
self.attn_drop = nn.Dropout(drop_rate)
|
||||
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
|
||||
self.norm2 = norm_layer(dim)
|
||||
self.mlp = Mlp(
|
||||
in_features=dim,
|
||||
hidden_features=int(dim * mlp_ratio),
|
||||
act_layer=act_layer,
|
||||
norm_layer=norm_layer,
|
||||
drop=drop_rate
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = x + self.drop_path(self.attn_drop(self.attn(self.norm1(x))))
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class BEiT3(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
img_size: Union[int, Tuple[int, int]] = 224,
|
||||
patch_size: Union[int, Tuple[int, int]] = 16,
|
||||
in_chans: int = 3,
|
||||
num_classes: int = 1000,
|
||||
global_pool: str = 'avg',
|
||||
embed_dim: int = 768,
|
||||
depth: int = 12,
|
||||
num_heads: int = 12,
|
||||
mlp_ratio: float = 4.,
|
||||
drop_rate: float = 0.,
|
||||
attn_drop_rate: float = 0.,
|
||||
drop_path_rate: float = 0.,
|
||||
act_layer: LayerType = nn.GELU,
|
||||
norm_layer: LayerType = partial(LayerNorm, eps=1e-5),
|
||||
head_init_scale: float = 0.001,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
self.global_pool = global_pool
|
||||
self.num_features = self.head_hidden_size = self.embed_dim = embed_dim # for consistency with other models
|
||||
self.num_prefix_tokens = 1
|
||||
self.grad_checkpointing = False
|
||||
|
||||
# vision_embed
|
||||
self.patch_embed = PatchEmbed(
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
embed_dim=embed_dim,
|
||||
)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
r = self.patch_embed.feat_ratio() if hasattr(self.patch_embed, 'feat_ratio') else patch_size
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
||||
|
||||
# encoder
|
||||
self.pos_embed = PositionalEmbedding(num_patches + 3, embed_dim)
|
||||
self.pos_drop = nn.Dropout(drop_rate)
|
||||
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
||||
self.blocks = nn.ModuleList([
|
||||
Block(
|
||||
dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
drop_rate=drop_rate,
|
||||
drop_path=dpr[i],
|
||||
attn_drop=attn_drop_rate,
|
||||
act_layer=act_layer,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
for i in range(depth)])
|
||||
self.feature_info = [
|
||||
dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=r) for i in range(depth)]
|
||||
|
||||
# class_head
|
||||
use_fc_norm = self.global_pool == 'avg'
|
||||
self.norm = nn.Identity() if use_fc_norm else norm_layer(embed_dim)
|
||||
self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
|
||||
self.head_drop = nn.Dropout(drop_rate)
|
||||
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
self.apply(self._init_weights)
|
||||
trunc_normal_(self.cls_token, std=.02)
|
||||
|
||||
self.fix_init_weight(depth)
|
||||
if isinstance(self.head, nn.Linear):
|
||||
trunc_normal_(self.head.weight, std=.02)
|
||||
self.head.weight.data.mul_(head_init_scale)
|
||||
self.head.bias.data.mul_(head_init_scale)
|
||||
|
||||
def fix_init_weight(self, depth: int):
|
||||
init_scale = math.sqrt(math.log(depth * 2))
|
||||
for name, p in self.named_parameters():
|
||||
if (
|
||||
"fc1" in name
|
||||
or "fc2" in name
|
||||
or "out_proj" in name
|
||||
or "v_proj" in name
|
||||
):
|
||||
p.data.mul_(init_scale)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay(self) -> Set:
|
||||
return {'pos_embed', 'cls_token'}
|
||||
|
||||
@torch.jit.ignore
|
||||
def set_grad_checkpointing(self, enable: bool = True):
|
||||
self.grad_checkpointing = enable
|
||||
|
||||
@torch.jit.ignore
|
||||
def group_matcher(self, coarse: bool = False) -> Dict:
|
||||
matcher = dict(
|
||||
stem=r'^cls_token|pos_embed|patch_embed', # stem and embed
|
||||
blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))],
|
||||
)
|
||||
return matcher
|
||||
|
||||
@torch.jit.ignore
|
||||
def get_classifier(self) -> nn.Module:
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.num_classes = num_classes
|
||||
if global_pool is not None:
|
||||
self.global_pool = global_pool
|
||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Optional[Union[int, List[int]]] = None,
|
||||
return_prefix_tokens: bool = False,
|
||||
norm: bool = False,
|
||||
stop_early: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
intermediates_only: bool = False,
|
||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
""" Forward features that returns intermediates.
|
||||
|
||||
Args:
|
||||
x: Input image tensor
|
||||
indices: Take last n blocks if an int, if is a sequence, select by matching indices
|
||||
return_prefix_tokens: Return both prefix and spatial intermediate tokens
|
||||
norm: Apply norm layer to all intermediates
|
||||
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||
output_fmt: Shape of intermediate feature outputs
|
||||
intermediates_only: Only return intermediate features
|
||||
Returns:
|
||||
|
||||
"""
|
||||
assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.'
|
||||
reshape = output_fmt == 'NCHW'
|
||||
intermediates = []
|
||||
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
|
||||
|
||||
# forward pass
|
||||
B, _, height, width = x.shape
|
||||
|
||||
x = self.patch_embed(x)
|
||||
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
||||
x = x + self.pos_embed(x)
|
||||
x = self.pos_drop(x)
|
||||
|
||||
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||
blocks = self.blocks
|
||||
else:
|
||||
blocks = self.blocks[:max_index + 1]
|
||||
|
||||
for i, blk in enumerate(blocks):
|
||||
x = blk(x)
|
||||
if i in take_indices:
|
||||
# normalize intermediates with final norm layer if enabled
|
||||
intermediates.append(self.norm(x) if norm else x)
|
||||
|
||||
# process intermediates
|
||||
if self.num_prefix_tokens:
|
||||
# split prefix (e.g. class, distill) and spatial feature tokens
|
||||
prefix_tokens = [y[:, 0:self.num_prefix_tokens] for y in intermediates]
|
||||
intermediates = [y[:, self.num_prefix_tokens:] for y in intermediates]
|
||||
|
||||
if reshape:
|
||||
# reshape to BCHW output format
|
||||
H, W = self.patch_embed.dynamic_feat_size((height, width))
|
||||
intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates]
|
||||
if not torch.jit.is_scripting() and return_prefix_tokens:
|
||||
# return_prefix not support in torchscript due to poor type handling
|
||||
intermediates = list(zip(intermediates, prefix_tokens))
|
||||
|
||||
if intermediates_only:
|
||||
return intermediates
|
||||
|
||||
x = self.norm(x)
|
||||
|
||||
return x, intermediates
|
||||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
indices: Union[int, List[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
""" Prune layers not required for specified intermediates.
|
||||
"""
|
||||
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
|
||||
self.blocks = self.blocks[:max_index + 1] # truncate blocks
|
||||
if prune_norm:
|
||||
self.norm = nn.Identity()
|
||||
if prune_head:
|
||||
self.fc_norm = nn.Identity()
|
||||
self.reset_classifier(0, '')
|
||||
return take_indices
|
||||
|
||||
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.patch_embed(x)
|
||||
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
||||
x = x + self.pos_embed(x)
|
||||
x = self.pos_drop(x)
|
||||
|
||||
for blk in self.blocks:
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
x = checkpoint(blk, x)
|
||||
else:
|
||||
x = blk(x)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
|
||||
if self.global_pool:
|
||||
x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
|
||||
x = self.fc_norm(x)
|
||||
x = self.head_drop(x)
|
||||
return x if pre_logits else self.head(x)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.forward_features(x)
|
||||
x = self.forward_head(x)
|
||||
return x
|
||||
|
||||
|
||||
def _cfg(url: str = '', **kwargs: Any) -> Dict[str, Any]:
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
||||
'crop_pct': 1.0, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'patch_embed.proj', 'classifier': 'head',
|
||||
'paper_ids': 'arXiv:2208.10442',
|
||||
'paper_name': 'Image as a Foreign Language: BEiT Pretraining for Vision and Vision-Language Tasks',
|
||||
'origin_url': 'https://github.com/microsoft/unilm/tree/master/beit3',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = generate_default_cfgs({
|
||||
'beit3_base_patch16_224.in1k': _cfg(
|
||||
url='https://github.com/addf400/files/releases/download/beit3/beit3_base_patch16_224_in1k.pth',
|
||||
# hf_hub_id='timm/',
|
||||
),
|
||||
'beit3_base_patch16_224.indomain_in1k': _cfg(
|
||||
url='https://github.com/addf400/files/releases/download/beit3/beit3_base_indomain_patch16_224_in1k.pth',
|
||||
# hf_hub_id='timm/',
|
||||
),
|
||||
'beit3_large_patch16_224.in1k': _cfg(
|
||||
url='https://github.com/addf400/files/releases/download/beit3/beit3_large_patch16_224_in1k.pth',
|
||||
# hf_hub_id='timm/',
|
||||
),
|
||||
'beit3_large_patch16_224.indomain_in1k': _cfg(
|
||||
url='https://github.com/addf400/files/releases/download/beit3/beit3_large_indomain_patch16_224_in1k.pth',
|
||||
# hf_hub_id='timm/',
|
||||
),
|
||||
})
|
||||
|
||||
|
||||
def checkpoint_filter_fn(
|
||||
state_dict: Dict[str, torch.Tensor],
|
||||
model: BEiT3,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
if 'model' in state_dict:
|
||||
state_dict = state_dict['model']
|
||||
|
||||
if 'patch_embed.proj.weight' in state_dict:
|
||||
return state_dict
|
||||
|
||||
state_dict.pop('beit3.text_embed.weight')
|
||||
state_dict.pop('beit3.vision_embed.mask_token')
|
||||
|
||||
out_dict = {}
|
||||
|
||||
for k, v in state_dict.items():
|
||||
if '.B.' in k:
|
||||
continue
|
||||
elif 'vision_embed.cls_token' in k:
|
||||
k = 'cls_token'
|
||||
else:
|
||||
k = k.replace('beit3.', '')
|
||||
k = k.replace('embed_positions.', 'pos_embed.')
|
||||
k = k.replace('vision_embed.', 'patch_embed.')
|
||||
k = k.replace('encoder.', '')
|
||||
k = k.replace('layers.', 'blocks.')
|
||||
k = k.replace('ffn.', 'mlp.')
|
||||
k = k.replace('ffn_layernorm.', 'norm.')
|
||||
k = k.replace('self_attn.', 'attn.')
|
||||
k = k.replace('self_attn_layer_norm.', 'norm1.')
|
||||
k = k.replace('final_layer_norm.', 'norm2.')
|
||||
k = k.replace('A.', '')
|
||||
|
||||
out_dict[k] = v
|
||||
|
||||
return out_dict
|
||||
|
||||
|
||||
def _create_beit3(variant: str, pretrained: bool, **kwargs: Any) -> BEiT3:
|
||||
out_indices = kwargs.pop('out_indices', 3)
|
||||
model = build_model_with_cfg(
|
||||
BEiT3, variant, pretrained,
|
||||
pretrained_filter_fn=checkpoint_filter_fn,
|
||||
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
|
||||
**kwargs,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def beit3_base_patch16_224(pretrained: bool = False, **kwargs: Any) -> BEiT3:
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4)
|
||||
model = _create_beit3('beit3_base_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def beit3_large_patch16_224(pretrained: bool = False, **kwargs: Any) -> BEiT3:
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4)
|
||||
model = _create_beit3('beit3_large_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
Loading…
x
Reference in New Issue
Block a user