Tweak DinoV2 add, add MAE ViT weights, add initial intermediate layer getter experiment

This commit is contained in:
Ross Wightman 2023-05-09 17:59:22 -07:00
parent 59bea4c306
commit a01d8f86f4
6 changed files with 222 additions and 50 deletions

View File

@ -34,7 +34,7 @@ from .norm_act import BatchNormAct2d, GroupNormAct, GroupNorm1Act, LayerNormAct,
SyncBatchNormAct, convert_sync_batchnorm, FrozenBatchNormAct2d, freeze_batch_norm_2d, unfreeze_batch_norm_2d
from .padding import get_padding, get_same_padding, pad_same
from .patch_dropout import PatchDropout
from .patch_embed import PatchEmbed, resample_patch_embed
from .patch_embed import PatchEmbed, PatchEmbedWithSize, resample_patch_embed
from .pool2d_same import AvgPool2dSame, create_pool2d
from .pos_embed import resample_abs_pos_embed
from .pos_embed_rel import RelPosMlp, RelPosBias, RelPosBiasTf, gen_relative_position_index, gen_relative_log_coords

View File

@ -9,7 +9,7 @@ Based on code in:
Hacked together by / Copyright 2020 Ross Wightman
"""
import logging
from typing import List, Optional, Callable
from typing import Callable, List, Optional, Tuple, Union
import torch
from torch import nn as nn
@ -75,6 +75,49 @@ class PatchEmbed(nn.Module):
return x
class PatchEmbedWithSize(PatchEmbed):
""" 2D Image to Patch Embedding
"""
output_fmt: Format
def __init__(
self,
img_size: Optional[int] = 224,
patch_size: int = 16,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer: Optional[Callable] = None,
flatten: bool = True,
output_fmt: Optional[str] = None,
bias: bool = True,
):
super().__init__(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
norm_layer=norm_layer,
flatten=flatten,
output_fmt=output_fmt,
bias=bias,
)
def forward(self, x) -> Tuple[torch.Tensor, List[int]]:
B, C, H, W = x.shape
if self.img_size is not None:
_assert(H % self.patch_size[0] == 0, f"Input image height ({H}) must be divisible by patch size ({self.patch_size[0]}).")
_assert(W % self.patch_size[1] == 0, f"Input image width ({W}) must be divisible by patch size ({self.patch_size[1]}).")
x = self.proj(x)
grid_size = x.shape[-2:]
if self.flatten:
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
elif self.output_fmt != Format.NCHW:
x = nchw_to(x, self.output_fmt)
x = self.norm(x)
return x, grid_size
def resample_patch_embed(
patch_embed,
new_size: List[int],

View File

@ -24,29 +24,31 @@ def resample_abs_pos_embed(
verbose: bool = False,
):
# sort out sizes, assume square if old size not provided
new_size = to_2tuple(new_size)
new_ntok = new_size[0] * new_size[1]
if not old_size:
old_size = int(math.sqrt(posemb.shape[1] - num_prefix_tokens))
old_size = to_2tuple(old_size)
if new_size == old_size: # might not both be same container type
num_pos_tokens = posemb.shape[1]
num_new_tokens = new_size[0] * new_size[1] + num_prefix_tokens
if num_new_tokens == num_pos_tokens and new_size[0] == new_size[1]:
return posemb
if not old_size:
hw = int(math.sqrt(num_pos_tokens - num_prefix_tokens))
old_size = hw, hw
if num_prefix_tokens:
posemb_prefix, posemb = posemb[:, :num_prefix_tokens], posemb[:, num_prefix_tokens:]
else:
posemb_prefix, posemb = None, posemb
# do the interpolation
embed_dim = posemb.shape[-1]
posemb = posemb.reshape(1, old_size[0], old_size[1], -1).permute(0, 3, 1, 2)
posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias)
posemb = posemb.permute(0, 2, 3, 1).reshape(1, new_ntok, -1)
if verbose:
_logger.info(f'Resized position embedding: {old_size} to {new_size}.')
posemb = posemb.permute(0, 2, 3, 1).reshape(1, -1, embed_dim)
# add back extra (class, etc) prefix tokens
if posemb_prefix is not None:
print(posemb_prefix.shape, posemb.shape)
posemb = torch.cat([posemb_prefix, posemb], dim=1)
if not torch.jit.is_scripting() and verbose:
_logger.info(f'Resized position embedding: {old_size} to {new_size}.')
return posemb

View File

@ -11,11 +11,13 @@ Modifications copyright 2021, Ross Wightman
# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
from functools import partial
from typing import Sequence, Union
import torch
from torch import nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import resample_abs_pos_embed
from timm.models.vision_transformer import VisionTransformer, trunc_normal_, checkpoint_filter_fn
from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq
@ -71,11 +73,37 @@ class VisionTransformerDistilled(VisionTransformer):
def set_distilled_training(self, enable=True):
self.distilled_training = enable
def _intermediate_layers(
self,
x: torch.Tensor,
n: Union[int, Sequence] = 1,
):
outputs, num_blocks = [], len(self.blocks)
take_indices = set(range(num_blocks - n, num_blocks) if isinstance(n, int) else n)
# forward pass
x = self.patch_embed(x)
x = torch.cat((
self.cls_token.expand(x.shape[0], -1, -1),
self.dist_token.expand(x.shape[0], -1, -1),
x),
dim=1)
x = self.pos_drop(x + self.pos_embed)
x = self.patch_drop(x)
x = self.norm_pre(x)
for i, blk in enumerate(self.blocks):
x = blk(x)
if i in take_indices:
outputs.append(x)
return outputs
def forward_features(self, x) -> torch.Tensor:
x = self.patch_embed(x)
x = torch.cat((
self.cls_token.expand(x.shape[0], -1, -1),
self.dist_token.expand(x.shape[0], -1, -1), x),
self.dist_token.expand(x.shape[0], -1, -1),
x),
dim=1)
x = self.pos_drop(x + self.pos_embed)
if self.grad_checkpointing and not torch.jit.is_scripting():

View File

@ -27,7 +27,7 @@ import logging
import math
from collections import OrderedDict
from functools import partial
from typing import Callable, List, Optional, Tuple, Union
from typing import Callable, List, Optional, Sequence, Tuple, Union
import torch
import torch.nn as nn
@ -125,7 +125,7 @@ class Block(nn.Module):
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
ffn_layer=Mlp,
mlp_layer=Mlp,
):
super().__init__()
self.norm1 = norm_layer(dim)
@ -142,7 +142,7 @@ class Block(nn.Module):
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = ffn_layer(
self.mlp = mlp_layer(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
@ -172,7 +172,7 @@ class ResPostBlock(nn.Module):
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
ffn_layer=Mlp,
mlp_layer=Mlp,
):
super().__init__()
self.init_values = init_values
@ -189,7 +189,7 @@ class ResPostBlock(nn.Module):
self.norm1 = norm_layer(dim)
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.mlp = ffn_layer(
self.mlp = mlp_layer(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
@ -232,7 +232,7 @@ class ParallelScalingBlock(nn.Module):
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
ffn_layer=None, # NOTE: not used
mlp_layer=None, # NOTE: not used
):
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
@ -326,7 +326,7 @@ class ParallelThingsBlock(nn.Module):
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
ffn_layer=Mlp,
mlp_layer=Mlp,
):
super().__init__()
self.num_parallel = num_parallel
@ -349,7 +349,7 @@ class ParallelThingsBlock(nn.Module):
])))
self.ffns.append(nn.Sequential(OrderedDict([
('norm', norm_layer(dim)),
('mlp', ffn_layer(
('mlp', mlp_layer(
dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
@ -413,7 +413,7 @@ class VisionTransformer(nn.Module):
norm_layer: Optional[Callable] = None,
act_layer: Optional[Callable] = None,
block_fn: Callable = Block,
ffn_layer: Callable = Mlp,
mlp_layer: Callable = Mlp,
):
"""
Args:
@ -435,7 +435,7 @@ class VisionTransformer(nn.Module):
attn_drop_rate: Attention dropout rate.
drop_path_rate: Stochastic depth rate.
weight_init: Weight initialization scheme.
embed_layer: Patch embedding layey.
embed_layer: Patch embedding layer.
norm_layer: Normalization layer.
act_layer: MLP activation layer.
block_fn: Transformer block layer.
@ -490,7 +490,7 @@ class VisionTransformer(nn.Module):
drop_path=dpr[i],
norm_layer=norm_layer,
act_layer=act_layer,
ffn_layer=ffn_layer,
mlp_layer=mlp_layer,
)
for i in range(depth)])
self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
@ -560,6 +560,55 @@ class VisionTransformer(nn.Module):
x = x + self.pos_embed
return self.pos_drop(x)
def _intermediate_layers(
self,
x: torch.Tensor,
n: Union[int, Sequence] = 1,
):
outputs, num_blocks = [], len(self.blocks)
take_indices = set(range(num_blocks - n, num_blocks) if isinstance(n, int) else n)
# forward pass
x = self.patch_embed(x)
x = self._pos_embed(x)
x = self.patch_drop(x)
x = self.norm_pre(x)
for i, blk in enumerate(self.blocks):
x = blk(x)
if i in take_indices:
outputs.append(x)
return outputs
def get_intermediate_layers(
self,
x: torch.Tensor,
n: Union[int, Sequence] = 1,
reshape: bool = False,
return_class_token: bool = False,
norm: bool = False,
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
""" Intermediate layer accessor (NOTE: This is a WIP experiment).
Inspired by DINO / DINOv2 interface
"""
# take last n blocks if n is an int, if in is a sequence, select by matching indices
outputs = self._intermediate_layers(x, n)
if norm:
outputs = [self.norm(out) for out in outputs]
class_tokens = [out[:, 0:self.num_prefix_tokens] for out in outputs]
outputs = [out[:, self.num_prefix_tokens:] for out in outputs]
if reshape:
grid_size = self.patch_embed.grid_size
outputs = [
out.reshape(x.shape[0], grid_size[0], grid_size[1], -1).permute(0, 3, 1, 2).contiguous()
for out in outputs
]
if return_class_token:
return tuple(zip(outputs, class_tokens))
return tuple(outputs)
def forward_features(self, x):
x = self.patch_embed(x)
x = self._pos_embed(x)
@ -816,9 +865,7 @@ def _convert_openai_clip(state_dict, model):
def _convert_dinov2(state_dict, model):
import re
out_dict = {}
for k, v in state_dict.items():
if k == "mask_token":
continue
@ -828,11 +875,10 @@ def _convert_dinov2(state_dict, model):
elif re.match(r"blocks\.(\d+)\.mlp\.w3\.(?:weight|bias)", k):
out_dict[k.replace("w3", "fc2")] = v
continue
out_dict[k] = v
return out_dict
def checkpoint_filter_fn(
state_dict,
model,
@ -1072,19 +1118,27 @@ default_cfgs = generate_default_cfgs({
hf_hub_id='timm/',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
# DINOv2 pretrained - https://arxiv.org/abs/2304.07193 (no classifier head, for fine-tune only)
'vit_small_patch14_dinov2': _cfg(
# DINOv2 pretrained - https://arxiv.org/abs/2304.07193 (no classifier head, for fine-tune/features only)
'vit_small_patch14_dinov2.lvd142m': _cfg(
url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_pretrain.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, input_size=(3, 518, 518)),
'vit_base_patch14_dinov2': _cfg(
license='cc-by-nc-4.0',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
input_size=(3, 518, 518), crop_pct=1.0),
'vit_base_patch14_dinov2.lvd142m': _cfg(
url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_pretrain.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, input_size=(3, 518, 518)),
'vit_large_patch14_dinov2': _cfg(
license='cc-by-nc-4.0',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
input_size=(3, 518, 518), crop_pct=1.0),
'vit_large_patch14_dinov2.lvd142m': _cfg(
url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, input_size=(3, 518, 518)),
'vit_giant_patch14_dinov2': _cfg(
license='cc-by-nc-4.0',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
input_size=(3, 518, 518), crop_pct=1.0),
'vit_giant_patch14_dinov2.lvd142m': _cfg(
url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_pretrain.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, input_size=(3, 518, 518)),
license='cc-by-nc-4.0',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
input_size=(3, 518, 518), crop_pct=1.0),
# ViT ImageNet-21K-P pretraining by MILL
'vit_base_patch16_224_miil.in21k': _cfg(
@ -1359,6 +1413,22 @@ default_cfgs = generate_default_cfgs({
'vit_base_patch16_xp_224.untrained': _cfg(url=''),
'vit_large_patch14_xp_224.untrained': _cfg(url=''),
'vit_huge_patch14_xp_224.untrained': _cfg(url=''),
'vit_base_patch16_224.mae': _cfg(
url='https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_base.pth',
#hf_hub_id='timm/',
license='cc-by-nc-4.0',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
'vit_large_patch16_224.mae': _cfg(
url='https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_large.pth',
# hf_hub_id='timm/',
license='cc-by-nc-4.0',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
'vit_huge_patch14_224.mae': _cfg(
url='https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_huge.pth',
# hf_hub_id='timm/',
license='cc-by-nc-4.0',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
})
@ -1904,10 +1974,8 @@ def vit_small_patch14_dinov2(pretrained=False, **kwargs):
""" ViT-S/14 for DINOv2
"""
model_args = dict(
patch_size=14, embed_dim=384, depth=12, num_heads=6,
init_values=1.0, img_size=518,
patch_size=14, embed_dim=384, depth=12, num_heads=6, init_values=1.0, img_size=518,
)
model = _create_vision_transformer(
'vit_small_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@ -1918,10 +1986,8 @@ def vit_base_patch14_dinov2(pretrained=False, **kwargs):
""" ViT-B/14 for DINOv2
"""
model_args = dict(
patch_size=14, embed_dim=768, depth=12, num_heads=12,
init_values=1.0, img_size=518,
patch_size=14, embed_dim=768, depth=12, num_heads=12, init_values=1.0, img_size=518,
)
model = _create_vision_transformer(
'vit_base_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@ -1932,14 +1998,13 @@ def vit_large_patch14_dinov2(pretrained=False, **kwargs):
""" ViT-L/14 for DINOv2
"""
model_args = dict(
patch_size=14, embed_dim=1024, depth=24, num_heads=16,
init_values=1.0, img_size=518,
patch_size=14, embed_dim=1024, depth=24, num_heads=16, init_values=1.0, img_size=518,
)
model = _create_vision_transformer(
'vit_large_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_giant_patch14_dinov2(pretrained=False, **kwargs):
""" ViT-G/14 for DINOv2
@ -1952,13 +2017,13 @@ def vit_giant_patch14_dinov2(pretrained=False, **kwargs):
model_args = dict(
patch_size=14, embed_dim=1536, depth=40, num_heads=24, init_values=1.0,
mlp_ratio=2.66667 * 2, ffn_layer=SwiGLUPacked, img_size=518, act_layer=nn.SiLU
mlp_ratio=2.66667 * 2, mlp_layer=SwiGLUPacked, img_size=518, act_layer=nn.SiLU
)
model = _create_vision_transformer(
'vit_giant_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
return model
register_model_deprecations(__name__, {
'vit_tiny_patch16_224_in21k': 'vit_tiny_patch16_224.augreg_in21k',
'vit_small_patch32_224_in21k': 'vit_small_patch32_224.augreg_in21k',

View File

@ -14,6 +14,7 @@ They were moved here to keep file sizes sane.
Hacked together by / Copyright 2020, Ross Wightman
"""
from functools import partial
from typing import List, Tuple
import torch
import torch.nn as nn
@ -74,10 +75,43 @@ class HybridEmbed(nn.Module):
x = self.backbone(x)
if isinstance(x, (list, tuple)):
x = x[-1] # last feature if backbone outputs list/tuple of features
x = self.proj(x).flatten(2).transpose(1, 2)
x = self.proj(x)
x = x.flatten(2).transpose(1, 2)
return x
class HybridEmbedWithSize(nn.Module):
""" CNN Feature Map Embedding
Extract feature map from CNN, flatten, project to embedding dim.
"""
def __init__(
self,
backbone,
img_size=224,
patch_size=1,
feature_size=None,
in_chans=3,
embed_dim=768,
bias=True,
):
super().__init__(
backbone=backbone,
img_size=img_size,
patch_size=patch_size,
feature_size=feature_size,
in_chans=in_chans,
embed_dim=embed_dim,
bias=bias,
)
def forward(self, x) -> Tuple[torch.Tensor, List[int]]:
x = self.backbone(x)
if isinstance(x, (list, tuple)):
x = x[-1] # last feature if backbone outputs list/tuple of features
x = self.proj(x)
return x.flatten(2).transpose(1, 2), x.shape[-2:]
def _create_vision_transformer_hybrid(variant, backbone, pretrained=False, **kwargs):
embed_layer = partial(HybridEmbed, backbone=backbone)
kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set