diff --git a/timm/layers/attention_pool2d.py b/timm/layers/attention_pool2d.py index dc594b70..443a384c 100644 --- a/timm/layers/attention_pool2d.py +++ b/timm/layers/attention_pool2d.py @@ -41,9 +41,10 @@ class RotAttentionPool2d(nn.Module): num_heads: Optional[int] = None, qkv_bias: bool = True, qkv_separate: bool = False, + drop: float = 0., ): super().__init__() - embed_dim = embed_dim or in_features + self.embed_dim = embed_dim = embed_dim or in_features self.in_features = in_features self.out_features = out_features or in_features ref_feat_size = to_2tuple(ref_feat_size) @@ -82,7 +83,7 @@ class RotAttentionPool2d(nn.Module): trunc_normal_(self.qkv.weight, std=in_features ** -0.5) nn.init.zeros_(self.qkv.bias) - def forward(self, x): + def forward(self, x, pre_logits: bool = False): B, _, H, W = x.shape N = H * W x = x.flatten(2).transpose(1, 2) @@ -107,8 +108,12 @@ class RotAttentionPool2d(nn.Module): attn = attn.softmax(dim=-1) x = attn @ v x = x.transpose(1, 2).reshape(B, N + 1, -1) + x = x[:, 0] + x = self.drop(x) + if pre_logits: + return x x = self.proj(x) - return x[:, 0] + return x class AttentionPool2d(nn.Module): @@ -132,9 +137,10 @@ class AttentionPool2d(nn.Module): num_heads: Optional[int] = None, qkv_bias: bool = True, qkv_separate: bool = False, + drop: float = 0., ): super().__init__() - embed_dim = embed_dim or in_features + self.embed_dim = embed_dim = embed_dim or in_features self.in_features = in_features self.out_features = out_features or in_features if num_heads is not None: @@ -158,6 +164,7 @@ class AttentionPool2d(nn.Module): else: self.q = self.k = self.v = None self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias) + self.drop = nn.Dropout(drop) self.proj = nn.Linear(embed_dim, self.out_features) self.pos_embed = nn.Parameter(torch.zeros(self.seq_len + 1, in_features)) @@ -178,15 +185,12 @@ class AttentionPool2d(nn.Module): nn.init.zeros_(self.qkv.bias) trunc_normal_(self.pos_embed, std=in_features ** -0.5) - def forward(self, x): + def forward(self, x, pre_logits: bool = False): B, _, H, W = x.shape N = H * W x = x.flatten(2).transpose(1, 2) x = torch.cat([x.mean(1, keepdim=True), x], dim=1) - if self.seq_len != N: - pos_embed = resample_abs_pos_embed(self.pos_embed.unsqueeze(0), (H, W), num_prefix_tokens=1) - else: - pos_embed = self.pos_embed.unsqueeze(0).to(x.dtype) + pos_embed = resample_abs_pos_embed(self.pos_embed.unsqueeze(0), (H, W), num_prefix_tokens=1) x = x + pos_embed if self.qkv is None: @@ -205,5 +209,9 @@ class AttentionPool2d(nn.Module): attn = attn.softmax(dim=-1) x = attn @ v x = x.transpose(1, 2).reshape(B, N + 1, -1) + x = x[:, 0] + x = self.drop(x) + if pre_logits: + return x x = self.proj(x) - return x[:, 0] + return x diff --git a/timm/models/byobnet.py b/timm/models/byobnet.py index b9417dfe..f84f2455 100644 --- a/timm/models/byobnet.py +++ b/timm/models/byobnet.py @@ -37,7 +37,7 @@ import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD -from timm.layers import ClassifierHead, ConvNormAct, ConvNormActAa, BatchNormAct2d, DropPath, AvgPool2dSame, \ +from timm.layers import ClassifierHead, ConvNormAct, BatchNormAct2d, DropPath, AvgPool2dSame, \ create_conv2d, get_act_layer, get_norm_act_layer, get_attn, make_divisible, to_2tuple, EvoNorm2dS0a from ._builder import build_model_with_cfg from ._features import feature_take_indices @@ -82,7 +82,6 @@ class ByoModelCfg: aa_layer: str = '' # Head config - attn_pool: str = '' head_hidden_size: Optional[int] = None # feat dim of MLP head or AttentionPool output head_type: str = '' @@ -304,7 +303,10 @@ class BottleneckBlock(nn.Module): mid_chs = make_divisible((in_chs if bottle_in else out_chs) * bottle_ratio) groups = num_groups(group_size, mid_chs) - + self.shortcut = create_shortcut( + downsample, in_chs, out_chs, + stride=stride, dilation=dilation, apply_act=False, layers=layers, + ) self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1) self.conv2_kxk = layers.conv_norm_act( @@ -321,10 +323,7 @@ class BottleneckBlock(nn.Module): self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs) self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() self.act = nn.Identity() if linear_out else layers.act(inplace=True) - self.shortcut = create_shortcut( - downsample, in_chs, out_chs, - stride=stride, dilation=dilation, apply_act=False, layers=layers, - ) + def init_weights(self, zero_init_last: bool = False): if zero_init_last and self.shortcut is not None and getattr(self.conv3_1x1.bn, 'weight', None) is not None: nn.init.zeros_(self.conv3_1x1.bn.weight) @@ -1165,7 +1164,7 @@ def get_layer_fns(cfg: ByoModelCfg, allow_aa: bool = True): act = get_act_layer(cfg.act_layer) norm_act = get_norm_act_layer(norm_layer=cfg.norm_layer, act_layer=act) if cfg.aa_layer and allow_aa: - conv_norm_act = partial(ConvNormActAa, norm_layer=cfg.norm_layer, act_layer=act, aa_layer=cfg.aa_layer) + conv_norm_act = partial(ConvNormAct, norm_layer=cfg.norm_layer, act_layer=act, aa_layer=cfg.aa_layer) else: conv_norm_act = partial(ConvNormAct, norm_layer=cfg.norm_layer, act_layer=act) attn = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None @@ -1258,6 +1257,7 @@ class ByobNet(nn.Module): self.stage_ends = [f['stage'] for f in self.feature_info] self.head_hidden_size = self.num_features + self.global_pool = global_pool assert cfg.head_type in ('', 'classifier', 'norm_mlp_classifier') if cfg.head_type == 'norm_mlp_classifier': from timm.layers import NormMlpClassifierHead @@ -1272,33 +1272,61 @@ class ByobNet(nn.Module): ) self.head_hidden_size = self.head.hidden_size else: - if cfg.attn_pool == 'abs': - from timm.layers import AttentionPool2d - self.attn_pool = AttentionPool2d( - self.num_features, - out_features=cfg.head_hidden_size, - feat_size=feat_size, - qkv_separate=True, - ) - self.head_hidden_size = self.attn_pool.out_features - elif cfg.attn_pool == 'rot': - from timm.layers import RotAttentionPool2d - self.attn_pool = RotAttentionPool2d( - self.num_features, - out_features=cfg.head_hidden_size, - ref_feat_size=feat_size, - ) - self.head_hidden_size = self.attn_pool.out_features - else: - assert not cfg.attn_pool - self.attn_pool = nn.Identity() + # FIXME evaluating different head vs pool configurations + if False: + if global_pool == 'attn_abs': + from timm.layers import AttentionPool2d + self.attn_pool = AttentionPool2d( + self.num_features, + out_features=cfg.head_hidden_size, + feat_size=feat_size, + qkv_separate=True, + ) + global_pool = '' # clear for ClassifierHead + self.head_hidden_size = self.attn_pool.out_features + elif global_pool =='attn_rot': + from timm.layers import RotAttentionPool2d + self.attn_pool = RotAttentionPool2d( + self.num_features, + out_features=cfg.head_hidden_size, + ref_feat_size=feat_size, + ) + global_pool = '' # clear for ClassifierHead + self.head_hidden_size = self.attn_pool.out_features + else: + self.attn_pool = nn.Identity() - self.head = ClassifierHead( - self.head_hidden_size, - num_classes, - pool_type='' if cfg.attn_pool else global_pool, - drop_rate=self.drop_rate, - ) + self.head = ClassifierHead( + self.head_hidden_size, + num_classes, + pool_type=global_pool, + drop_rate=self.drop_rate, + ) + else: + if global_pool == 'attn_abs': + from timm.layers import AttentionPool2d + self.head = AttentionPool2d( + self.num_features, + out_features=num_classes, + feat_size=feat_size, + qkv_separate=True, + ) + self.head_hidden_size = self.head.embed_dim + elif global_pool == 'attn_rot': + from timm.layers import RotAttentionPool2d + self.head = RotAttentionPool2d( + self.num_features, + out_features=num_classes, + ref_feat_size=feat_size, + ) + self.head_hidden_size = self.head.embed_dim + else: + self.head = ClassifierHead( + self.num_features, + num_classes, + pool_type=global_pool, + drop_rate=self.drop_rate, + ) # init weights named_apply(partial(_init_weights, zero_init_last=zero_init_last), self) @@ -1324,6 +1352,9 @@ class ByobNet(nn.Module): def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.num_classes = num_classes + if global_pool is not None: + if self.global_pool in ('attn_abs', 'attn_rot'): + raise RuntimeError('Cannot change attention pool on head reset.') self.head.reset(num_classes, global_pool) def forward_intermediates( @@ -1413,7 +1444,7 @@ class ByobNet(nn.Module): return x def forward_head(self, x, pre_logits: bool = False): - x = self.attn_pool(x) + #x = self.attn_pool(x) return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x) def forward(self, x): @@ -1916,7 +1947,6 @@ model_cfgs = dict( stem_pool='avg2', downsample='avg', aa_layer='avg', - attn_pool='abs', head_hidden_size=1024, ), @@ -1932,7 +1962,6 @@ model_cfgs = dict( stem_pool='avg2', downsample='avg', aa_layer='avg', - attn_pool='abs', head_hidden_size=512, ), @@ -1949,7 +1978,6 @@ model_cfgs = dict( stem_pool='avg2', downsample='avg', aa_layer='avg', - attn_pool='abs', head_hidden_size=640, ), @@ -1960,12 +1988,12 @@ model_cfgs = dict( ByoBlockCfg(type='bottle', d=18, c=1024, s=2, br=0.25), ByoBlockCfg(type='bottle', d=8, c=2048, s=2, br=0.25), ), + width_factor=1.5, stem_chs=(32, 32, 64), stem_type='', stem_pool='avg2', downsample='avg', aa_layer='avg', - attn_pool='abs', head_hidden_size=768, ), @@ -1976,12 +2004,12 @@ model_cfgs = dict( ByoBlockCfg(type='bottle', d=36, c=1024, s=2, br=0.25), ByoBlockCfg(type='bottle', d=10, c=2048, s=2, br=0.25), ), + width_factor=2.0, stem_chs=(32, 32, 64), stem_type='', stem_pool='avg2', downsample='avg', aa_layer='avg', - attn_pool='abs', head_hidden_size=1024, ), @@ -2029,10 +2057,10 @@ def _convert_openai_clip( continue k = re.sub(rf'{prefix}conv([0-9])', r'stem.conv\1.conv', k) k = re.sub(rf'{prefix}bn([0-9])', r'stem.conv\1.bn', k) - k = re.sub(rf'{prefix}layer([0-9])\.([0-9])\.([a-z]+)([0-9])', _stage_sub, k) - k = re.sub(rf'{prefix}layer([0-9])\.([0-9])\.downsample\.([0-9])', _down_sub, k) + k = re.sub(rf'{prefix}layer([0-9])\.([0-9]+)\.([a-z]+)([0-9])', _stage_sub, k) + k = re.sub(rf'{prefix}layer([0-9])\.([0-9]+)\.downsample\.([0-9])', _down_sub, k) if k.startswith(f'{prefix}attnpool'): - k = k.replace(prefix + 'attnpool', 'attn_pool') + k = k.replace(prefix + 'attnpool', 'head') #'attn_pool') k = k.replace('positional_embedding', 'pos_embed') k = k.replace('q_proj', 'q') k = k.replace('k_proj', 'k') @@ -2053,13 +2081,19 @@ def checkpoint_filter_fn( def _create_byobnet(variant, pretrained=False, **kwargs): + strict = True + if 'clip' in variant and kwargs.get('global_pool', None) != 'attn_abs': + # NOTE: a hack to allow removing attention pool from CLIP ResNet variants + strict = False + return build_model_with_cfg( ByobNet, variant, pretrained, model_cfg=model_cfgs[variant], pretrained_filter_fn=checkpoint_filter_fn, feature_cfg=dict(flatten_sequential=True), - #pretrained_strict=False, - **kwargs) + pretrained_strict=strict, + **kwargs, + ) def _cfg(url='', **kwargs): @@ -2257,31 +2291,36 @@ default_cfgs = generate_default_cfgs({ hf_hub_id='timm/', hf_hub_filename='open_clip_pytorch_model.bin', num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, - fixed_input_size=True, input_size=(3, 224, 224), pool_size=(7, 7) + fixed_input_size=True, input_size=(3, 224, 224), pool_size=(7, 7), + classifier = 'head.proj', ), 'resnet101_clip.openai': _cfgr( hf_hub_id='timm/', hf_hub_filename='open_clip_pytorch_model.bin', num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, - fixed_input_size=True, input_size=(3, 224, 224), pool_size=(7, 7) + fixed_input_size=True, input_size=(3, 224, 224), pool_size=(7, 7), + classifier='head.proj', ), 'resnet50x4_clip.openai': _cfgr( hf_hub_id='timm/', hf_hub_filename='open_clip_pytorch_model.bin', num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, - fixed_input_size=True, input_size=(3, 288, 288), pool_size=(9, 9) + fixed_input_size=True, input_size=(3, 288, 288), pool_size=(9, 9), + classifier = 'head.proj', ), 'resnet50x16_clip.openai': _cfgr( hf_hub_id='timm/', hf_hub_filename='open_clip_pytorch_model.bin', num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, - fixed_input_size=True, input_size=(3, 384, 384), pool_size=(12, 12) + fixed_input_size=True, input_size=(3, 384, 384), pool_size=(12, 12), + classifier = 'head.proj', ), 'resnet50x64_clip.openai': _cfgr( hf_hub_id='timm/', hf_hub_filename='open_clip_pytorch_model.bin', num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, - fixed_input_size=True, input_size=(3, 448, 448), pool_size=(14, 14) + fixed_input_size=True, input_size=(3, 448, 448), pool_size=(14, 14), + classifier = 'head.proj', ), }) @@ -2592,35 +2631,40 @@ def mobileone_s4(pretrained=False, **kwargs) -> ByobNet: def resnet50_clip(pretrained=False, **kwargs) -> ByobNet: """ OpenAI Modified ResNet-50 CLIP image tower """ - return _create_byobnet('resnet50_clip', pretrained=pretrained, **kwargs) + model_args = dict(global_pool='attn_abs') + return _create_byobnet('resnet50_clip', pretrained=pretrained, **dict(model_args, **kwargs)) @register_model def resnet101_clip(pretrained=False, **kwargs) -> ByobNet: """ OpenAI Modified ResNet-101 CLIP image tower """ - return _create_byobnet('resnet101_clip', pretrained=pretrained, **kwargs) + model_args = dict(global_pool='attn_abs') + return _create_byobnet('resnet101_clip', pretrained=pretrained, **dict(model_args, **kwargs)) @register_model def resnet50x4_clip(pretrained=False, **kwargs) -> ByobNet: """ OpenAI Modified ResNet-50x4 CLIP image tower """ - return _create_byobnet('resnet50x4_clip', pretrained=pretrained, **kwargs) + model_args = dict(global_pool='attn_abs') + return _create_byobnet('resnet50x4_clip', pretrained=pretrained, **dict(model_args, **kwargs)) @register_model def resnet50x16_clip(pretrained=False, **kwargs) -> ByobNet: """ OpenAI Modified ResNet-50x16 CLIP image tower """ - return _create_byobnet('resnet50x16_clip', pretrained=pretrained, **kwargs) + model_args = dict(global_pool='attn_abs') + return _create_byobnet('resnet50x16_clip', pretrained=pretrained, **dict(model_args, **kwargs)) @register_model def resnet50x64_clip(pretrained=False, **kwargs) -> ByobNet: """ OpenAI Modified ResNet-50x64 CLIP image tower """ - return _create_byobnet('resnet50x64_clip', pretrained=pretrained, **kwargs) + model_args = dict(global_pool='attn_abs') + return _create_byobnet('resnet50x64_clip', pretrained=pretrained, **dict(model_args, **kwargs)) @register_model