Tweak DinoV2 add, add MAE ViT weights, add initial intermediate layer getter experiment

This commit is contained in:
Ross Wightman 2023-05-09 17:59:22 -07:00
parent 59bea4c306
commit a01d8f86f4
6 changed files with 222 additions and 50 deletions

View File

@ -34,7 +34,7 @@ from .norm_act import BatchNormAct2d, GroupNormAct, GroupNorm1Act, LayerNormAct,
SyncBatchNormAct, convert_sync_batchnorm, FrozenBatchNormAct2d, freeze_batch_norm_2d, unfreeze_batch_norm_2d SyncBatchNormAct, convert_sync_batchnorm, FrozenBatchNormAct2d, freeze_batch_norm_2d, unfreeze_batch_norm_2d
from .padding import get_padding, get_same_padding, pad_same from .padding import get_padding, get_same_padding, pad_same
from .patch_dropout import PatchDropout from .patch_dropout import PatchDropout
from .patch_embed import PatchEmbed, resample_patch_embed from .patch_embed import PatchEmbed, PatchEmbedWithSize, resample_patch_embed
from .pool2d_same import AvgPool2dSame, create_pool2d from .pool2d_same import AvgPool2dSame, create_pool2d
from .pos_embed import resample_abs_pos_embed from .pos_embed import resample_abs_pos_embed
from .pos_embed_rel import RelPosMlp, RelPosBias, RelPosBiasTf, gen_relative_position_index, gen_relative_log_coords from .pos_embed_rel import RelPosMlp, RelPosBias, RelPosBiasTf, gen_relative_position_index, gen_relative_log_coords

View File

@ -9,7 +9,7 @@ Based on code in:
Hacked together by / Copyright 2020 Ross Wightman Hacked together by / Copyright 2020 Ross Wightman
""" """
import logging import logging
from typing import List, Optional, Callable from typing import Callable, List, Optional, Tuple, Union
import torch import torch
from torch import nn as nn from torch import nn as nn
@ -75,6 +75,49 @@ class PatchEmbed(nn.Module):
return x return x
class PatchEmbedWithSize(PatchEmbed):
""" 2D Image to Patch Embedding
"""
output_fmt: Format
def __init__(
self,
img_size: Optional[int] = 224,
patch_size: int = 16,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer: Optional[Callable] = None,
flatten: bool = True,
output_fmt: Optional[str] = None,
bias: bool = True,
):
super().__init__(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
norm_layer=norm_layer,
flatten=flatten,
output_fmt=output_fmt,
bias=bias,
)
def forward(self, x) -> Tuple[torch.Tensor, List[int]]:
B, C, H, W = x.shape
if self.img_size is not None:
_assert(H % self.patch_size[0] == 0, f"Input image height ({H}) must be divisible by patch size ({self.patch_size[0]}).")
_assert(W % self.patch_size[1] == 0, f"Input image width ({W}) must be divisible by patch size ({self.patch_size[1]}).")
x = self.proj(x)
grid_size = x.shape[-2:]
if self.flatten:
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
elif self.output_fmt != Format.NCHW:
x = nchw_to(x, self.output_fmt)
x = self.norm(x)
return x, grid_size
def resample_patch_embed( def resample_patch_embed(
patch_embed, patch_embed,
new_size: List[int], new_size: List[int],

View File

@ -24,29 +24,31 @@ def resample_abs_pos_embed(
verbose: bool = False, verbose: bool = False,
): ):
# sort out sizes, assume square if old size not provided # sort out sizes, assume square if old size not provided
new_size = to_2tuple(new_size) num_pos_tokens = posemb.shape[1]
new_ntok = new_size[0] * new_size[1] num_new_tokens = new_size[0] * new_size[1] + num_prefix_tokens
if not old_size: if num_new_tokens == num_pos_tokens and new_size[0] == new_size[1]:
old_size = int(math.sqrt(posemb.shape[1] - num_prefix_tokens))
old_size = to_2tuple(old_size)
if new_size == old_size: # might not both be same container type
return posemb return posemb
if not old_size:
hw = int(math.sqrt(num_pos_tokens - num_prefix_tokens))
old_size = hw, hw
if num_prefix_tokens: if num_prefix_tokens:
posemb_prefix, posemb = posemb[:, :num_prefix_tokens], posemb[:, num_prefix_tokens:] posemb_prefix, posemb = posemb[:, :num_prefix_tokens], posemb[:, num_prefix_tokens:]
else: else:
posemb_prefix, posemb = None, posemb posemb_prefix, posemb = None, posemb
# do the interpolation # do the interpolation
embed_dim = posemb.shape[-1]
posemb = posemb.reshape(1, old_size[0], old_size[1], -1).permute(0, 3, 1, 2) posemb = posemb.reshape(1, old_size[0], old_size[1], -1).permute(0, 3, 1, 2)
posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias) posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias)
posemb = posemb.permute(0, 2, 3, 1).reshape(1, new_ntok, -1) posemb = posemb.permute(0, 2, 3, 1).reshape(1, -1, embed_dim)
if verbose:
_logger.info(f'Resized position embedding: {old_size} to {new_size}.')
# add back extra (class, etc) prefix tokens # add back extra (class, etc) prefix tokens
if posemb_prefix is not None: if posemb_prefix is not None:
print(posemb_prefix.shape, posemb.shape)
posemb = torch.cat([posemb_prefix, posemb], dim=1) posemb = torch.cat([posemb_prefix, posemb], dim=1)
if not torch.jit.is_scripting() and verbose:
_logger.info(f'Resized position embedding: {old_size} to {new_size}.')
return posemb return posemb

View File

@ -11,11 +11,13 @@ Modifications copyright 2021, Ross Wightman
# Copyright (c) 2015-present, Facebook, Inc. # Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved. # All rights reserved.
from functools import partial from functools import partial
from typing import Sequence, Union
import torch import torch
from torch import nn as nn from torch import nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import resample_abs_pos_embed
from timm.models.vision_transformer import VisionTransformer, trunc_normal_, checkpoint_filter_fn from timm.models.vision_transformer import VisionTransformer, trunc_normal_, checkpoint_filter_fn
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq from ._manipulate import checkpoint_seq
@ -71,11 +73,37 @@ class VisionTransformerDistilled(VisionTransformer):
def set_distilled_training(self, enable=True): def set_distilled_training(self, enable=True):
self.distilled_training = enable self.distilled_training = enable
def _intermediate_layers(
self,
x: torch.Tensor,
n: Union[int, Sequence] = 1,
):
outputs, num_blocks = [], len(self.blocks)
take_indices = set(range(num_blocks - n, num_blocks) if isinstance(n, int) else n)
# forward pass
x = self.patch_embed(x)
x = torch.cat((
self.cls_token.expand(x.shape[0], -1, -1),
self.dist_token.expand(x.shape[0], -1, -1),
x),
dim=1)
x = self.pos_drop(x + self.pos_embed)
x = self.patch_drop(x)
x = self.norm_pre(x)
for i, blk in enumerate(self.blocks):
x = blk(x)
if i in take_indices:
outputs.append(x)
return outputs
def forward_features(self, x) -> torch.Tensor: def forward_features(self, x) -> torch.Tensor:
x = self.patch_embed(x) x = self.patch_embed(x)
x = torch.cat(( x = torch.cat((
self.cls_token.expand(x.shape[0], -1, -1), self.cls_token.expand(x.shape[0], -1, -1),
self.dist_token.expand(x.shape[0], -1, -1), x), self.dist_token.expand(x.shape[0], -1, -1),
x),
dim=1) dim=1)
x = self.pos_drop(x + self.pos_embed) x = self.pos_drop(x + self.pos_embed)
if self.grad_checkpointing and not torch.jit.is_scripting(): if self.grad_checkpointing and not torch.jit.is_scripting():

View File

@ -27,7 +27,7 @@ import logging
import math import math
from collections import OrderedDict from collections import OrderedDict
from functools import partial from functools import partial
from typing import Callable, List, Optional, Tuple, Union from typing import Callable, List, Optional, Sequence, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -125,7 +125,7 @@ class Block(nn.Module):
drop_path=0., drop_path=0.,
act_layer=nn.GELU, act_layer=nn.GELU,
norm_layer=nn.LayerNorm, norm_layer=nn.LayerNorm,
ffn_layer=Mlp, mlp_layer=Mlp,
): ):
super().__init__() super().__init__()
self.norm1 = norm_layer(dim) self.norm1 = norm_layer(dim)
@ -142,7 +142,7 @@ class Block(nn.Module):
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim) self.norm2 = norm_layer(dim)
self.mlp = ffn_layer( self.mlp = mlp_layer(
in_features=dim, in_features=dim,
hidden_features=int(dim * mlp_ratio), hidden_features=int(dim * mlp_ratio),
act_layer=act_layer, act_layer=act_layer,
@ -172,7 +172,7 @@ class ResPostBlock(nn.Module):
drop_path=0., drop_path=0.,
act_layer=nn.GELU, act_layer=nn.GELU,
norm_layer=nn.LayerNorm, norm_layer=nn.LayerNorm,
ffn_layer=Mlp, mlp_layer=Mlp,
): ):
super().__init__() super().__init__()
self.init_values = init_values self.init_values = init_values
@ -189,7 +189,7 @@ class ResPostBlock(nn.Module):
self.norm1 = norm_layer(dim) self.norm1 = norm_layer(dim)
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.mlp = ffn_layer( self.mlp = mlp_layer(
in_features=dim, in_features=dim,
hidden_features=int(dim * mlp_ratio), hidden_features=int(dim * mlp_ratio),
act_layer=act_layer, act_layer=act_layer,
@ -232,7 +232,7 @@ class ParallelScalingBlock(nn.Module):
drop_path=0., drop_path=0.,
act_layer=nn.GELU, act_layer=nn.GELU,
norm_layer=nn.LayerNorm, norm_layer=nn.LayerNorm,
ffn_layer=None, # NOTE: not used mlp_layer=None, # NOTE: not used
): ):
super().__init__() super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads' assert dim % num_heads == 0, 'dim should be divisible by num_heads'
@ -326,7 +326,7 @@ class ParallelThingsBlock(nn.Module):
drop_path=0., drop_path=0.,
act_layer=nn.GELU, act_layer=nn.GELU,
norm_layer=nn.LayerNorm, norm_layer=nn.LayerNorm,
ffn_layer=Mlp, mlp_layer=Mlp,
): ):
super().__init__() super().__init__()
self.num_parallel = num_parallel self.num_parallel = num_parallel
@ -349,7 +349,7 @@ class ParallelThingsBlock(nn.Module):
]))) ])))
self.ffns.append(nn.Sequential(OrderedDict([ self.ffns.append(nn.Sequential(OrderedDict([
('norm', norm_layer(dim)), ('norm', norm_layer(dim)),
('mlp', ffn_layer( ('mlp', mlp_layer(
dim, dim,
hidden_features=int(dim * mlp_ratio), hidden_features=int(dim * mlp_ratio),
act_layer=act_layer, act_layer=act_layer,
@ -413,7 +413,7 @@ class VisionTransformer(nn.Module):
norm_layer: Optional[Callable] = None, norm_layer: Optional[Callable] = None,
act_layer: Optional[Callable] = None, act_layer: Optional[Callable] = None,
block_fn: Callable = Block, block_fn: Callable = Block,
ffn_layer: Callable = Mlp, mlp_layer: Callable = Mlp,
): ):
""" """
Args: Args:
@ -435,7 +435,7 @@ class VisionTransformer(nn.Module):
attn_drop_rate: Attention dropout rate. attn_drop_rate: Attention dropout rate.
drop_path_rate: Stochastic depth rate. drop_path_rate: Stochastic depth rate.
weight_init: Weight initialization scheme. weight_init: Weight initialization scheme.
embed_layer: Patch embedding layey. embed_layer: Patch embedding layer.
norm_layer: Normalization layer. norm_layer: Normalization layer.
act_layer: MLP activation layer. act_layer: MLP activation layer.
block_fn: Transformer block layer. block_fn: Transformer block layer.
@ -490,7 +490,7 @@ class VisionTransformer(nn.Module):
drop_path=dpr[i], drop_path=dpr[i],
norm_layer=norm_layer, norm_layer=norm_layer,
act_layer=act_layer, act_layer=act_layer,
ffn_layer=ffn_layer, mlp_layer=mlp_layer,
) )
for i in range(depth)]) for i in range(depth)])
self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity() self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
@ -560,6 +560,55 @@ class VisionTransformer(nn.Module):
x = x + self.pos_embed x = x + self.pos_embed
return self.pos_drop(x) return self.pos_drop(x)
def _intermediate_layers(
self,
x: torch.Tensor,
n: Union[int, Sequence] = 1,
):
outputs, num_blocks = [], len(self.blocks)
take_indices = set(range(num_blocks - n, num_blocks) if isinstance(n, int) else n)
# forward pass
x = self.patch_embed(x)
x = self._pos_embed(x)
x = self.patch_drop(x)
x = self.norm_pre(x)
for i, blk in enumerate(self.blocks):
x = blk(x)
if i in take_indices:
outputs.append(x)
return outputs
def get_intermediate_layers(
self,
x: torch.Tensor,
n: Union[int, Sequence] = 1,
reshape: bool = False,
return_class_token: bool = False,
norm: bool = False,
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
""" Intermediate layer accessor (NOTE: This is a WIP experiment).
Inspired by DINO / DINOv2 interface
"""
# take last n blocks if n is an int, if in is a sequence, select by matching indices
outputs = self._intermediate_layers(x, n)
if norm:
outputs = [self.norm(out) for out in outputs]
class_tokens = [out[:, 0:self.num_prefix_tokens] for out in outputs]
outputs = [out[:, self.num_prefix_tokens:] for out in outputs]
if reshape:
grid_size = self.patch_embed.grid_size
outputs = [
out.reshape(x.shape[0], grid_size[0], grid_size[1], -1).permute(0, 3, 1, 2).contiguous()
for out in outputs
]
if return_class_token:
return tuple(zip(outputs, class_tokens))
return tuple(outputs)
def forward_features(self, x): def forward_features(self, x):
x = self.patch_embed(x) x = self.patch_embed(x)
x = self._pos_embed(x) x = self._pos_embed(x)
@ -816,9 +865,7 @@ def _convert_openai_clip(state_dict, model):
def _convert_dinov2(state_dict, model): def _convert_dinov2(state_dict, model):
import re import re
out_dict = {} out_dict = {}
for k, v in state_dict.items(): for k, v in state_dict.items():
if k == "mask_token": if k == "mask_token":
continue continue
@ -828,11 +875,10 @@ def _convert_dinov2(state_dict, model):
elif re.match(r"blocks\.(\d+)\.mlp\.w3\.(?:weight|bias)", k): elif re.match(r"blocks\.(\d+)\.mlp\.w3\.(?:weight|bias)", k):
out_dict[k.replace("w3", "fc2")] = v out_dict[k.replace("w3", "fc2")] = v
continue continue
out_dict[k] = v out_dict[k] = v
return out_dict return out_dict
def checkpoint_filter_fn( def checkpoint_filter_fn(
state_dict, state_dict,
model, model,
@ -1072,19 +1118,27 @@ default_cfgs = generate_default_cfgs({
hf_hub_id='timm/', hf_hub_id='timm/',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
# DINOv2 pretrained - https://arxiv.org/abs/2304.07193 (no classifier head, for fine-tune only) # DINOv2 pretrained - https://arxiv.org/abs/2304.07193 (no classifier head, for fine-tune/features only)
'vit_small_patch14_dinov2': _cfg( 'vit_small_patch14_dinov2.lvd142m': _cfg(
url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_pretrain.pth', url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_pretrain.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, input_size=(3, 518, 518)), license='cc-by-nc-4.0',
'vit_base_patch14_dinov2': _cfg( mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
input_size=(3, 518, 518), crop_pct=1.0),
'vit_base_patch14_dinov2.lvd142m': _cfg(
url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_pretrain.pth', url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_pretrain.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, input_size=(3, 518, 518)), license='cc-by-nc-4.0',
'vit_large_patch14_dinov2': _cfg( mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
input_size=(3, 518, 518), crop_pct=1.0),
'vit_large_patch14_dinov2.lvd142m': _cfg(
url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth', url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, input_size=(3, 518, 518)), license='cc-by-nc-4.0',
'vit_giant_patch14_dinov2': _cfg( mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
input_size=(3, 518, 518), crop_pct=1.0),
'vit_giant_patch14_dinov2.lvd142m': _cfg(
url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_pretrain.pth', url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_pretrain.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, input_size=(3, 518, 518)), license='cc-by-nc-4.0',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
input_size=(3, 518, 518), crop_pct=1.0),
# ViT ImageNet-21K-P pretraining by MILL # ViT ImageNet-21K-P pretraining by MILL
'vit_base_patch16_224_miil.in21k': _cfg( 'vit_base_patch16_224_miil.in21k': _cfg(
@ -1359,6 +1413,22 @@ default_cfgs = generate_default_cfgs({
'vit_base_patch16_xp_224.untrained': _cfg(url=''), 'vit_base_patch16_xp_224.untrained': _cfg(url=''),
'vit_large_patch14_xp_224.untrained': _cfg(url=''), 'vit_large_patch14_xp_224.untrained': _cfg(url=''),
'vit_huge_patch14_xp_224.untrained': _cfg(url=''), 'vit_huge_patch14_xp_224.untrained': _cfg(url=''),
'vit_base_patch16_224.mae': _cfg(
url='https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_base.pth',
#hf_hub_id='timm/',
license='cc-by-nc-4.0',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
'vit_large_patch16_224.mae': _cfg(
url='https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_large.pth',
# hf_hub_id='timm/',
license='cc-by-nc-4.0',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
'vit_huge_patch14_224.mae': _cfg(
url='https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_huge.pth',
# hf_hub_id='timm/',
license='cc-by-nc-4.0',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
}) })
@ -1904,10 +1974,8 @@ def vit_small_patch14_dinov2(pretrained=False, **kwargs):
""" ViT-S/14 for DINOv2 """ ViT-S/14 for DINOv2
""" """
model_args = dict( model_args = dict(
patch_size=14, embed_dim=384, depth=12, num_heads=6, patch_size=14, embed_dim=384, depth=12, num_heads=6, init_values=1.0, img_size=518,
init_values=1.0, img_size=518,
) )
model = _create_vision_transformer( model = _create_vision_transformer(
'vit_small_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs)) 'vit_small_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
return model return model
@ -1918,10 +1986,8 @@ def vit_base_patch14_dinov2(pretrained=False, **kwargs):
""" ViT-B/14 for DINOv2 """ ViT-B/14 for DINOv2
""" """
model_args = dict( model_args = dict(
patch_size=14, embed_dim=768, depth=12, num_heads=12, patch_size=14, embed_dim=768, depth=12, num_heads=12, init_values=1.0, img_size=518,
init_values=1.0, img_size=518,
) )
model = _create_vision_transformer( model = _create_vision_transformer(
'vit_base_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs)) 'vit_base_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
return model return model
@ -1932,14 +1998,13 @@ def vit_large_patch14_dinov2(pretrained=False, **kwargs):
""" ViT-L/14 for DINOv2 """ ViT-L/14 for DINOv2
""" """
model_args = dict( model_args = dict(
patch_size=14, embed_dim=1024, depth=24, num_heads=16, patch_size=14, embed_dim=1024, depth=24, num_heads=16, init_values=1.0, img_size=518,
init_values=1.0, img_size=518,
) )
model = _create_vision_transformer( model = _create_vision_transformer(
'vit_large_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs)) 'vit_large_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
return model return model
@register_model @register_model
def vit_giant_patch14_dinov2(pretrained=False, **kwargs): def vit_giant_patch14_dinov2(pretrained=False, **kwargs):
""" ViT-G/14 for DINOv2 """ ViT-G/14 for DINOv2
@ -1952,13 +2017,13 @@ def vit_giant_patch14_dinov2(pretrained=False, **kwargs):
model_args = dict( model_args = dict(
patch_size=14, embed_dim=1536, depth=40, num_heads=24, init_values=1.0, patch_size=14, embed_dim=1536, depth=40, num_heads=24, init_values=1.0,
mlp_ratio=2.66667 * 2, ffn_layer=SwiGLUPacked, img_size=518, act_layer=nn.SiLU mlp_ratio=2.66667 * 2, mlp_layer=SwiGLUPacked, img_size=518, act_layer=nn.SiLU
) )
model = _create_vision_transformer( model = _create_vision_transformer(
'vit_giant_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs)) 'vit_giant_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
return model return model
register_model_deprecations(__name__, { register_model_deprecations(__name__, {
'vit_tiny_patch16_224_in21k': 'vit_tiny_patch16_224.augreg_in21k', 'vit_tiny_patch16_224_in21k': 'vit_tiny_patch16_224.augreg_in21k',
'vit_small_patch32_224_in21k': 'vit_small_patch32_224.augreg_in21k', 'vit_small_patch32_224_in21k': 'vit_small_patch32_224.augreg_in21k',

View File

@ -14,6 +14,7 @@ They were moved here to keep file sizes sane.
Hacked together by / Copyright 2020, Ross Wightman Hacked together by / Copyright 2020, Ross Wightman
""" """
from functools import partial from functools import partial
from typing import List, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -74,10 +75,43 @@ class HybridEmbed(nn.Module):
x = self.backbone(x) x = self.backbone(x)
if isinstance(x, (list, tuple)): if isinstance(x, (list, tuple)):
x = x[-1] # last feature if backbone outputs list/tuple of features x = x[-1] # last feature if backbone outputs list/tuple of features
x = self.proj(x).flatten(2).transpose(1, 2) x = self.proj(x)
x = x.flatten(2).transpose(1, 2)
return x return x
class HybridEmbedWithSize(nn.Module):
""" CNN Feature Map Embedding
Extract feature map from CNN, flatten, project to embedding dim.
"""
def __init__(
self,
backbone,
img_size=224,
patch_size=1,
feature_size=None,
in_chans=3,
embed_dim=768,
bias=True,
):
super().__init__(
backbone=backbone,
img_size=img_size,
patch_size=patch_size,
feature_size=feature_size,
in_chans=in_chans,
embed_dim=embed_dim,
bias=bias,
)
def forward(self, x) -> Tuple[torch.Tensor, List[int]]:
x = self.backbone(x)
if isinstance(x, (list, tuple)):
x = x[-1] # last feature if backbone outputs list/tuple of features
x = self.proj(x)
return x.flatten(2).transpose(1, 2), x.shape[-2:]
def _create_vision_transformer_hybrid(variant, backbone, pretrained=False, **kwargs): def _create_vision_transformer_hybrid(variant, backbone, pretrained=False, **kwargs):
embed_layer = partial(HybridEmbed, backbone=backbone) embed_layer = partial(HybridEmbed, backbone=backbone)
kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set