Merge pull request #2213 from huggingface/florence2
Fix #2212 map florence2 image tower to davit with a few changespull/2217/head
commit
f8342a045a
|
@ -34,7 +34,14 @@ class ConvPosEnc(nn.Module):
|
|||
def __init__(self, dim: int, k: int = 3, act: bool = False):
|
||||
super(ConvPosEnc, self).__init__()
|
||||
|
||||
self.proj = nn.Conv2d(dim, dim, k, 1, k // 2, groups=dim)
|
||||
self.proj = nn.Conv2d(
|
||||
dim,
|
||||
dim,
|
||||
kernel_size=k,
|
||||
stride=1,
|
||||
padding=k // 2,
|
||||
groups=dim,
|
||||
)
|
||||
self.act = nn.GELU() if act else nn.Identity()
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
|
@ -72,8 +79,9 @@ class Stem(nn.Module):
|
|||
|
||||
def forward(self, x: Tensor):
|
||||
B, C, H, W = x.shape
|
||||
x = F.pad(x, (0, (self.stride[1] - W % self.stride[1]) % self.stride[1]))
|
||||
x = F.pad(x, (0, 0, 0, (self.stride[0] - H % self.stride[0]) % self.stride[0]))
|
||||
pad_r = (self.stride[1] - W % self.stride[1]) % self.stride[1]
|
||||
pad_b = (self.stride[0] - H % self.stride[0]) % self.stride[0]
|
||||
x = F.pad(x, (0, pad_r, 0, pad_b))
|
||||
x = self.conv(x)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
@ -84,6 +92,7 @@ class Downsample(nn.Module):
|
|||
self,
|
||||
in_chs,
|
||||
out_chs,
|
||||
kernel_size=3,
|
||||
norm_layer=LayerNorm2d,
|
||||
):
|
||||
super().__init__()
|
||||
|
@ -91,23 +100,58 @@ class Downsample(nn.Module):
|
|||
self.out_chs = out_chs
|
||||
|
||||
self.norm = norm_layer(in_chs)
|
||||
self.even_k = kernel_size % 2 == 0
|
||||
self.conv = nn.Conv2d(
|
||||
in_chs,
|
||||
out_chs,
|
||||
kernel_size=2,
|
||||
kernel_size=kernel_size,
|
||||
stride=2,
|
||||
padding=0,
|
||||
padding=0 if self.even_k else kernel_size // 2,
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
B, C, H, W = x.shape
|
||||
x = self.norm(x)
|
||||
x = F.pad(x, (0, (2 - W % 2) % 2))
|
||||
x = F.pad(x, (0, 0, 0, (2 - H % 2) % 2))
|
||||
if self.even_k:
|
||||
k_h, k_w = self.conv.kernel_size
|
||||
pad_r = (k_w - W % k_w) % k_w
|
||||
pad_b = (k_h - H % k_h) % k_h
|
||||
x = F.pad(x, (0, pad_r , 0, pad_b))
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class ChannelAttentionV2(nn.Module):
|
||||
|
||||
def __init__(self, dim, num_heads=8, qkv_bias=True, dynamic_scale=True):
|
||||
super().__init__()
|
||||
self.groups = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.dynamic_scale = dynamic_scale
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
def forward(self, x):
|
||||
B, N, C = x.shape
|
||||
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.groups, C // self.groups).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv.unbind(0)
|
||||
|
||||
if self.dynamic_scale:
|
||||
q = q * N ** -0.5
|
||||
else:
|
||||
q = q * self.head_dim ** -0.5
|
||||
attn = q.transpose(-1, -2) @ k
|
||||
attn = attn.softmax(dim=-1)
|
||||
x = (attn @ v.transpose(-1, -2)).transpose(-1, -2)
|
||||
|
||||
x = x.transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class ChannelAttention(nn.Module):
|
||||
|
||||
def __init__(self, dim, num_heads=8, qkv_bias=False):
|
||||
|
@ -147,13 +191,19 @@ class ChannelBlock(nn.Module):
|
|||
norm_layer=nn.LayerNorm,
|
||||
ffn=True,
|
||||
cpe_act=False,
|
||||
v2=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.cpe1 = ConvPosEnc(dim=dim, k=3, act=cpe_act)
|
||||
self.ffn = ffn
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = ChannelAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias)
|
||||
attn_layer = ChannelAttentionV2 if v2 else ChannelAttention
|
||||
self.attn = attn_layer(
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
)
|
||||
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.cpe2 = ConvPosEnc(dim=dim, k=3, act=cpe_act)
|
||||
|
||||
|
@ -372,13 +422,16 @@ class DaVitStage(nn.Module):
|
|||
attn_types=('spatial', 'channel'),
|
||||
num_heads=3,
|
||||
window_size=7,
|
||||
mlp_ratio=4,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=True,
|
||||
drop_path_rates=(0, 0),
|
||||
norm_layer=LayerNorm2d,
|
||||
norm_layer_cl=nn.LayerNorm,
|
||||
ffn=True,
|
||||
cpe_act=False
|
||||
cpe_act=False,
|
||||
down_kernel_size=2,
|
||||
named_blocks=False,
|
||||
channel_attn_v2=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
@ -386,7 +439,7 @@ class DaVitStage(nn.Module):
|
|||
|
||||
# downsample embedding layer at the beginning of each stage
|
||||
if downsample:
|
||||
self.downsample = Downsample(in_chs, out_chs, norm_layer=norm_layer)
|
||||
self.downsample = Downsample(in_chs, out_chs, kernel_size=down_kernel_size, norm_layer=norm_layer)
|
||||
else:
|
||||
self.downsample = nn.Identity()
|
||||
|
||||
|
@ -399,10 +452,11 @@ class DaVitStage(nn.Module):
|
|||
'''
|
||||
stage_blocks = []
|
||||
for block_idx in range(depth):
|
||||
from collections import OrderedDict
|
||||
dual_attention_block = []
|
||||
for attn_idx, attn_type in enumerate(attn_types):
|
||||
if attn_type == 'spatial':
|
||||
dual_attention_block.append(SpatialBlock(
|
||||
dual_attention_block.append(('spatial_block', SpatialBlock(
|
||||
dim=out_chs,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
|
@ -412,9 +466,9 @@ class DaVitStage(nn.Module):
|
|||
ffn=ffn,
|
||||
cpe_act=cpe_act,
|
||||
window_size=window_size,
|
||||
))
|
||||
)))
|
||||
elif attn_type == 'channel':
|
||||
dual_attention_block.append(ChannelBlock(
|
||||
dual_attention_block.append(('channel_block', ChannelBlock(
|
||||
dim=out_chs,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
|
@ -422,9 +476,13 @@ class DaVitStage(nn.Module):
|
|||
drop_path=drop_path_rates[block_idx],
|
||||
norm_layer=norm_layer_cl,
|
||||
ffn=ffn,
|
||||
cpe_act=cpe_act
|
||||
))
|
||||
stage_blocks.append(nn.Sequential(*dual_attention_block))
|
||||
cpe_act=cpe_act,
|
||||
v2=channel_attn_v2,
|
||||
)))
|
||||
if named_blocks:
|
||||
stage_blocks.append(nn.Sequential(OrderedDict(dual_attention_block)))
|
||||
else:
|
||||
stage_blocks.append(nn.Sequential(*[b[1] for b in dual_attention_block]))
|
||||
self.blocks = nn.Sequential(*stage_blocks)
|
||||
|
||||
@torch.jit.ignore
|
||||
|
@ -473,6 +531,9 @@ class DaVit(nn.Module):
|
|||
attn_types=('spatial', 'channel'),
|
||||
ffn=True,
|
||||
cpe_act=False,
|
||||
down_kernel_size=2,
|
||||
channel_attn_v2=False,
|
||||
named_blocks=False,
|
||||
drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
num_classes=1000,
|
||||
|
@ -512,6 +573,9 @@ class DaVit(nn.Module):
|
|||
norm_layer_cl=norm_layer_cl,
|
||||
ffn=ffn,
|
||||
cpe_act=cpe_act,
|
||||
down_kernel_size=down_kernel_size,
|
||||
channel_attn_v2=channel_attn_v2,
|
||||
named_blocks=named_blocks,
|
||||
)
|
||||
in_chs = out_chs
|
||||
stages.append(stage)
|
||||
|
@ -589,6 +653,34 @@ class DaVit(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
def _convert_florence2(state_dict, model, prefix='vision_tower.'):
|
||||
import re
|
||||
out_dict = {}
|
||||
|
||||
for k, v in state_dict.items():
|
||||
if k.startswith(prefix):
|
||||
k = k.replace(prefix, '')
|
||||
else:
|
||||
continue
|
||||
k = re.sub(r'convs.([0-9]+)', r'stages.\1.downsample', k)
|
||||
k = re.sub(r'blocks.([0-9]+)', r'stages.\1.blocks', k)
|
||||
k = k.replace('downsample.proj', 'downsample.conv')
|
||||
k = k.replace('stages.0.downsample', 'stem')
|
||||
#k = k.replace('head.', 'head.fc.')
|
||||
#k = k.replace('norms.', 'head.norm.')
|
||||
k = k.replace('window_attn.norm.', 'norm1.')
|
||||
k = k.replace('window_attn.fn.', 'attn.')
|
||||
k = k.replace('channel_attn.norm.', 'norm1.')
|
||||
k = k.replace('channel_attn.fn.', 'attn.')
|
||||
k = k.replace('ffn.norm.', 'norm2.')
|
||||
k = k.replace('ffn.fn.net.', 'mlp.')
|
||||
k = k.replace('conv1.fn.dw', 'cpe1.proj')
|
||||
k = k.replace('conv2.fn.dw', 'cpe2.proj')
|
||||
out_dict[k] = v
|
||||
|
||||
return out_dict
|
||||
|
||||
|
||||
def checkpoint_filter_fn(state_dict, model):
|
||||
""" Remap MSFT checkpoints -> timm """
|
||||
if 'head.fc.weight' in state_dict:
|
||||
|
@ -597,6 +689,9 @@ def checkpoint_filter_fn(state_dict, model):
|
|||
if 'state_dict' in state_dict:
|
||||
state_dict = state_dict['state_dict']
|
||||
|
||||
if 'vision_tower.convs.0.proj.weight' in state_dict:
|
||||
return _convert_florence2(state_dict, model)
|
||||
|
||||
import re
|
||||
out_dict = {}
|
||||
for k, v in state_dict.items():
|
||||
|
@ -615,13 +710,17 @@ def checkpoint_filter_fn(state_dict, model):
|
|||
def _create_davit(variant, pretrained=False, **kwargs):
|
||||
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1))))
|
||||
out_indices = kwargs.pop('out_indices', default_out_indices)
|
||||
|
||||
strict = True
|
||||
if variant.endswith('_fl'):
|
||||
# FIXME cleaner approach to missing head norm?
|
||||
strict = False
|
||||
model = build_model_with_cfg(
|
||||
DaVit,
|
||||
variant,
|
||||
pretrained,
|
||||
pretrained_filter_fn=checkpoint_filter_fn,
|
||||
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
|
||||
pretrained_strict=strict,
|
||||
**kwargs)
|
||||
|
||||
return model
|
||||
|
@ -650,6 +749,12 @@ default_cfgs = generate_default_cfgs({
|
|||
'davit_large': _cfg(),
|
||||
'davit_huge': _cfg(),
|
||||
'davit_giant': _cfg(),
|
||||
'davit_base_fl.msft_florence2': _cfg(
|
||||
hf_hub_id='microsoft/Florence-2-base',
|
||||
num_classes=0, input_size=(3, 768, 768)),
|
||||
'davit_huge_fl.msft_florence2': _cfg(
|
||||
hf_hub_id='microsoft/Florence-2-large',
|
||||
num_classes=0, input_size=(3, 768, 768)),
|
||||
})
|
||||
|
||||
|
||||
|
@ -687,3 +792,23 @@ def davit_huge(pretrained=False, **kwargs) -> DaVit:
|
|||
def davit_giant(pretrained=False, **kwargs) -> DaVit:
|
||||
model_args = dict(depths=(1, 1, 12, 3), embed_dims=(384, 768, 1536, 3072), num_heads=(12, 24, 48, 96))
|
||||
return _create_davit('davit_giant', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
|
||||
@register_model
|
||||
def davit_base_fl(pretrained=False, **kwargs) -> DaVit:
|
||||
model_args = dict(
|
||||
depths=(1, 1, 9, 1), embed_dims=(128, 256, 512, 1024), num_heads=(4, 8, 16, 32),
|
||||
window_size=12, down_kernel_size=3, channel_attn_v2=True, named_blocks=True,
|
||||
)
|
||||
return _create_davit('davit_base_fl', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def davit_huge_fl(pretrained=False, **kwargs) -> DaVit:
|
||||
# NOTE: huge image tower used in 'large' Florence2 model
|
||||
model_args = dict(
|
||||
depths=(1, 1, 9, 1), embed_dims=(256, 512, 1024, 2048), num_heads=(8, 16, 32, 64),
|
||||
window_size=12, down_kernel_size=3, channel_attn_v2=True, named_blocks=True,
|
||||
)
|
||||
return _create_davit('davit_huge_fl', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
|
Loading…
Reference in New Issue