mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Tweak DinoV2 add, add MAE ViT weights, add initial intermediate layer getter experiment
This commit is contained in:
parent
59bea4c306
commit
a01d8f86f4
@ -34,7 +34,7 @@ from .norm_act import BatchNormAct2d, GroupNormAct, GroupNorm1Act, LayerNormAct,
|
|||||||
SyncBatchNormAct, convert_sync_batchnorm, FrozenBatchNormAct2d, freeze_batch_norm_2d, unfreeze_batch_norm_2d
|
SyncBatchNormAct, convert_sync_batchnorm, FrozenBatchNormAct2d, freeze_batch_norm_2d, unfreeze_batch_norm_2d
|
||||||
from .padding import get_padding, get_same_padding, pad_same
|
from .padding import get_padding, get_same_padding, pad_same
|
||||||
from .patch_dropout import PatchDropout
|
from .patch_dropout import PatchDropout
|
||||||
from .patch_embed import PatchEmbed, resample_patch_embed
|
from .patch_embed import PatchEmbed, PatchEmbedWithSize, resample_patch_embed
|
||||||
from .pool2d_same import AvgPool2dSame, create_pool2d
|
from .pool2d_same import AvgPool2dSame, create_pool2d
|
||||||
from .pos_embed import resample_abs_pos_embed
|
from .pos_embed import resample_abs_pos_embed
|
||||||
from .pos_embed_rel import RelPosMlp, RelPosBias, RelPosBiasTf, gen_relative_position_index, gen_relative_log_coords
|
from .pos_embed_rel import RelPosMlp, RelPosBias, RelPosBiasTf, gen_relative_position_index, gen_relative_log_coords
|
||||||
|
@ -9,7 +9,7 @@ Based on code in:
|
|||||||
Hacked together by / Copyright 2020 Ross Wightman
|
Hacked together by / Copyright 2020 Ross Wightman
|
||||||
"""
|
"""
|
||||||
import logging
|
import logging
|
||||||
from typing import List, Optional, Callable
|
from typing import Callable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn as nn
|
from torch import nn as nn
|
||||||
@ -75,6 +75,49 @@ class PatchEmbed(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class PatchEmbedWithSize(PatchEmbed):
|
||||||
|
""" 2D Image to Patch Embedding
|
||||||
|
"""
|
||||||
|
output_fmt: Format
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
img_size: Optional[int] = 224,
|
||||||
|
patch_size: int = 16,
|
||||||
|
in_chans: int = 3,
|
||||||
|
embed_dim: int = 768,
|
||||||
|
norm_layer: Optional[Callable] = None,
|
||||||
|
flatten: bool = True,
|
||||||
|
output_fmt: Optional[str] = None,
|
||||||
|
bias: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
img_size=img_size,
|
||||||
|
patch_size=patch_size,
|
||||||
|
in_chans=in_chans,
|
||||||
|
embed_dim=embed_dim,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
flatten=flatten,
|
||||||
|
output_fmt=output_fmt,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x) -> Tuple[torch.Tensor, List[int]]:
|
||||||
|
B, C, H, W = x.shape
|
||||||
|
if self.img_size is not None:
|
||||||
|
_assert(H % self.patch_size[0] == 0, f"Input image height ({H}) must be divisible by patch size ({self.patch_size[0]}).")
|
||||||
|
_assert(W % self.patch_size[1] == 0, f"Input image width ({W}) must be divisible by patch size ({self.patch_size[1]}).")
|
||||||
|
|
||||||
|
x = self.proj(x)
|
||||||
|
grid_size = x.shape[-2:]
|
||||||
|
if self.flatten:
|
||||||
|
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
|
||||||
|
elif self.output_fmt != Format.NCHW:
|
||||||
|
x = nchw_to(x, self.output_fmt)
|
||||||
|
x = self.norm(x)
|
||||||
|
return x, grid_size
|
||||||
|
|
||||||
|
|
||||||
def resample_patch_embed(
|
def resample_patch_embed(
|
||||||
patch_embed,
|
patch_embed,
|
||||||
new_size: List[int],
|
new_size: List[int],
|
||||||
|
@ -24,29 +24,31 @@ def resample_abs_pos_embed(
|
|||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
):
|
):
|
||||||
# sort out sizes, assume square if old size not provided
|
# sort out sizes, assume square if old size not provided
|
||||||
new_size = to_2tuple(new_size)
|
num_pos_tokens = posemb.shape[1]
|
||||||
new_ntok = new_size[0] * new_size[1]
|
num_new_tokens = new_size[0] * new_size[1] + num_prefix_tokens
|
||||||
if not old_size:
|
if num_new_tokens == num_pos_tokens and new_size[0] == new_size[1]:
|
||||||
old_size = int(math.sqrt(posemb.shape[1] - num_prefix_tokens))
|
|
||||||
old_size = to_2tuple(old_size)
|
|
||||||
if new_size == old_size: # might not both be same container type
|
|
||||||
return posemb
|
return posemb
|
||||||
|
|
||||||
|
if not old_size:
|
||||||
|
hw = int(math.sqrt(num_pos_tokens - num_prefix_tokens))
|
||||||
|
old_size = hw, hw
|
||||||
|
|
||||||
if num_prefix_tokens:
|
if num_prefix_tokens:
|
||||||
posemb_prefix, posemb = posemb[:, :num_prefix_tokens], posemb[:, num_prefix_tokens:]
|
posemb_prefix, posemb = posemb[:, :num_prefix_tokens], posemb[:, num_prefix_tokens:]
|
||||||
else:
|
else:
|
||||||
posemb_prefix, posemb = None, posemb
|
posemb_prefix, posemb = None, posemb
|
||||||
|
|
||||||
# do the interpolation
|
# do the interpolation
|
||||||
|
embed_dim = posemb.shape[-1]
|
||||||
posemb = posemb.reshape(1, old_size[0], old_size[1], -1).permute(0, 3, 1, 2)
|
posemb = posemb.reshape(1, old_size[0], old_size[1], -1).permute(0, 3, 1, 2)
|
||||||
posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias)
|
posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias)
|
||||||
posemb = posemb.permute(0, 2, 3, 1).reshape(1, new_ntok, -1)
|
posemb = posemb.permute(0, 2, 3, 1).reshape(1, -1, embed_dim)
|
||||||
|
|
||||||
if verbose:
|
|
||||||
_logger.info(f'Resized position embedding: {old_size} to {new_size}.')
|
|
||||||
|
|
||||||
# add back extra (class, etc) prefix tokens
|
# add back extra (class, etc) prefix tokens
|
||||||
if posemb_prefix is not None:
|
if posemb_prefix is not None:
|
||||||
print(posemb_prefix.shape, posemb.shape)
|
|
||||||
posemb = torch.cat([posemb_prefix, posemb], dim=1)
|
posemb = torch.cat([posemb_prefix, posemb], dim=1)
|
||||||
|
|
||||||
|
if not torch.jit.is_scripting() and verbose:
|
||||||
|
_logger.info(f'Resized position embedding: {old_size} to {new_size}.')
|
||||||
|
|
||||||
return posemb
|
return posemb
|
||||||
|
@ -11,11 +11,13 @@ Modifications copyright 2021, Ross Wightman
|
|||||||
# Copyright (c) 2015-present, Facebook, Inc.
|
# Copyright (c) 2015-present, Facebook, Inc.
|
||||||
# All rights reserved.
|
# All rights reserved.
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn as nn
|
from torch import nn as nn
|
||||||
|
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
|
from timm.layers import resample_abs_pos_embed
|
||||||
from timm.models.vision_transformer import VisionTransformer, trunc_normal_, checkpoint_filter_fn
|
from timm.models.vision_transformer import VisionTransformer, trunc_normal_, checkpoint_filter_fn
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
from ._manipulate import checkpoint_seq
|
from ._manipulate import checkpoint_seq
|
||||||
@ -71,11 +73,37 @@ class VisionTransformerDistilled(VisionTransformer):
|
|||||||
def set_distilled_training(self, enable=True):
|
def set_distilled_training(self, enable=True):
|
||||||
self.distilled_training = enable
|
self.distilled_training = enable
|
||||||
|
|
||||||
|
def _intermediate_layers(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
n: Union[int, Sequence] = 1,
|
||||||
|
):
|
||||||
|
outputs, num_blocks = [], len(self.blocks)
|
||||||
|
take_indices = set(range(num_blocks - n, num_blocks) if isinstance(n, int) else n)
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
x = self.patch_embed(x)
|
||||||
|
x = torch.cat((
|
||||||
|
self.cls_token.expand(x.shape[0], -1, -1),
|
||||||
|
self.dist_token.expand(x.shape[0], -1, -1),
|
||||||
|
x),
|
||||||
|
dim=1)
|
||||||
|
x = self.pos_drop(x + self.pos_embed)
|
||||||
|
x = self.patch_drop(x)
|
||||||
|
x = self.norm_pre(x)
|
||||||
|
for i, blk in enumerate(self.blocks):
|
||||||
|
x = blk(x)
|
||||||
|
if i in take_indices:
|
||||||
|
outputs.append(x)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
def forward_features(self, x) -> torch.Tensor:
|
def forward_features(self, x) -> torch.Tensor:
|
||||||
x = self.patch_embed(x)
|
x = self.patch_embed(x)
|
||||||
x = torch.cat((
|
x = torch.cat((
|
||||||
self.cls_token.expand(x.shape[0], -1, -1),
|
self.cls_token.expand(x.shape[0], -1, -1),
|
||||||
self.dist_token.expand(x.shape[0], -1, -1), x),
|
self.dist_token.expand(x.shape[0], -1, -1),
|
||||||
|
x),
|
||||||
dim=1)
|
dim=1)
|
||||||
x = self.pos_drop(x + self.pos_embed)
|
x = self.pos_drop(x + self.pos_embed)
|
||||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||||
|
@ -27,7 +27,7 @@ import logging
|
|||||||
import math
|
import math
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Callable, List, Optional, Tuple, Union
|
from typing import Callable, List, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -125,7 +125,7 @@ class Block(nn.Module):
|
|||||||
drop_path=0.,
|
drop_path=0.,
|
||||||
act_layer=nn.GELU,
|
act_layer=nn.GELU,
|
||||||
norm_layer=nn.LayerNorm,
|
norm_layer=nn.LayerNorm,
|
||||||
ffn_layer=Mlp,
|
mlp_layer=Mlp,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.norm1 = norm_layer(dim)
|
self.norm1 = norm_layer(dim)
|
||||||
@ -142,7 +142,7 @@ class Block(nn.Module):
|
|||||||
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||||
|
|
||||||
self.norm2 = norm_layer(dim)
|
self.norm2 = norm_layer(dim)
|
||||||
self.mlp = ffn_layer(
|
self.mlp = mlp_layer(
|
||||||
in_features=dim,
|
in_features=dim,
|
||||||
hidden_features=int(dim * mlp_ratio),
|
hidden_features=int(dim * mlp_ratio),
|
||||||
act_layer=act_layer,
|
act_layer=act_layer,
|
||||||
@ -172,7 +172,7 @@ class ResPostBlock(nn.Module):
|
|||||||
drop_path=0.,
|
drop_path=0.,
|
||||||
act_layer=nn.GELU,
|
act_layer=nn.GELU,
|
||||||
norm_layer=nn.LayerNorm,
|
norm_layer=nn.LayerNorm,
|
||||||
ffn_layer=Mlp,
|
mlp_layer=Mlp,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.init_values = init_values
|
self.init_values = init_values
|
||||||
@ -189,7 +189,7 @@ class ResPostBlock(nn.Module):
|
|||||||
self.norm1 = norm_layer(dim)
|
self.norm1 = norm_layer(dim)
|
||||||
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||||
|
|
||||||
self.mlp = ffn_layer(
|
self.mlp = mlp_layer(
|
||||||
in_features=dim,
|
in_features=dim,
|
||||||
hidden_features=int(dim * mlp_ratio),
|
hidden_features=int(dim * mlp_ratio),
|
||||||
act_layer=act_layer,
|
act_layer=act_layer,
|
||||||
@ -232,7 +232,7 @@ class ParallelScalingBlock(nn.Module):
|
|||||||
drop_path=0.,
|
drop_path=0.,
|
||||||
act_layer=nn.GELU,
|
act_layer=nn.GELU,
|
||||||
norm_layer=nn.LayerNorm,
|
norm_layer=nn.LayerNorm,
|
||||||
ffn_layer=None, # NOTE: not used
|
mlp_layer=None, # NOTE: not used
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
||||||
@ -326,7 +326,7 @@ class ParallelThingsBlock(nn.Module):
|
|||||||
drop_path=0.,
|
drop_path=0.,
|
||||||
act_layer=nn.GELU,
|
act_layer=nn.GELU,
|
||||||
norm_layer=nn.LayerNorm,
|
norm_layer=nn.LayerNorm,
|
||||||
ffn_layer=Mlp,
|
mlp_layer=Mlp,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_parallel = num_parallel
|
self.num_parallel = num_parallel
|
||||||
@ -349,7 +349,7 @@ class ParallelThingsBlock(nn.Module):
|
|||||||
])))
|
])))
|
||||||
self.ffns.append(nn.Sequential(OrderedDict([
|
self.ffns.append(nn.Sequential(OrderedDict([
|
||||||
('norm', norm_layer(dim)),
|
('norm', norm_layer(dim)),
|
||||||
('mlp', ffn_layer(
|
('mlp', mlp_layer(
|
||||||
dim,
|
dim,
|
||||||
hidden_features=int(dim * mlp_ratio),
|
hidden_features=int(dim * mlp_ratio),
|
||||||
act_layer=act_layer,
|
act_layer=act_layer,
|
||||||
@ -413,7 +413,7 @@ class VisionTransformer(nn.Module):
|
|||||||
norm_layer: Optional[Callable] = None,
|
norm_layer: Optional[Callable] = None,
|
||||||
act_layer: Optional[Callable] = None,
|
act_layer: Optional[Callable] = None,
|
||||||
block_fn: Callable = Block,
|
block_fn: Callable = Block,
|
||||||
ffn_layer: Callable = Mlp,
|
mlp_layer: Callable = Mlp,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -435,7 +435,7 @@ class VisionTransformer(nn.Module):
|
|||||||
attn_drop_rate: Attention dropout rate.
|
attn_drop_rate: Attention dropout rate.
|
||||||
drop_path_rate: Stochastic depth rate.
|
drop_path_rate: Stochastic depth rate.
|
||||||
weight_init: Weight initialization scheme.
|
weight_init: Weight initialization scheme.
|
||||||
embed_layer: Patch embedding layey.
|
embed_layer: Patch embedding layer.
|
||||||
norm_layer: Normalization layer.
|
norm_layer: Normalization layer.
|
||||||
act_layer: MLP activation layer.
|
act_layer: MLP activation layer.
|
||||||
block_fn: Transformer block layer.
|
block_fn: Transformer block layer.
|
||||||
@ -490,7 +490,7 @@ class VisionTransformer(nn.Module):
|
|||||||
drop_path=dpr[i],
|
drop_path=dpr[i],
|
||||||
norm_layer=norm_layer,
|
norm_layer=norm_layer,
|
||||||
act_layer=act_layer,
|
act_layer=act_layer,
|
||||||
ffn_layer=ffn_layer,
|
mlp_layer=mlp_layer,
|
||||||
)
|
)
|
||||||
for i in range(depth)])
|
for i in range(depth)])
|
||||||
self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
|
self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
|
||||||
@ -560,6 +560,55 @@ class VisionTransformer(nn.Module):
|
|||||||
x = x + self.pos_embed
|
x = x + self.pos_embed
|
||||||
return self.pos_drop(x)
|
return self.pos_drop(x)
|
||||||
|
|
||||||
|
def _intermediate_layers(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
n: Union[int, Sequence] = 1,
|
||||||
|
):
|
||||||
|
outputs, num_blocks = [], len(self.blocks)
|
||||||
|
take_indices = set(range(num_blocks - n, num_blocks) if isinstance(n, int) else n)
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
x = self.patch_embed(x)
|
||||||
|
x = self._pos_embed(x)
|
||||||
|
x = self.patch_drop(x)
|
||||||
|
x = self.norm_pre(x)
|
||||||
|
for i, blk in enumerate(self.blocks):
|
||||||
|
x = blk(x)
|
||||||
|
if i in take_indices:
|
||||||
|
outputs.append(x)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def get_intermediate_layers(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
n: Union[int, Sequence] = 1,
|
||||||
|
reshape: bool = False,
|
||||||
|
return_class_token: bool = False,
|
||||||
|
norm: bool = False,
|
||||||
|
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
|
||||||
|
""" Intermediate layer accessor (NOTE: This is a WIP experiment).
|
||||||
|
Inspired by DINO / DINOv2 interface
|
||||||
|
"""
|
||||||
|
# take last n blocks if n is an int, if in is a sequence, select by matching indices
|
||||||
|
outputs = self._intermediate_layers(x, n)
|
||||||
|
if norm:
|
||||||
|
outputs = [self.norm(out) for out in outputs]
|
||||||
|
class_tokens = [out[:, 0:self.num_prefix_tokens] for out in outputs]
|
||||||
|
outputs = [out[:, self.num_prefix_tokens:] for out in outputs]
|
||||||
|
|
||||||
|
if reshape:
|
||||||
|
grid_size = self.patch_embed.grid_size
|
||||||
|
outputs = [
|
||||||
|
out.reshape(x.shape[0], grid_size[0], grid_size[1], -1).permute(0, 3, 1, 2).contiguous()
|
||||||
|
for out in outputs
|
||||||
|
]
|
||||||
|
|
||||||
|
if return_class_token:
|
||||||
|
return tuple(zip(outputs, class_tokens))
|
||||||
|
return tuple(outputs)
|
||||||
|
|
||||||
def forward_features(self, x):
|
def forward_features(self, x):
|
||||||
x = self.patch_embed(x)
|
x = self.patch_embed(x)
|
||||||
x = self._pos_embed(x)
|
x = self._pos_embed(x)
|
||||||
@ -816,9 +865,7 @@ def _convert_openai_clip(state_dict, model):
|
|||||||
|
|
||||||
def _convert_dinov2(state_dict, model):
|
def _convert_dinov2(state_dict, model):
|
||||||
import re
|
import re
|
||||||
|
|
||||||
out_dict = {}
|
out_dict = {}
|
||||||
|
|
||||||
for k, v in state_dict.items():
|
for k, v in state_dict.items():
|
||||||
if k == "mask_token":
|
if k == "mask_token":
|
||||||
continue
|
continue
|
||||||
@ -828,11 +875,10 @@ def _convert_dinov2(state_dict, model):
|
|||||||
elif re.match(r"blocks\.(\d+)\.mlp\.w3\.(?:weight|bias)", k):
|
elif re.match(r"blocks\.(\d+)\.mlp\.w3\.(?:weight|bias)", k):
|
||||||
out_dict[k.replace("w3", "fc2")] = v
|
out_dict[k.replace("w3", "fc2")] = v
|
||||||
continue
|
continue
|
||||||
|
|
||||||
out_dict[k] = v
|
out_dict[k] = v
|
||||||
|
|
||||||
return out_dict
|
return out_dict
|
||||||
|
|
||||||
|
|
||||||
def checkpoint_filter_fn(
|
def checkpoint_filter_fn(
|
||||||
state_dict,
|
state_dict,
|
||||||
model,
|
model,
|
||||||
@ -1072,19 +1118,27 @@ default_cfgs = generate_default_cfgs({
|
|||||||
hf_hub_id='timm/',
|
hf_hub_id='timm/',
|
||||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
|
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
|
||||||
|
|
||||||
# DINOv2 pretrained - https://arxiv.org/abs/2304.07193 (no classifier head, for fine-tune only)
|
# DINOv2 pretrained - https://arxiv.org/abs/2304.07193 (no classifier head, for fine-tune/features only)
|
||||||
'vit_small_patch14_dinov2': _cfg(
|
'vit_small_patch14_dinov2.lvd142m': _cfg(
|
||||||
url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_pretrain.pth',
|
url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_pretrain.pth',
|
||||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, input_size=(3, 518, 518)),
|
license='cc-by-nc-4.0',
|
||||||
'vit_base_patch14_dinov2': _cfg(
|
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
|
||||||
|
input_size=(3, 518, 518), crop_pct=1.0),
|
||||||
|
'vit_base_patch14_dinov2.lvd142m': _cfg(
|
||||||
url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_pretrain.pth',
|
url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_pretrain.pth',
|
||||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, input_size=(3, 518, 518)),
|
license='cc-by-nc-4.0',
|
||||||
'vit_large_patch14_dinov2': _cfg(
|
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
|
||||||
|
input_size=(3, 518, 518), crop_pct=1.0),
|
||||||
|
'vit_large_patch14_dinov2.lvd142m': _cfg(
|
||||||
url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth',
|
url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth',
|
||||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, input_size=(3, 518, 518)),
|
license='cc-by-nc-4.0',
|
||||||
'vit_giant_patch14_dinov2': _cfg(
|
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
|
||||||
|
input_size=(3, 518, 518), crop_pct=1.0),
|
||||||
|
'vit_giant_patch14_dinov2.lvd142m': _cfg(
|
||||||
url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_pretrain.pth',
|
url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_pretrain.pth',
|
||||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, input_size=(3, 518, 518)),
|
license='cc-by-nc-4.0',
|
||||||
|
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
|
||||||
|
input_size=(3, 518, 518), crop_pct=1.0),
|
||||||
|
|
||||||
# ViT ImageNet-21K-P pretraining by MILL
|
# ViT ImageNet-21K-P pretraining by MILL
|
||||||
'vit_base_patch16_224_miil.in21k': _cfg(
|
'vit_base_patch16_224_miil.in21k': _cfg(
|
||||||
@ -1359,6 +1413,22 @@ default_cfgs = generate_default_cfgs({
|
|||||||
'vit_base_patch16_xp_224.untrained': _cfg(url=''),
|
'vit_base_patch16_xp_224.untrained': _cfg(url=''),
|
||||||
'vit_large_patch14_xp_224.untrained': _cfg(url=''),
|
'vit_large_patch14_xp_224.untrained': _cfg(url=''),
|
||||||
'vit_huge_patch14_xp_224.untrained': _cfg(url=''),
|
'vit_huge_patch14_xp_224.untrained': _cfg(url=''),
|
||||||
|
|
||||||
|
'vit_base_patch16_224.mae': _cfg(
|
||||||
|
url='https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_base.pth',
|
||||||
|
#hf_hub_id='timm/',
|
||||||
|
license='cc-by-nc-4.0',
|
||||||
|
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
|
||||||
|
'vit_large_patch16_224.mae': _cfg(
|
||||||
|
url='https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_large.pth',
|
||||||
|
# hf_hub_id='timm/',
|
||||||
|
license='cc-by-nc-4.0',
|
||||||
|
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
|
||||||
|
'vit_huge_patch14_224.mae': _cfg(
|
||||||
|
url='https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_huge.pth',
|
||||||
|
# hf_hub_id='timm/',
|
||||||
|
license='cc-by-nc-4.0',
|
||||||
|
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
@ -1904,10 +1974,8 @@ def vit_small_patch14_dinov2(pretrained=False, **kwargs):
|
|||||||
""" ViT-S/14 for DINOv2
|
""" ViT-S/14 for DINOv2
|
||||||
"""
|
"""
|
||||||
model_args = dict(
|
model_args = dict(
|
||||||
patch_size=14, embed_dim=384, depth=12, num_heads=6,
|
patch_size=14, embed_dim=384, depth=12, num_heads=6, init_values=1.0, img_size=518,
|
||||||
init_values=1.0, img_size=518,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
model = _create_vision_transformer(
|
model = _create_vision_transformer(
|
||||||
'vit_small_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
|
'vit_small_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
return model
|
return model
|
||||||
@ -1918,10 +1986,8 @@ def vit_base_patch14_dinov2(pretrained=False, **kwargs):
|
|||||||
""" ViT-B/14 for DINOv2
|
""" ViT-B/14 for DINOv2
|
||||||
"""
|
"""
|
||||||
model_args = dict(
|
model_args = dict(
|
||||||
patch_size=14, embed_dim=768, depth=12, num_heads=12,
|
patch_size=14, embed_dim=768, depth=12, num_heads=12, init_values=1.0, img_size=518,
|
||||||
init_values=1.0, img_size=518,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
model = _create_vision_transformer(
|
model = _create_vision_transformer(
|
||||||
'vit_base_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
|
'vit_base_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
return model
|
return model
|
||||||
@ -1932,14 +1998,13 @@ def vit_large_patch14_dinov2(pretrained=False, **kwargs):
|
|||||||
""" ViT-L/14 for DINOv2
|
""" ViT-L/14 for DINOv2
|
||||||
"""
|
"""
|
||||||
model_args = dict(
|
model_args = dict(
|
||||||
patch_size=14, embed_dim=1024, depth=24, num_heads=16,
|
patch_size=14, embed_dim=1024, depth=24, num_heads=16, init_values=1.0, img_size=518,
|
||||||
init_values=1.0, img_size=518,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
model = _create_vision_transformer(
|
model = _create_vision_transformer(
|
||||||
'vit_large_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
|
'vit_large_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def vit_giant_patch14_dinov2(pretrained=False, **kwargs):
|
def vit_giant_patch14_dinov2(pretrained=False, **kwargs):
|
||||||
""" ViT-G/14 for DINOv2
|
""" ViT-G/14 for DINOv2
|
||||||
@ -1952,13 +2017,13 @@ def vit_giant_patch14_dinov2(pretrained=False, **kwargs):
|
|||||||
|
|
||||||
model_args = dict(
|
model_args = dict(
|
||||||
patch_size=14, embed_dim=1536, depth=40, num_heads=24, init_values=1.0,
|
patch_size=14, embed_dim=1536, depth=40, num_heads=24, init_values=1.0,
|
||||||
mlp_ratio=2.66667 * 2, ffn_layer=SwiGLUPacked, img_size=518, act_layer=nn.SiLU
|
mlp_ratio=2.66667 * 2, mlp_layer=SwiGLUPacked, img_size=518, act_layer=nn.SiLU
|
||||||
)
|
)
|
||||||
|
|
||||||
model = _create_vision_transformer(
|
model = _create_vision_transformer(
|
||||||
'vit_giant_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
|
'vit_giant_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
register_model_deprecations(__name__, {
|
register_model_deprecations(__name__, {
|
||||||
'vit_tiny_patch16_224_in21k': 'vit_tiny_patch16_224.augreg_in21k',
|
'vit_tiny_patch16_224_in21k': 'vit_tiny_patch16_224.augreg_in21k',
|
||||||
'vit_small_patch32_224_in21k': 'vit_small_patch32_224.augreg_in21k',
|
'vit_small_patch32_224_in21k': 'vit_small_patch32_224.augreg_in21k',
|
||||||
|
@ -14,6 +14,7 @@ They were moved here to keep file sizes sane.
|
|||||||
Hacked together by / Copyright 2020, Ross Wightman
|
Hacked together by / Copyright 2020, Ross Wightman
|
||||||
"""
|
"""
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -74,10 +75,43 @@ class HybridEmbed(nn.Module):
|
|||||||
x = self.backbone(x)
|
x = self.backbone(x)
|
||||||
if isinstance(x, (list, tuple)):
|
if isinstance(x, (list, tuple)):
|
||||||
x = x[-1] # last feature if backbone outputs list/tuple of features
|
x = x[-1] # last feature if backbone outputs list/tuple of features
|
||||||
x = self.proj(x).flatten(2).transpose(1, 2)
|
x = self.proj(x)
|
||||||
|
x = x.flatten(2).transpose(1, 2)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class HybridEmbedWithSize(nn.Module):
|
||||||
|
""" CNN Feature Map Embedding
|
||||||
|
Extract feature map from CNN, flatten, project to embedding dim.
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
backbone,
|
||||||
|
img_size=224,
|
||||||
|
patch_size=1,
|
||||||
|
feature_size=None,
|
||||||
|
in_chans=3,
|
||||||
|
embed_dim=768,
|
||||||
|
bias=True,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
backbone=backbone,
|
||||||
|
img_size=img_size,
|
||||||
|
patch_size=patch_size,
|
||||||
|
feature_size=feature_size,
|
||||||
|
in_chans=in_chans,
|
||||||
|
embed_dim=embed_dim,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x) -> Tuple[torch.Tensor, List[int]]:
|
||||||
|
x = self.backbone(x)
|
||||||
|
if isinstance(x, (list, tuple)):
|
||||||
|
x = x[-1] # last feature if backbone outputs list/tuple of features
|
||||||
|
x = self.proj(x)
|
||||||
|
return x.flatten(2).transpose(1, 2), x.shape[-2:]
|
||||||
|
|
||||||
|
|
||||||
def _create_vision_transformer_hybrid(variant, backbone, pretrained=False, **kwargs):
|
def _create_vision_transformer_hybrid(variant, backbone, pretrained=False, **kwargs):
|
||||||
embed_layer = partial(HybridEmbed, backbone=backbone)
|
embed_layer = partial(HybridEmbed, backbone=backbone)
|
||||||
kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set
|
kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set
|
||||||
|
Loading…
x
Reference in New Issue
Block a user