From 5efa15b2a2476a02c6f0ead515494e7b76cd13cd Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 9 Jun 2024 16:54:48 -0700 Subject: [PATCH] Mapping OpenAI CLIP Modified ResNet weights -> ByobNet. Improve AttentionPool2d layers. Fix #1731 --- timm/layers/attention_pool.py | 6 +- timm/layers/attention_pool2d.py | 182 ++++++++++++----- timm/layers/classifier.py | 2 - timm/layers/pos_embed_sincos.py | 1 - timm/models/byobnet.py | 351 +++++++++++++++++++++++++++++--- 5 files changed, 455 insertions(+), 87 deletions(-) diff --git a/timm/layers/attention_pool.py b/timm/layers/attention_pool.py index 41e404d2..da5585b3 100644 --- a/timm/layers/attention_pool.py +++ b/timm/layers/attention_pool.py @@ -20,6 +20,7 @@ class AttentionPoolLatent(nn.Module): out_features: int = None, embed_dim: int = None, num_heads: int = 8, + feat_size: Optional[int] = None, mlp_ratio: float = 4.0, qkv_bias: bool = True, qk_norm: bool = False, @@ -36,13 +37,14 @@ class AttentionPoolLatent(nn.Module): assert embed_dim % num_heads == 0 self.num_heads = num_heads self.head_dim = embed_dim // num_heads + self.feat_size = feat_size self.scale = self.head_dim ** -0.5 self.pool = pool_type self.fused_attn = use_fused_attn() if pos_embed == 'abs': - spatial_len = self.feat_size - self.pos_embed = nn.Parameter(torch.zeros(spatial_len, in_features)) + assert feat_size is not None + self.pos_embed = nn.Parameter(torch.zeros(feat_size, in_features)) else: self.pos_embed = None diff --git a/timm/layers/attention_pool2d.py b/timm/layers/attention_pool2d.py index 765efa08..dc594b70 100644 --- a/timm/layers/attention_pool2d.py +++ b/timm/layers/attention_pool2d.py @@ -7,12 +7,14 @@ https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/cli Hacked together by / Copyright 2021 Ross Wightman """ -from typing import Union, Tuple +from typing import Optional, Union, Tuple import torch import torch.nn as nn +from. config import use_fused_attn from .helpers import to_2tuple +from .pos_embed import resample_abs_pos_embed from .pos_embed_sincos import apply_rot_embed, RotaryEmbedding from .weight_init import trunc_normal_ @@ -27,51 +29,84 @@ class RotAttentionPool2d(nn.Module): NOTE: While this impl does not require a fixed feature size, performance at differeing resolutions from train varies widely and falls off dramatically. I'm not sure if there is a way around this... -RW """ + fused_attn: torch.jit.Final[bool] + def __init__( self, in_features: int, - out_features: int = None, - embed_dim: int = None, - num_heads: int = 4, + out_features: Optional[int] = None, + ref_feat_size: Union[int, Tuple[int, int]] = 7, + embed_dim: Optional[int] = None, + head_dim: Optional[int] = 64, + num_heads: Optional[int] = None, qkv_bias: bool = True, + qkv_separate: bool = False, ): super().__init__() embed_dim = embed_dim or in_features - out_features = out_features or in_features - self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias) - self.proj = nn.Linear(embed_dim, out_features) + self.in_features = in_features + self.out_features = out_features or in_features + ref_feat_size = to_2tuple(ref_feat_size) + if num_heads is not None: + assert embed_dim % num_heads == 0 + head_dim = embed_dim // num_heads + else: + assert embed_dim % head_dim == 0 + num_heads = embed_dim // head_dim self.num_heads = num_heads - assert embed_dim % num_heads == 0 - self.head_dim = embed_dim // num_heads + self.head_dim = head_dim self.scale = self.head_dim ** -0.5 - self.pos_embed = RotaryEmbedding(self.head_dim) + self.fused_attn = use_fused_attn() - trunc_normal_(self.qkv.weight, std=in_features ** -0.5) - nn.init.zeros_(self.qkv.bias) + if qkv_separate: + self.q = nn.Linear(in_features, embed_dim, bias=qkv_bias) + self.k = nn.Linear(in_features, embed_dim, bias=qkv_bias) + self.v = nn.Linear(in_features, embed_dim, bias=qkv_bias) + self.qkv = None + else: + self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias) + self.proj = nn.Linear(embed_dim, self.out_features) + self.pos_embed = RotaryEmbedding(self.head_dim, in_pixels=False, ref_feat_shape=ref_feat_size) + + def init_weights(self, zero_init_last: bool = False): + if self.qkv is None: + in_features = self.q.in_features + trunc_normal_(self.q.weight, std=in_features ** -0.5) + nn.init.zeros_(self.q.bias) + trunc_normal_(self.k.weight, std=in_features ** -0.5) + nn.init.zeros_(self.k.bias) + trunc_normal_(self.v.weight, std=in_features ** -0.5) + nn.init.zeros_(self.v.bias) + else: + in_features = self.qkv.in_features + trunc_normal_(self.qkv.weight, std=in_features ** -0.5) + nn.init.zeros_(self.qkv.bias) def forward(self, x): B, _, H, W = x.shape N = H * W - x = x.reshape(B, -1, N).permute(0, 2, 1) - + x = x.flatten(2).transpose(1, 2) x = torch.cat([x.mean(1, keepdim=True), x], dim=1) + if self.qkv is None: + q = self.q(x).reshape(B, N + 1, self.num_heads, self.head_dim).transpose(1, 2) + k = self.k(x).reshape(B, N + 1, self.num_heads, self.head_dim).transpose(1, 2) + v = self.v(x).reshape(B, N + 1, self.num_heads, self.head_dim).transpose(1, 2) + else: + x = self.qkv(x).reshape(B, N + 1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k, v = x.unbind(0) - x = self.qkv(x).reshape(B, N + 1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) - q, k, v = x[0], x[1], x[2] + rse, rce = self.pos_embed.get_embed((H, W)) + q = torch.cat([q[:, :, :1, :], apply_rot_embed(q[:, :, 1:, :], rse, rce)], dim=2).type_as(v) + k = torch.cat([k[:, :, :1, :], apply_rot_embed(k[:, :, 1:, :], rse, rce)], dim=2).type_as(v) - qc, q = q[:, :, :1], q[:, :, 1:] - sin_emb, cos_emb = self.pos_embed.get_embed((H, W)) - q = apply_rot_embed(q, sin_emb, cos_emb) - q = torch.cat([qc, q], dim=2) - - kc, k = k[:, :, :1], k[:, :, 1:] - k = apply_rot_embed(k, sin_emb, cos_emb) - k = torch.cat([kc, k], dim=2) - - attn = (q @ k.transpose(-2, -1)) * self.scale - attn = attn.softmax(dim=-1) - - x = (attn @ v).transpose(1, 2).reshape(B, N + 1, -1) + if self.fused_attn: + x = nn.functional.scaled_dot_product_attention(q, k, v) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + x = attn @ v + x = x.transpose(1, 2).reshape(B, N + 1, -1) x = self.proj(x) return x[:, 0] @@ -85,47 +120,90 @@ class AttentionPool2d(nn.Module): NOTE: This requires feature size upon construction and well prevent adaptive sizing of the network. """ + fused_attn: torch.jit.Final[bool] + def __init__( self, in_features: int, - feat_size: Union[int, Tuple[int, int]], - out_features: int = None, - embed_dim: int = None, - num_heads: int = 4, + feat_size: Union[int, Tuple[int, int]] = 7, + out_features: Optional[int] = None, + embed_dim: Optional[int] = None, + head_dim: Optional[int] = 64, + num_heads: Optional[int] = None, qkv_bias: bool = True, + qkv_separate: bool = False, ): super().__init__() - embed_dim = embed_dim or in_features - out_features = out_features or in_features - assert embed_dim % num_heads == 0 + self.in_features = in_features + self.out_features = out_features or in_features + if num_heads is not None: + assert embed_dim % num_heads == 0 + head_dim = embed_dim // num_heads + else: + assert embed_dim % head_dim == 0 + num_heads = embed_dim // head_dim self.feat_size = to_2tuple(feat_size) - self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias) - self.proj = nn.Linear(embed_dim, out_features) + self.seq_len = self.feat_size[0] * self.feat_size[1] self.num_heads = num_heads - self.head_dim = embed_dim // num_heads + self.head_dim = head_dim self.scale = self.head_dim ** -0.5 + self.fused_attn = use_fused_attn() - spatial_dim = self.feat_size[0] * self.feat_size[1] - self.pos_embed = nn.Parameter(torch.zeros(spatial_dim + 1, in_features)) + if qkv_separate: + self.q = nn.Linear(in_features, embed_dim, bias=qkv_bias) + self.k = nn.Linear(in_features, embed_dim, bias=qkv_bias) + self.v = nn.Linear(in_features, embed_dim, bias=qkv_bias) + self.qkv = None + else: + self.q = self.k = self.v = None + self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias) + self.proj = nn.Linear(embed_dim, self.out_features) + self.pos_embed = nn.Parameter(torch.zeros(self.seq_len + 1, in_features)) + + self.init_weights() + + def init_weights(self, zero_init_last: bool = False): + if self.qkv is None: + in_features = self.q.in_features + trunc_normal_(self.q.weight, std=in_features ** -0.5) + nn.init.zeros_(self.q.bias) + trunc_normal_(self.k.weight, std=in_features ** -0.5) + nn.init.zeros_(self.k.bias) + trunc_normal_(self.v.weight, std=in_features ** -0.5) + nn.init.zeros_(self.v.bias) + else: + in_features = self.qkv.in_features + trunc_normal_(self.qkv.weight, std=in_features ** -0.5) + nn.init.zeros_(self.qkv.bias) trunc_normal_(self.pos_embed, std=in_features ** -0.5) - trunc_normal_(self.qkv.weight, std=in_features ** -0.5) - nn.init.zeros_(self.qkv.bias) def forward(self, x): B, _, H, W = x.shape N = H * W - assert self.feat_size[0] == H - assert self.feat_size[1] == W - x = x.reshape(B, -1, N).permute(0, 2, 1) + x = x.flatten(2).transpose(1, 2) x = torch.cat([x.mean(1, keepdim=True), x], dim=1) - x = x + self.pos_embed.unsqueeze(0).to(x.dtype) + if self.seq_len != N: + pos_embed = resample_abs_pos_embed(self.pos_embed.unsqueeze(0), (H, W), num_prefix_tokens=1) + else: + pos_embed = self.pos_embed.unsqueeze(0).to(x.dtype) + x = x + pos_embed - x = self.qkv(x).reshape(B, N + 1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) - q, k, v = x[0], x[1], x[2] - attn = (q @ k.transpose(-2, -1)) * self.scale - attn = attn.softmax(dim=-1) + if self.qkv is None: + q = self.q(x).reshape(B, N + 1, self.num_heads, self.head_dim).transpose(1, 2) + k = self.k(x).reshape(B, N + 1, self.num_heads, self.head_dim).transpose(1, 2) + v = self.v(x).reshape(B, N + 1, self.num_heads, self.head_dim).transpose(1, 2) + else: + x = self.qkv(x).reshape(B, -1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k, v = x.unbind(0) - x = (attn @ v).transpose(1, 2).reshape(B, N + 1, -1) + if self.fused_attn: + x = nn.functional.scaled_dot_product_attention(q, k, v) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + x = attn @ v + x = x.transpose(1, 2).reshape(B, N + 1, -1) x = self.proj(x) return x[:, 0] diff --git a/timm/layers/classifier.py b/timm/layers/classifier.py index 27ee5e70..2441c050 100644 --- a/timm/layers/classifier.py +++ b/timm/layers/classifier.py @@ -24,8 +24,6 @@ def _create_pool( ): flatten_in_pool = not use_conv # flatten when we use a Linear layer after pooling if not pool_type: - assert num_classes == 0 or use_conv,\ - 'Pooling can only be disabled if classifier is also removed or conv classifier is used' flatten_in_pool = False # disable flattening if pooling is pass-through (no pooling) global_pool = SelectAdaptivePool2d( pool_type=pool_type, diff --git a/timm/layers/pos_embed_sincos.py b/timm/layers/pos_embed_sincos.py index b5f8502f..5bb31af5 100644 --- a/timm/layers/pos_embed_sincos.py +++ b/timm/layers/pos_embed_sincos.py @@ -312,7 +312,6 @@ class RotaryEmbedding(nn.Module): temperature=temperature, step=1, ) - print(bands) self.register_buffer( 'bands', bands, diff --git a/timm/models/byobnet.py b/timm/models/byobnet.py index 02e25836..b9417dfe 100644 --- a/timm/models/byobnet.py +++ b/timm/models/byobnet.py @@ -36,8 +36,8 @@ from typing import Tuple, List, Dict, Optional, Union, Any, Callable, Sequence import torch import torch.nn as nn -from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import ClassifierHead, ConvNormAct, BatchNormAct2d, DropPath, AvgPool2dSame, \ +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD +from timm.layers import ClassifierHead, ConvNormAct, ConvNormActAa, BatchNormAct2d, DropPath, AvgPool2dSame, \ create_conv2d, get_act_layer, get_norm_act_layer, get_attn, make_divisible, to_2tuple, EvoNorm2dS0a from ._builder import build_model_with_cfg from ._features import feature_take_indices @@ -70,15 +70,23 @@ class ByoModelCfg: downsample: str = 'conv1x1' stem_type: str = '3x3' stem_pool: Optional[str] = 'maxpool' - stem_chs: int = 32 + stem_chs: Union[int, List[int], Tuple[int, ...]] = 32 width_factor: float = 1.0 num_features: int = 0 # num out_channels for final conv, no final 1x1 conv if 0 zero_init_last: bool = True # zero init last weight (usually bn) in residual path fixed_input_size: bool = False # model constrained to a fixed-input size / img_size must be provided on creation + # layer config act_layer: str = 'relu' norm_layer: str = 'batchnorm' + aa_layer: str = '' + # Head config + attn_pool: str = '' + head_hidden_size: Optional[int] = None # feat dim of MLP head or AttentionPool output + head_type: str = '' + + # Block config # NOTE: these config items will be overridden by the block cfg (per-block) if they are set there attn_layer: Optional[str] = None attn_kwargs: dict = field(default_factory=lambda: dict()) @@ -296,10 +304,7 @@ class BottleneckBlock(nn.Module): mid_chs = make_divisible((in_chs if bottle_in else out_chs) * bottle_ratio) groups = num_groups(group_size, mid_chs) - self.shortcut = create_shortcut( - downsample, in_chs, out_chs, - stride=stride, dilation=dilation, apply_act=False, layers=layers, - ) + self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1) self.conv2_kxk = layers.conv_norm_act( @@ -316,7 +321,10 @@ class BottleneckBlock(nn.Module): self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs) self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() self.act = nn.Identity() if linear_out else layers.act(inplace=True) - + self.shortcut = create_shortcut( + downsample, in_chs, out_chs, + stride=stride, dilation=dilation, apply_act=False, layers=layers, + ) def init_weights(self, zero_init_last: bool = False): if zero_init_last and self.shortcut is not None and getattr(self.conv3_1x1.bn, 'weight', None) is not None: nn.init.zeros_(self.conv3_1x1.bn.weight) @@ -917,7 +925,7 @@ class Stem(nn.Sequential): def __init__( self, in_chs: int, - out_chs: int, + out_chs: Union[int, List[int], Tuple[int, ...]], kernel_size: int = 3, stride: int = 4, pool: str = 'maxpool', @@ -961,10 +969,19 @@ class Stem(nn.Sequential): curr_stride *= s prev_feat = conv_name - if pool and 'max' in pool.lower(): + if pool: + pool = pool.lower() + assert pool in ('max', 'maxpool', 'avg', 'avgpool', 'max2', 'avg2') last_feat_idx = i self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat, stage=0)) - self.add_module('pool', nn.MaxPool2d(3, 2, 1)) + if pool == 'max2': + self.add_module('pool', nn.MaxPool2d(2)) + elif pool == 'avg2': + self.add_module('pool', nn.AvgPool2d(2)) + elif 'max' in pool: + self.add_module('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) + elif 'avg' in pool: + self.add_module('pool', nn.AvgPool2d(kernel_size=3, stride=2, padding=1, count_include_pad=False)) curr_stride *= 2 prev_feat = 'pool' @@ -1012,11 +1029,14 @@ def create_byob_stem( else: stem = layers.conv_norm_act(in_chs, out_chs, 7, stride=2) else: - # 3x3 stem conv as in RegNet is the default - if pool_type: - stem = Stem(in_chs, out_chs, 3, num_rep=1, pool=pool_type, layers=layers) + if isinstance(out_chs, (tuple, list)): + stem = Stem(in_chs, out_chs, 3, pool=pool_type, layers=layers) else: - stem = layers.conv_norm_act(in_chs, out_chs, 3, stride=2) + # 3x3 stem conv as in RegNet is the default + if pool_type: + stem = Stem(in_chs, out_chs, 3, num_rep=1, pool=pool_type, layers=layers) + else: + stem = layers.conv_norm_act(in_chs, out_chs, 3, stride=2) if isinstance(stem, Stem): feature_info = [dict(f, module='.'.join([feat_prefix, f['module']])) for f in stem.feature_info] @@ -1138,13 +1158,16 @@ def create_byob_stages( prev_feat = dict(num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}', stage=stage_idx + 1) feature_info.append(prev_feat) - return nn.Sequential(*stages), feature_info + return nn.Sequential(*stages), feature_info, feat_size -def get_layer_fns(cfg: ByoModelCfg): +def get_layer_fns(cfg: ByoModelCfg, allow_aa: bool = True): act = get_act_layer(cfg.act_layer) norm_act = get_norm_act_layer(norm_layer=cfg.norm_layer, act_layer=act) - conv_norm_act = partial(ConvNormAct, norm_layer=cfg.norm_layer, act_layer=act) + if cfg.aa_layer and allow_aa: + conv_norm_act = partial(ConvNormActAa, norm_layer=cfg.norm_layer, act_layer=act, aa_layer=cfg.aa_layer) + else: + conv_norm_act = partial(ConvNormAct, norm_layer=cfg.norm_layer, act_layer=act) attn = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None self_attn = partial(get_attn(cfg.self_attn_layer), **cfg.self_attn_kwargs) if cfg.self_attn_layer else None layer_fn = LayerFn(conv_norm_act=conv_norm_act, norm_act=norm_act, act=act, attn=attn, self_attn=self_attn) @@ -1191,23 +1214,33 @@ class ByobNet(nn.Module): self.grad_checkpointing = False cfg = replace(cfg, **kwargs) # overlay kwargs onto cfg - layers = get_layer_fns(cfg) + stem_layers = get_layer_fns(cfg, allow_aa=False) # keep aa off for stem-layers + stage_layers = get_layer_fns(cfg) if cfg.fixed_input_size: assert img_size is not None, 'img_size argument is required for fixed input size model' feat_size = to_2tuple(img_size) if img_size is not None else None self.feature_info = [] - stem_chs = int(round((cfg.stem_chs or cfg.blocks[0].c) * cfg.width_factor)) - self.stem, stem_feat = create_byob_stem(in_chans, stem_chs, cfg.stem_type, cfg.stem_pool, layers=layers) + if isinstance(cfg.stem_chs, (list, tuple)): + stem_chs = [int(round(c * cfg.width_factor)) for c in cfg.stem_chs] + else: + stem_chs = int(round((cfg.stem_chs or cfg.blocks[0].c) * cfg.width_factor)) + self.stem, stem_feat = create_byob_stem( + in_chs=in_chans, + out_chs=stem_chs, + stem_type=cfg.stem_type, + pool_type=cfg.stem_pool, + layers=stem_layers, + ) self.feature_info.extend(stem_feat[:-1]) feat_size = reduce_feat_size(feat_size, stride=stem_feat[-1]['reduction']) - self.stages, stage_feat = create_byob_stages( + self.stages, stage_feat, feat_size = create_byob_stages( cfg, drop_path_rate, output_stride, stem_feat[-1], - layers=layers, + layers=stage_layers, feat_size=feat_size, ) self.feature_info.extend(stage_feat[:-1]) @@ -1216,7 +1249,7 @@ class ByobNet(nn.Module): prev_chs = stage_feat[-1]['num_chs'] if cfg.num_features: self.num_features = int(round(cfg.width_factor * cfg.num_features)) - self.final_conv = layers.conv_norm_act(prev_chs, self.num_features, 1) + self.final_conv = stage_layers.conv_norm_act(prev_chs, self.num_features, 1) else: self.num_features = prev_chs self.final_conv = nn.Identity() @@ -1225,12 +1258,47 @@ class ByobNet(nn.Module): self.stage_ends = [f['stage'] for f in self.feature_info] self.head_hidden_size = self.num_features - self.head = ClassifierHead( - self.num_features, - num_classes, - pool_type=global_pool, - drop_rate=self.drop_rate, - ) + assert cfg.head_type in ('', 'classifier', 'norm_mlp_classifier') + if cfg.head_type == 'norm_mlp_classifier': + from timm.layers import NormMlpClassifierHead + assert not cfg.attn_pool, "Cannot use attentional pooling with norm + MLP head" + self.attn_pool = nn.Identity() + self.head = NormMlpClassifierHead( + self.num_features, + num_classes, + hidden_size=cfg.head_hidden_size, + norm_layer=cfg.norm_layer, + act_layer=cfg.act_layer, + ) + self.head_hidden_size = self.head.hidden_size + else: + if cfg.attn_pool == 'abs': + from timm.layers import AttentionPool2d + self.attn_pool = AttentionPool2d( + self.num_features, + out_features=cfg.head_hidden_size, + feat_size=feat_size, + qkv_separate=True, + ) + self.head_hidden_size = self.attn_pool.out_features + elif cfg.attn_pool == 'rot': + from timm.layers import RotAttentionPool2d + self.attn_pool = RotAttentionPool2d( + self.num_features, + out_features=cfg.head_hidden_size, + ref_feat_size=feat_size, + ) + self.head_hidden_size = self.attn_pool.out_features + else: + assert not cfg.attn_pool + self.attn_pool = nn.Identity() + + self.head = ClassifierHead( + self.head_hidden_size, + num_classes, + pool_type='' if cfg.attn_pool else global_pool, + drop_rate=self.drop_rate, + ) # init weights named_apply(partial(_init_weights, zero_init_last=zero_init_last), self) @@ -1345,6 +1413,7 @@ class ByobNet(nn.Module): return x def forward_head(self, x, pre_logits: bool = False): + x = self.attn_pool(x) return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x) def forward(self, x): @@ -1834,14 +1903,162 @@ model_cfgs = dict( stem_type='one', stem_chs=64, ), + + resnet50_clip=ByoModelCfg( + blocks=( + ByoBlockCfg(type='bottle', d=3, c=256, s=1, br=0.25), + ByoBlockCfg(type='bottle', d=4, c=512, s=2, br=0.25), + ByoBlockCfg(type='bottle', d=6, c=1024, s=2, br=0.25), + ByoBlockCfg(type='bottle', d=3, c=2048, s=2, br=0.25), + ), + stem_chs=(32, 32, 64), + stem_type='', + stem_pool='avg2', + downsample='avg', + aa_layer='avg', + attn_pool='abs', + head_hidden_size=1024, + ), + + resnet101_clip=ByoModelCfg( + blocks=( + ByoBlockCfg(type='bottle', d=3, c=256, s=1, br=0.25), + ByoBlockCfg(type='bottle', d=4, c=512, s=2, br=0.25), + ByoBlockCfg(type='bottle', d=23, c=1024, s=2, br=0.25), + ByoBlockCfg(type='bottle', d=3, c=2048, s=2, br=0.25), + ), + stem_chs=(32, 32, 64), + stem_type='', + stem_pool='avg2', + downsample='avg', + aa_layer='avg', + attn_pool='abs', + head_hidden_size=512, + ), + + resnet50x4_clip=ByoModelCfg( + blocks=( + ByoBlockCfg(type='bottle', d=4, c=256, s=1, br=0.25), + ByoBlockCfg(type='bottle', d=6, c=512, s=2, br=0.25), + ByoBlockCfg(type='bottle', d=10, c=1024, s=2, br=0.25), + ByoBlockCfg(type='bottle', d=6, c=2048, s=2, br=0.25), + ), + width_factor=1.25, + stem_chs=(32, 32, 64), + stem_type='', + stem_pool='avg2', + downsample='avg', + aa_layer='avg', + attn_pool='abs', + head_hidden_size=640, + ), + + resnet50x16_clip=ByoModelCfg( + blocks=( + ByoBlockCfg(type='bottle', d=6, c=256, s=1, br=0.25), + ByoBlockCfg(type='bottle', d=8, c=512, s=2, br=0.25), + ByoBlockCfg(type='bottle', d=18, c=1024, s=2, br=0.25), + ByoBlockCfg(type='bottle', d=8, c=2048, s=2, br=0.25), + ), + stem_chs=(32, 32, 64), + stem_type='', + stem_pool='avg2', + downsample='avg', + aa_layer='avg', + attn_pool='abs', + head_hidden_size=768, + ), + + resnet50x64_clip=ByoModelCfg( + blocks=( + ByoBlockCfg(type='bottle', d=3, c=256, s=1, br=0.25), + ByoBlockCfg(type='bottle', d=15, c=512, s=2, br=0.25), + ByoBlockCfg(type='bottle', d=36, c=1024, s=2, br=0.25), + ByoBlockCfg(type='bottle', d=10, c=2048, s=2, br=0.25), + ), + stem_chs=(32, 32, 64), + stem_type='', + stem_pool='avg2', + downsample='avg', + aa_layer='avg', + attn_pool='abs', + head_hidden_size=1024, + ), + + resnet50_nmlp=ByoModelCfg( + blocks=( + ByoBlockCfg(type='bottle', d=3, c=256, s=1, br=0.25), + ByoBlockCfg(type='bottle', d=4, c=512, s=2, br=0.25), + ByoBlockCfg(type='bottle', d=6, c=1024, s=2, br=0.25), + ByoBlockCfg(type='bottle', d=3, c=2048, s=2, br=0.25), + ), + stem_chs=(32, 32, 64), + stem_type='', + stem_pool='avg2', + downsample='avg', + aa_layer='avg', + head_hidden_size=1024, + head_type='norm_mlp_classifier', + ), ) +def _convert_openai_clip( + state_dict: Dict[str, torch.Tensor], + model: ByobNet, + prefix: str = 'visual.', +) -> Dict[str, torch.Tensor]: + import re + + def _stage_sub(m): + stage_idx = int(m.group(1)) - 1 + layer_idx, layer_type, layer_id = int(m.group(2)), m.group(3), int(m.group(4)) + prefix_str = f'stages.{stage_idx}.{layer_idx}.' + id_map = {1: 'conv1_1x1.', 2: 'conv2_kxk.', 3: 'conv3_1x1.'} + suffix_str = id_map[layer_id] + layer_type + return prefix_str + suffix_str + + def _down_sub(m): + stage_idx = int(m.group(1)) - 1 + layer_idx, layer_id = int(m.group(2)), int(m.group(3)) + return f'stages.{stage_idx}.{layer_idx}.shortcut.' + ('conv.conv' if layer_id == 0 else 'conv.bn') + + out_dict = {} + for k, v in state_dict.items(): + if not k.startswith(prefix): + continue + k = re.sub(rf'{prefix}conv([0-9])', r'stem.conv\1.conv', k) + k = re.sub(rf'{prefix}bn([0-9])', r'stem.conv\1.bn', k) + k = re.sub(rf'{prefix}layer([0-9])\.([0-9])\.([a-z]+)([0-9])', _stage_sub, k) + k = re.sub(rf'{prefix}layer([0-9])\.([0-9])\.downsample\.([0-9])', _down_sub, k) + if k.startswith(f'{prefix}attnpool'): + k = k.replace(prefix + 'attnpool', 'attn_pool') + k = k.replace('positional_embedding', 'pos_embed') + k = k.replace('q_proj', 'q') + k = k.replace('k_proj', 'k') + k = k.replace('v_proj', 'v') + k = k.replace('c_proj', 'proj') + out_dict[k] = v + + return out_dict + + +def checkpoint_filter_fn( + state_dict: Dict[str, torch.Tensor], + model: ByobNet +): + if 'visual.conv1.weight' in state_dict: + state_dict = _convert_openai_clip(state_dict, model) + return state_dict + + def _create_byobnet(variant, pretrained=False, **kwargs): return build_model_with_cfg( ByobNet, variant, pretrained, model_cfg=model_cfgs[variant], + pretrained_filter_fn=checkpoint_filter_fn, feature_cfg=dict(flatten_sequential=True), + #pretrained_strict=False, **kwargs) @@ -2035,6 +2252,38 @@ default_cfgs = generate_default_cfgs({ crop_pct=0.9, first_conv=('stem.conv_kxk.0.conv', 'stem.conv_scale.conv'), ), + + 'resnet50_clip.openai': _cfgr( + hf_hub_id='timm/', + hf_hub_filename='open_clip_pytorch_model.bin', + num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + fixed_input_size=True, input_size=(3, 224, 224), pool_size=(7, 7) + ), + 'resnet101_clip.openai': _cfgr( + hf_hub_id='timm/', + hf_hub_filename='open_clip_pytorch_model.bin', + num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + fixed_input_size=True, input_size=(3, 224, 224), pool_size=(7, 7) + ), + 'resnet50x4_clip.openai': _cfgr( + hf_hub_id='timm/', + hf_hub_filename='open_clip_pytorch_model.bin', + num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + fixed_input_size=True, input_size=(3, 288, 288), pool_size=(9, 9) + ), + 'resnet50x16_clip.openai': _cfgr( + hf_hub_id='timm/', + hf_hub_filename='open_clip_pytorch_model.bin', + num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + fixed_input_size=True, input_size=(3, 384, 384), pool_size=(12, 12) + ), + 'resnet50x64_clip.openai': _cfgr( + hf_hub_id='timm/', + hf_hub_filename='open_clip_pytorch_model.bin', + num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + fixed_input_size=True, input_size=(3, 448, 448), pool_size=(14, 14) + ), + }) @@ -2337,3 +2586,45 @@ def mobileone_s4(pretrained=False, **kwargs) -> ByobNet: """ """ return _create_byobnet('mobileone_s4', pretrained=pretrained, **kwargs) + + +@register_model +def resnet50_clip(pretrained=False, **kwargs) -> ByobNet: + """ OpenAI Modified ResNet-50 CLIP image tower + """ + return _create_byobnet('resnet50_clip', pretrained=pretrained, **kwargs) + + +@register_model +def resnet101_clip(pretrained=False, **kwargs) -> ByobNet: + """ OpenAI Modified ResNet-101 CLIP image tower + """ + return _create_byobnet('resnet101_clip', pretrained=pretrained, **kwargs) + + +@register_model +def resnet50x4_clip(pretrained=False, **kwargs) -> ByobNet: + """ OpenAI Modified ResNet-50x4 CLIP image tower + """ + return _create_byobnet('resnet50x4_clip', pretrained=pretrained, **kwargs) + + +@register_model +def resnet50x16_clip(pretrained=False, **kwargs) -> ByobNet: + """ OpenAI Modified ResNet-50x16 CLIP image tower + """ + return _create_byobnet('resnet50x16_clip', pretrained=pretrained, **kwargs) + + +@register_model +def resnet50x64_clip(pretrained=False, **kwargs) -> ByobNet: + """ OpenAI Modified ResNet-50x64 CLIP image tower + """ + return _create_byobnet('resnet50x64_clip', pretrained=pretrained, **kwargs) + + +@register_model +def resnet50_nmlp(pretrained=False, **kwargs) -> ByobNet: + """ + """ + return _create_byobnet('resnet50_nmlp', pretrained=pretrained, **kwargs)