Merge pull request #2213 from huggingface/florence2

Fix #2212 map florence2 image tower to davit with a few changes
This commit is contained in:
Ross Wightman 2024-06-24 11:01:08 -07:00 committed by GitHub
commit f8342a045a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -34,7 +34,14 @@ class ConvPosEnc(nn.Module):
def __init__(self, dim: int, k: int = 3, act: bool = False): def __init__(self, dim: int, k: int = 3, act: bool = False):
super(ConvPosEnc, self).__init__() super(ConvPosEnc, self).__init__()
self.proj = nn.Conv2d(dim, dim, k, 1, k // 2, groups=dim) self.proj = nn.Conv2d(
dim,
dim,
kernel_size=k,
stride=1,
padding=k // 2,
groups=dim,
)
self.act = nn.GELU() if act else nn.Identity() self.act = nn.GELU() if act else nn.Identity()
def forward(self, x: Tensor): def forward(self, x: Tensor):
@ -72,8 +79,9 @@ class Stem(nn.Module):
def forward(self, x: Tensor): def forward(self, x: Tensor):
B, C, H, W = x.shape B, C, H, W = x.shape
x = F.pad(x, (0, (self.stride[1] - W % self.stride[1]) % self.stride[1])) pad_r = (self.stride[1] - W % self.stride[1]) % self.stride[1]
x = F.pad(x, (0, 0, 0, (self.stride[0] - H % self.stride[0]) % self.stride[0])) pad_b = (self.stride[0] - H % self.stride[0]) % self.stride[0]
x = F.pad(x, (0, pad_r, 0, pad_b))
x = self.conv(x) x = self.conv(x)
x = self.norm(x) x = self.norm(x)
return x return x
@ -84,6 +92,7 @@ class Downsample(nn.Module):
self, self,
in_chs, in_chs,
out_chs, out_chs,
kernel_size=3,
norm_layer=LayerNorm2d, norm_layer=LayerNorm2d,
): ):
super().__init__() super().__init__()
@ -91,23 +100,58 @@ class Downsample(nn.Module):
self.out_chs = out_chs self.out_chs = out_chs
self.norm = norm_layer(in_chs) self.norm = norm_layer(in_chs)
self.even_k = kernel_size % 2 == 0
self.conv = nn.Conv2d( self.conv = nn.Conv2d(
in_chs, in_chs,
out_chs, out_chs,
kernel_size=2, kernel_size=kernel_size,
stride=2, stride=2,
padding=0, padding=0 if self.even_k else kernel_size // 2,
) )
def forward(self, x: Tensor): def forward(self, x: Tensor):
B, C, H, W = x.shape B, C, H, W = x.shape
x = self.norm(x) x = self.norm(x)
x = F.pad(x, (0, (2 - W % 2) % 2)) if self.even_k:
x = F.pad(x, (0, 0, 0, (2 - H % 2) % 2)) k_h, k_w = self.conv.kernel_size
pad_r = (k_w - W % k_w) % k_w
pad_b = (k_h - H % k_h) % k_h
x = F.pad(x, (0, pad_r , 0, pad_b))
x = self.conv(x) x = self.conv(x)
return x return x
class ChannelAttentionV2(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=True, dynamic_scale=True):
super().__init__()
self.groups = num_heads
self.head_dim = dim // num_heads
self.dynamic_scale = dynamic_scale
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.groups, C // self.groups).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
if self.dynamic_scale:
q = q * N ** -0.5
else:
q = q * self.head_dim ** -0.5
attn = q.transpose(-1, -2) @ k
attn = attn.softmax(dim=-1)
x = (attn @ v.transpose(-1, -2)).transpose(-1, -2)
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
return x
class ChannelAttention(nn.Module): class ChannelAttention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False): def __init__(self, dim, num_heads=8, qkv_bias=False):
@ -147,13 +191,19 @@ class ChannelBlock(nn.Module):
norm_layer=nn.LayerNorm, norm_layer=nn.LayerNorm,
ffn=True, ffn=True,
cpe_act=False, cpe_act=False,
v2=False,
): ):
super().__init__() super().__init__()
self.cpe1 = ConvPosEnc(dim=dim, k=3, act=cpe_act) self.cpe1 = ConvPosEnc(dim=dim, k=3, act=cpe_act)
self.ffn = ffn self.ffn = ffn
self.norm1 = norm_layer(dim) self.norm1 = norm_layer(dim)
self.attn = ChannelAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias) attn_layer = ChannelAttentionV2 if v2 else ChannelAttention
self.attn = attn_layer(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
)
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.cpe2 = ConvPosEnc(dim=dim, k=3, act=cpe_act) self.cpe2 = ConvPosEnc(dim=dim, k=3, act=cpe_act)
@ -372,13 +422,16 @@ class DaVitStage(nn.Module):
attn_types=('spatial', 'channel'), attn_types=('spatial', 'channel'),
num_heads=3, num_heads=3,
window_size=7, window_size=7,
mlp_ratio=4, mlp_ratio=4.,
qkv_bias=True, qkv_bias=True,
drop_path_rates=(0, 0), drop_path_rates=(0, 0),
norm_layer=LayerNorm2d, norm_layer=LayerNorm2d,
norm_layer_cl=nn.LayerNorm, norm_layer_cl=nn.LayerNorm,
ffn=True, ffn=True,
cpe_act=False cpe_act=False,
down_kernel_size=2,
named_blocks=False,
channel_attn_v2=False,
): ):
super().__init__() super().__init__()
@ -386,7 +439,7 @@ class DaVitStage(nn.Module):
# downsample embedding layer at the beginning of each stage # downsample embedding layer at the beginning of each stage
if downsample: if downsample:
self.downsample = Downsample(in_chs, out_chs, norm_layer=norm_layer) self.downsample = Downsample(in_chs, out_chs, kernel_size=down_kernel_size, norm_layer=norm_layer)
else: else:
self.downsample = nn.Identity() self.downsample = nn.Identity()
@ -399,10 +452,11 @@ class DaVitStage(nn.Module):
''' '''
stage_blocks = [] stage_blocks = []
for block_idx in range(depth): for block_idx in range(depth):
from collections import OrderedDict
dual_attention_block = [] dual_attention_block = []
for attn_idx, attn_type in enumerate(attn_types): for attn_idx, attn_type in enumerate(attn_types):
if attn_type == 'spatial': if attn_type == 'spatial':
dual_attention_block.append(SpatialBlock( dual_attention_block.append(('spatial_block', SpatialBlock(
dim=out_chs, dim=out_chs,
num_heads=num_heads, num_heads=num_heads,
mlp_ratio=mlp_ratio, mlp_ratio=mlp_ratio,
@ -412,9 +466,9 @@ class DaVitStage(nn.Module):
ffn=ffn, ffn=ffn,
cpe_act=cpe_act, cpe_act=cpe_act,
window_size=window_size, window_size=window_size,
)) )))
elif attn_type == 'channel': elif attn_type == 'channel':
dual_attention_block.append(ChannelBlock( dual_attention_block.append(('channel_block', ChannelBlock(
dim=out_chs, dim=out_chs,
num_heads=num_heads, num_heads=num_heads,
mlp_ratio=mlp_ratio, mlp_ratio=mlp_ratio,
@ -422,9 +476,13 @@ class DaVitStage(nn.Module):
drop_path=drop_path_rates[block_idx], drop_path=drop_path_rates[block_idx],
norm_layer=norm_layer_cl, norm_layer=norm_layer_cl,
ffn=ffn, ffn=ffn,
cpe_act=cpe_act cpe_act=cpe_act,
)) v2=channel_attn_v2,
stage_blocks.append(nn.Sequential(*dual_attention_block)) )))
if named_blocks:
stage_blocks.append(nn.Sequential(OrderedDict(dual_attention_block)))
else:
stage_blocks.append(nn.Sequential(*[b[1] for b in dual_attention_block]))
self.blocks = nn.Sequential(*stage_blocks) self.blocks = nn.Sequential(*stage_blocks)
@torch.jit.ignore @torch.jit.ignore
@ -473,6 +531,9 @@ class DaVit(nn.Module):
attn_types=('spatial', 'channel'), attn_types=('spatial', 'channel'),
ffn=True, ffn=True,
cpe_act=False, cpe_act=False,
down_kernel_size=2,
channel_attn_v2=False,
named_blocks=False,
drop_rate=0., drop_rate=0.,
drop_path_rate=0., drop_path_rate=0.,
num_classes=1000, num_classes=1000,
@ -512,6 +573,9 @@ class DaVit(nn.Module):
norm_layer_cl=norm_layer_cl, norm_layer_cl=norm_layer_cl,
ffn=ffn, ffn=ffn,
cpe_act=cpe_act, cpe_act=cpe_act,
down_kernel_size=down_kernel_size,
channel_attn_v2=channel_attn_v2,
named_blocks=named_blocks,
) )
in_chs = out_chs in_chs = out_chs
stages.append(stage) stages.append(stage)
@ -589,6 +653,34 @@ class DaVit(nn.Module):
return x return x
def _convert_florence2(state_dict, model, prefix='vision_tower.'):
import re
out_dict = {}
for k, v in state_dict.items():
if k.startswith(prefix):
k = k.replace(prefix, '')
else:
continue
k = re.sub(r'convs.([0-9]+)', r'stages.\1.downsample', k)
k = re.sub(r'blocks.([0-9]+)', r'stages.\1.blocks', k)
k = k.replace('downsample.proj', 'downsample.conv')
k = k.replace('stages.0.downsample', 'stem')
#k = k.replace('head.', 'head.fc.')
#k = k.replace('norms.', 'head.norm.')
k = k.replace('window_attn.norm.', 'norm1.')
k = k.replace('window_attn.fn.', 'attn.')
k = k.replace('channel_attn.norm.', 'norm1.')
k = k.replace('channel_attn.fn.', 'attn.')
k = k.replace('ffn.norm.', 'norm2.')
k = k.replace('ffn.fn.net.', 'mlp.')
k = k.replace('conv1.fn.dw', 'cpe1.proj')
k = k.replace('conv2.fn.dw', 'cpe2.proj')
out_dict[k] = v
return out_dict
def checkpoint_filter_fn(state_dict, model): def checkpoint_filter_fn(state_dict, model):
""" Remap MSFT checkpoints -> timm """ """ Remap MSFT checkpoints -> timm """
if 'head.fc.weight' in state_dict: if 'head.fc.weight' in state_dict:
@ -597,6 +689,9 @@ def checkpoint_filter_fn(state_dict, model):
if 'state_dict' in state_dict: if 'state_dict' in state_dict:
state_dict = state_dict['state_dict'] state_dict = state_dict['state_dict']
if 'vision_tower.convs.0.proj.weight' in state_dict:
return _convert_florence2(state_dict, model)
import re import re
out_dict = {} out_dict = {}
for k, v in state_dict.items(): for k, v in state_dict.items():
@ -615,13 +710,17 @@ def checkpoint_filter_fn(state_dict, model):
def _create_davit(variant, pretrained=False, **kwargs): def _create_davit(variant, pretrained=False, **kwargs):
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1)))) default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1))))
out_indices = kwargs.pop('out_indices', default_out_indices) out_indices = kwargs.pop('out_indices', default_out_indices)
strict = True
if variant.endswith('_fl'):
# FIXME cleaner approach to missing head norm?
strict = False
model = build_model_with_cfg( model = build_model_with_cfg(
DaVit, DaVit,
variant, variant,
pretrained, pretrained,
pretrained_filter_fn=checkpoint_filter_fn, pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
pretrained_strict=strict,
**kwargs) **kwargs)
return model return model
@ -650,6 +749,12 @@ default_cfgs = generate_default_cfgs({
'davit_large': _cfg(), 'davit_large': _cfg(),
'davit_huge': _cfg(), 'davit_huge': _cfg(),
'davit_giant': _cfg(), 'davit_giant': _cfg(),
'davit_base_fl.msft_florence2': _cfg(
hf_hub_id='microsoft/Florence-2-base',
num_classes=0, input_size=(3, 768, 768)),
'davit_huge_fl.msft_florence2': _cfg(
hf_hub_id='microsoft/Florence-2-large',
num_classes=0, input_size=(3, 768, 768)),
}) })
@ -687,3 +792,23 @@ def davit_huge(pretrained=False, **kwargs) -> DaVit:
def davit_giant(pretrained=False, **kwargs) -> DaVit: def davit_giant(pretrained=False, **kwargs) -> DaVit:
model_args = dict(depths=(1, 1, 12, 3), embed_dims=(384, 768, 1536, 3072), num_heads=(12, 24, 48, 96)) model_args = dict(depths=(1, 1, 12, 3), embed_dims=(384, 768, 1536, 3072), num_heads=(12, 24, 48, 96))
return _create_davit('davit_giant', pretrained=pretrained, **dict(model_args, **kwargs)) return _create_davit('davit_giant', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def davit_base_fl(pretrained=False, **kwargs) -> DaVit:
model_args = dict(
depths=(1, 1, 9, 1), embed_dims=(128, 256, 512, 1024), num_heads=(4, 8, 16, 32),
window_size=12, down_kernel_size=3, channel_attn_v2=True, named_blocks=True,
)
return _create_davit('davit_base_fl', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def davit_huge_fl(pretrained=False, **kwargs) -> DaVit:
# NOTE: huge image tower used in 'large' Florence2 model
model_args = dict(
depths=(1, 1, 9, 1), embed_dims=(256, 512, 1024, 2048), num_heads=(8, 16, 32, 64),
window_size=12, down_kernel_size=3, channel_attn_v2=True, named_blocks=True,
)
return _create_davit('davit_huge_fl', pretrained=pretrained, **dict(model_args, **kwargs))