Fix load of larger ResNet CLIP models, experimenting with making AttentionPool *the* head, seems to fine-tune better, one less layer.

This commit is contained in:
Ross Wightman 2024-06-10 12:07:14 -07:00
parent 5e9ff5798f
commit 30ffa152de
2 changed files with 117 additions and 65 deletions

View File

@ -41,9 +41,10 @@ class RotAttentionPool2d(nn.Module):
num_heads: Optional[int] = None,
qkv_bias: bool = True,
qkv_separate: bool = False,
drop: float = 0.,
):
super().__init__()
embed_dim = embed_dim or in_features
self.embed_dim = embed_dim = embed_dim or in_features
self.in_features = in_features
self.out_features = out_features or in_features
ref_feat_size = to_2tuple(ref_feat_size)
@ -82,7 +83,7 @@ class RotAttentionPool2d(nn.Module):
trunc_normal_(self.qkv.weight, std=in_features ** -0.5)
nn.init.zeros_(self.qkv.bias)
def forward(self, x):
def forward(self, x, pre_logits: bool = False):
B, _, H, W = x.shape
N = H * W
x = x.flatten(2).transpose(1, 2)
@ -107,8 +108,12 @@ class RotAttentionPool2d(nn.Module):
attn = attn.softmax(dim=-1)
x = attn @ v
x = x.transpose(1, 2).reshape(B, N + 1, -1)
x = x[:, 0]
x = self.drop(x)
if pre_logits:
return x
x = self.proj(x)
return x[:, 0]
return x
class AttentionPool2d(nn.Module):
@ -132,9 +137,10 @@ class AttentionPool2d(nn.Module):
num_heads: Optional[int] = None,
qkv_bias: bool = True,
qkv_separate: bool = False,
drop: float = 0.,
):
super().__init__()
embed_dim = embed_dim or in_features
self.embed_dim = embed_dim = embed_dim or in_features
self.in_features = in_features
self.out_features = out_features or in_features
if num_heads is not None:
@ -158,6 +164,7 @@ class AttentionPool2d(nn.Module):
else:
self.q = self.k = self.v = None
self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias)
self.drop = nn.Dropout(drop)
self.proj = nn.Linear(embed_dim, self.out_features)
self.pos_embed = nn.Parameter(torch.zeros(self.seq_len + 1, in_features))
@ -178,15 +185,12 @@ class AttentionPool2d(nn.Module):
nn.init.zeros_(self.qkv.bias)
trunc_normal_(self.pos_embed, std=in_features ** -0.5)
def forward(self, x):
def forward(self, x, pre_logits: bool = False):
B, _, H, W = x.shape
N = H * W
x = x.flatten(2).transpose(1, 2)
x = torch.cat([x.mean(1, keepdim=True), x], dim=1)
if self.seq_len != N:
pos_embed = resample_abs_pos_embed(self.pos_embed.unsqueeze(0), (H, W), num_prefix_tokens=1)
else:
pos_embed = self.pos_embed.unsqueeze(0).to(x.dtype)
pos_embed = resample_abs_pos_embed(self.pos_embed.unsqueeze(0), (H, W), num_prefix_tokens=1)
x = x + pos_embed
if self.qkv is None:
@ -205,5 +209,9 @@ class AttentionPool2d(nn.Module):
attn = attn.softmax(dim=-1)
x = attn @ v
x = x.transpose(1, 2).reshape(B, N + 1, -1)
x = x[:, 0]
x = self.drop(x)
if pre_logits:
return x
x = self.proj(x)
return x[:, 0]
return x

View File

@ -37,7 +37,7 @@ import torch
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from timm.layers import ClassifierHead, ConvNormAct, ConvNormActAa, BatchNormAct2d, DropPath, AvgPool2dSame, \
from timm.layers import ClassifierHead, ConvNormAct, BatchNormAct2d, DropPath, AvgPool2dSame, \
create_conv2d, get_act_layer, get_norm_act_layer, get_attn, make_divisible, to_2tuple, EvoNorm2dS0a
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
@ -82,7 +82,6 @@ class ByoModelCfg:
aa_layer: str = ''
# Head config
attn_pool: str = ''
head_hidden_size: Optional[int] = None # feat dim of MLP head or AttentionPool output
head_type: str = ''
@ -304,7 +303,10 @@ class BottleneckBlock(nn.Module):
mid_chs = make_divisible((in_chs if bottle_in else out_chs) * bottle_ratio)
groups = num_groups(group_size, mid_chs)
self.shortcut = create_shortcut(
downsample, in_chs, out_chs,
stride=stride, dilation=dilation, apply_act=False, layers=layers,
)
self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1)
self.conv2_kxk = layers.conv_norm_act(
@ -321,10 +323,7 @@ class BottleneckBlock(nn.Module):
self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
self.shortcut = create_shortcut(
downsample, in_chs, out_chs,
stride=stride, dilation=dilation, apply_act=False, layers=layers,
)
def init_weights(self, zero_init_last: bool = False):
if zero_init_last and self.shortcut is not None and getattr(self.conv3_1x1.bn, 'weight', None) is not None:
nn.init.zeros_(self.conv3_1x1.bn.weight)
@ -1165,7 +1164,7 @@ def get_layer_fns(cfg: ByoModelCfg, allow_aa: bool = True):
act = get_act_layer(cfg.act_layer)
norm_act = get_norm_act_layer(norm_layer=cfg.norm_layer, act_layer=act)
if cfg.aa_layer and allow_aa:
conv_norm_act = partial(ConvNormActAa, norm_layer=cfg.norm_layer, act_layer=act, aa_layer=cfg.aa_layer)
conv_norm_act = partial(ConvNormAct, norm_layer=cfg.norm_layer, act_layer=act, aa_layer=cfg.aa_layer)
else:
conv_norm_act = partial(ConvNormAct, norm_layer=cfg.norm_layer, act_layer=act)
attn = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None
@ -1258,6 +1257,7 @@ class ByobNet(nn.Module):
self.stage_ends = [f['stage'] for f in self.feature_info]
self.head_hidden_size = self.num_features
self.global_pool = global_pool
assert cfg.head_type in ('', 'classifier', 'norm_mlp_classifier')
if cfg.head_type == 'norm_mlp_classifier':
from timm.layers import NormMlpClassifierHead
@ -1272,33 +1272,61 @@ class ByobNet(nn.Module):
)
self.head_hidden_size = self.head.hidden_size
else:
if cfg.attn_pool == 'abs':
from timm.layers import AttentionPool2d
self.attn_pool = AttentionPool2d(
self.num_features,
out_features=cfg.head_hidden_size,
feat_size=feat_size,
qkv_separate=True,
)
self.head_hidden_size = self.attn_pool.out_features
elif cfg.attn_pool == 'rot':
from timm.layers import RotAttentionPool2d
self.attn_pool = RotAttentionPool2d(
self.num_features,
out_features=cfg.head_hidden_size,
ref_feat_size=feat_size,
)
self.head_hidden_size = self.attn_pool.out_features
else:
assert not cfg.attn_pool
self.attn_pool = nn.Identity()
# FIXME evaluating different head vs pool configurations
if False:
if global_pool == 'attn_abs':
from timm.layers import AttentionPool2d
self.attn_pool = AttentionPool2d(
self.num_features,
out_features=cfg.head_hidden_size,
feat_size=feat_size,
qkv_separate=True,
)
global_pool = '' # clear for ClassifierHead
self.head_hidden_size = self.attn_pool.out_features
elif global_pool =='attn_rot':
from timm.layers import RotAttentionPool2d
self.attn_pool = RotAttentionPool2d(
self.num_features,
out_features=cfg.head_hidden_size,
ref_feat_size=feat_size,
)
global_pool = '' # clear for ClassifierHead
self.head_hidden_size = self.attn_pool.out_features
else:
self.attn_pool = nn.Identity()
self.head = ClassifierHead(
self.head_hidden_size,
num_classes,
pool_type='' if cfg.attn_pool else global_pool,
drop_rate=self.drop_rate,
)
self.head = ClassifierHead(
self.head_hidden_size,
num_classes,
pool_type=global_pool,
drop_rate=self.drop_rate,
)
else:
if global_pool == 'attn_abs':
from timm.layers import AttentionPool2d
self.head = AttentionPool2d(
self.num_features,
out_features=num_classes,
feat_size=feat_size,
qkv_separate=True,
)
self.head_hidden_size = self.head.embed_dim
elif global_pool == 'attn_rot':
from timm.layers import RotAttentionPool2d
self.head = RotAttentionPool2d(
self.num_features,
out_features=num_classes,
ref_feat_size=feat_size,
)
self.head_hidden_size = self.head.embed_dim
else:
self.head = ClassifierHead(
self.num_features,
num_classes,
pool_type=global_pool,
drop_rate=self.drop_rate,
)
# init weights
named_apply(partial(_init_weights, zero_init_last=zero_init_last), self)
@ -1324,6 +1352,9 @@ class ByobNet(nn.Module):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
if global_pool is not None:
if self.global_pool in ('attn_abs', 'attn_rot'):
raise RuntimeError('Cannot change attention pool on head reset.')
self.head.reset(num_classes, global_pool)
def forward_intermediates(
@ -1413,7 +1444,7 @@ class ByobNet(nn.Module):
return x
def forward_head(self, x, pre_logits: bool = False):
x = self.attn_pool(x)
#x = self.attn_pool(x)
return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
def forward(self, x):
@ -1916,7 +1947,6 @@ model_cfgs = dict(
stem_pool='avg2',
downsample='avg',
aa_layer='avg',
attn_pool='abs',
head_hidden_size=1024,
),
@ -1932,7 +1962,6 @@ model_cfgs = dict(
stem_pool='avg2',
downsample='avg',
aa_layer='avg',
attn_pool='abs',
head_hidden_size=512,
),
@ -1949,7 +1978,6 @@ model_cfgs = dict(
stem_pool='avg2',
downsample='avg',
aa_layer='avg',
attn_pool='abs',
head_hidden_size=640,
),
@ -1960,12 +1988,12 @@ model_cfgs = dict(
ByoBlockCfg(type='bottle', d=18, c=1024, s=2, br=0.25),
ByoBlockCfg(type='bottle', d=8, c=2048, s=2, br=0.25),
),
width_factor=1.5,
stem_chs=(32, 32, 64),
stem_type='',
stem_pool='avg2',
downsample='avg',
aa_layer='avg',
attn_pool='abs',
head_hidden_size=768,
),
@ -1976,12 +2004,12 @@ model_cfgs = dict(
ByoBlockCfg(type='bottle', d=36, c=1024, s=2, br=0.25),
ByoBlockCfg(type='bottle', d=10, c=2048, s=2, br=0.25),
),
width_factor=2.0,
stem_chs=(32, 32, 64),
stem_type='',
stem_pool='avg2',
downsample='avg',
aa_layer='avg',
attn_pool='abs',
head_hidden_size=1024,
),
@ -2029,10 +2057,10 @@ def _convert_openai_clip(
continue
k = re.sub(rf'{prefix}conv([0-9])', r'stem.conv\1.conv', k)
k = re.sub(rf'{prefix}bn([0-9])', r'stem.conv\1.bn', k)
k = re.sub(rf'{prefix}layer([0-9])\.([0-9])\.([a-z]+)([0-9])', _stage_sub, k)
k = re.sub(rf'{prefix}layer([0-9])\.([0-9])\.downsample\.([0-9])', _down_sub, k)
k = re.sub(rf'{prefix}layer([0-9])\.([0-9]+)\.([a-z]+)([0-9])', _stage_sub, k)
k = re.sub(rf'{prefix}layer([0-9])\.([0-9]+)\.downsample\.([0-9])', _down_sub, k)
if k.startswith(f'{prefix}attnpool'):
k = k.replace(prefix + 'attnpool', 'attn_pool')
k = k.replace(prefix + 'attnpool', 'head') #'attn_pool')
k = k.replace('positional_embedding', 'pos_embed')
k = k.replace('q_proj', 'q')
k = k.replace('k_proj', 'k')
@ -2053,13 +2081,19 @@ def checkpoint_filter_fn(
def _create_byobnet(variant, pretrained=False, **kwargs):
strict = True
if 'clip' in variant and kwargs.get('global_pool', None) != 'attn_abs':
# NOTE: a hack to allow removing attention pool from CLIP ResNet variants
strict = False
return build_model_with_cfg(
ByobNet, variant, pretrained,
model_cfg=model_cfgs[variant],
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(flatten_sequential=True),
#pretrained_strict=False,
**kwargs)
pretrained_strict=strict,
**kwargs,
)
def _cfg(url='', **kwargs):
@ -2257,31 +2291,36 @@ default_cfgs = generate_default_cfgs({
hf_hub_id='timm/',
hf_hub_filename='open_clip_pytorch_model.bin',
num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
fixed_input_size=True, input_size=(3, 224, 224), pool_size=(7, 7)
fixed_input_size=True, input_size=(3, 224, 224), pool_size=(7, 7),
classifier = 'head.proj',
),
'resnet101_clip.openai': _cfgr(
hf_hub_id='timm/',
hf_hub_filename='open_clip_pytorch_model.bin',
num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
fixed_input_size=True, input_size=(3, 224, 224), pool_size=(7, 7)
fixed_input_size=True, input_size=(3, 224, 224), pool_size=(7, 7),
classifier='head.proj',
),
'resnet50x4_clip.openai': _cfgr(
hf_hub_id='timm/',
hf_hub_filename='open_clip_pytorch_model.bin',
num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
fixed_input_size=True, input_size=(3, 288, 288), pool_size=(9, 9)
fixed_input_size=True, input_size=(3, 288, 288), pool_size=(9, 9),
classifier = 'head.proj',
),
'resnet50x16_clip.openai': _cfgr(
hf_hub_id='timm/',
hf_hub_filename='open_clip_pytorch_model.bin',
num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
fixed_input_size=True, input_size=(3, 384, 384), pool_size=(12, 12)
fixed_input_size=True, input_size=(3, 384, 384), pool_size=(12, 12),
classifier = 'head.proj',
),
'resnet50x64_clip.openai': _cfgr(
hf_hub_id='timm/',
hf_hub_filename='open_clip_pytorch_model.bin',
num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
fixed_input_size=True, input_size=(3, 448, 448), pool_size=(14, 14)
fixed_input_size=True, input_size=(3, 448, 448), pool_size=(14, 14),
classifier = 'head.proj',
),
})
@ -2592,35 +2631,40 @@ def mobileone_s4(pretrained=False, **kwargs) -> ByobNet:
def resnet50_clip(pretrained=False, **kwargs) -> ByobNet:
""" OpenAI Modified ResNet-50 CLIP image tower
"""
return _create_byobnet('resnet50_clip', pretrained=pretrained, **kwargs)
model_args = dict(global_pool='attn_abs')
return _create_byobnet('resnet50_clip', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def resnet101_clip(pretrained=False, **kwargs) -> ByobNet:
""" OpenAI Modified ResNet-101 CLIP image tower
"""
return _create_byobnet('resnet101_clip', pretrained=pretrained, **kwargs)
model_args = dict(global_pool='attn_abs')
return _create_byobnet('resnet101_clip', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def resnet50x4_clip(pretrained=False, **kwargs) -> ByobNet:
""" OpenAI Modified ResNet-50x4 CLIP image tower
"""
return _create_byobnet('resnet50x4_clip', pretrained=pretrained, **kwargs)
model_args = dict(global_pool='attn_abs')
return _create_byobnet('resnet50x4_clip', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def resnet50x16_clip(pretrained=False, **kwargs) -> ByobNet:
""" OpenAI Modified ResNet-50x16 CLIP image tower
"""
return _create_byobnet('resnet50x16_clip', pretrained=pretrained, **kwargs)
model_args = dict(global_pool='attn_abs')
return _create_byobnet('resnet50x16_clip', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def resnet50x64_clip(pretrained=False, **kwargs) -> ByobNet:
""" OpenAI Modified ResNet-50x64 CLIP image tower
"""
return _create_byobnet('resnet50x64_clip', pretrained=pretrained, **kwargs)
model_args = dict(global_pool='attn_abs')
return _create_byobnet('resnet50x64_clip', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model