Mapping OpenAI CLIP Modified ResNet weights -> ByobNet. Improve AttentionPool2d layers. Fix #1731

This commit is contained in:
Ross Wightman 2024-06-09 16:54:48 -07:00
parent 7702d9afa1
commit 5efa15b2a2
5 changed files with 455 additions and 87 deletions

View File

@ -20,6 +20,7 @@ class AttentionPoolLatent(nn.Module):
out_features: int = None,
embed_dim: int = None,
num_heads: int = 8,
feat_size: Optional[int] = None,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
qk_norm: bool = False,
@ -36,13 +37,14 @@ class AttentionPoolLatent(nn.Module):
assert embed_dim % num_heads == 0
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.feat_size = feat_size
self.scale = self.head_dim ** -0.5
self.pool = pool_type
self.fused_attn = use_fused_attn()
if pos_embed == 'abs':
spatial_len = self.feat_size
self.pos_embed = nn.Parameter(torch.zeros(spatial_len, in_features))
assert feat_size is not None
self.pos_embed = nn.Parameter(torch.zeros(feat_size, in_features))
else:
self.pos_embed = None

View File

@ -7,12 +7,14 @@ https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/cli
Hacked together by / Copyright 2021 Ross Wightman
"""
from typing import Union, Tuple
from typing import Optional, Union, Tuple
import torch
import torch.nn as nn
from. config import use_fused_attn
from .helpers import to_2tuple
from .pos_embed import resample_abs_pos_embed
from .pos_embed_sincos import apply_rot_embed, RotaryEmbedding
from .weight_init import trunc_normal_
@ -27,51 +29,84 @@ class RotAttentionPool2d(nn.Module):
NOTE: While this impl does not require a fixed feature size, performance at differeing resolutions from
train varies widely and falls off dramatically. I'm not sure if there is a way around this... -RW
"""
fused_attn: torch.jit.Final[bool]
def __init__(
self,
in_features: int,
out_features: int = None,
embed_dim: int = None,
num_heads: int = 4,
out_features: Optional[int] = None,
ref_feat_size: Union[int, Tuple[int, int]] = 7,
embed_dim: Optional[int] = None,
head_dim: Optional[int] = 64,
num_heads: Optional[int] = None,
qkv_bias: bool = True,
qkv_separate: bool = False,
):
super().__init__()
embed_dim = embed_dim or in_features
out_features = out_features or in_features
self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias)
self.proj = nn.Linear(embed_dim, out_features)
self.in_features = in_features
self.out_features = out_features or in_features
ref_feat_size = to_2tuple(ref_feat_size)
if num_heads is not None:
assert embed_dim % num_heads == 0
head_dim = embed_dim // num_heads
else:
assert embed_dim % head_dim == 0
num_heads = embed_dim // head_dim
self.num_heads = num_heads
assert embed_dim % num_heads == 0
self.head_dim = embed_dim // num_heads
self.head_dim = head_dim
self.scale = self.head_dim ** -0.5
self.pos_embed = RotaryEmbedding(self.head_dim)
self.fused_attn = use_fused_attn()
trunc_normal_(self.qkv.weight, std=in_features ** -0.5)
nn.init.zeros_(self.qkv.bias)
if qkv_separate:
self.q = nn.Linear(in_features, embed_dim, bias=qkv_bias)
self.k = nn.Linear(in_features, embed_dim, bias=qkv_bias)
self.v = nn.Linear(in_features, embed_dim, bias=qkv_bias)
self.qkv = None
else:
self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias)
self.proj = nn.Linear(embed_dim, self.out_features)
self.pos_embed = RotaryEmbedding(self.head_dim, in_pixels=False, ref_feat_shape=ref_feat_size)
def init_weights(self, zero_init_last: bool = False):
if self.qkv is None:
in_features = self.q.in_features
trunc_normal_(self.q.weight, std=in_features ** -0.5)
nn.init.zeros_(self.q.bias)
trunc_normal_(self.k.weight, std=in_features ** -0.5)
nn.init.zeros_(self.k.bias)
trunc_normal_(self.v.weight, std=in_features ** -0.5)
nn.init.zeros_(self.v.bias)
else:
in_features = self.qkv.in_features
trunc_normal_(self.qkv.weight, std=in_features ** -0.5)
nn.init.zeros_(self.qkv.bias)
def forward(self, x):
B, _, H, W = x.shape
N = H * W
x = x.reshape(B, -1, N).permute(0, 2, 1)
x = x.flatten(2).transpose(1, 2)
x = torch.cat([x.mean(1, keepdim=True), x], dim=1)
if self.qkv is None:
q = self.q(x).reshape(B, N + 1, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k(x).reshape(B, N + 1, self.num_heads, self.head_dim).transpose(1, 2)
v = self.v(x).reshape(B, N + 1, self.num_heads, self.head_dim).transpose(1, 2)
else:
x = self.qkv(x).reshape(B, N + 1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = x.unbind(0)
x = self.qkv(x).reshape(B, N + 1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = x[0], x[1], x[2]
rse, rce = self.pos_embed.get_embed((H, W))
q = torch.cat([q[:, :, :1, :], apply_rot_embed(q[:, :, 1:, :], rse, rce)], dim=2).type_as(v)
k = torch.cat([k[:, :, :1, :], apply_rot_embed(k[:, :, 1:, :], rse, rce)], dim=2).type_as(v)
qc, q = q[:, :, :1], q[:, :, 1:]
sin_emb, cos_emb = self.pos_embed.get_embed((H, W))
q = apply_rot_embed(q, sin_emb, cos_emb)
q = torch.cat([qc, q], dim=2)
kc, k = k[:, :, :1], k[:, :, 1:]
k = apply_rot_embed(k, sin_emb, cos_emb)
k = torch.cat([kc, k], dim=2)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, N + 1, -1)
if self.fused_attn:
x = nn.functional.scaled_dot_product_attention(q, k, v)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
x = attn @ v
x = x.transpose(1, 2).reshape(B, N + 1, -1)
x = self.proj(x)
return x[:, 0]
@ -85,47 +120,90 @@ class AttentionPool2d(nn.Module):
NOTE: This requires feature size upon construction and well prevent adaptive sizing of the network.
"""
fused_attn: torch.jit.Final[bool]
def __init__(
self,
in_features: int,
feat_size: Union[int, Tuple[int, int]],
out_features: int = None,
embed_dim: int = None,
num_heads: int = 4,
feat_size: Union[int, Tuple[int, int]] = 7,
out_features: Optional[int] = None,
embed_dim: Optional[int] = None,
head_dim: Optional[int] = 64,
num_heads: Optional[int] = None,
qkv_bias: bool = True,
qkv_separate: bool = False,
):
super().__init__()
embed_dim = embed_dim or in_features
out_features = out_features or in_features
assert embed_dim % num_heads == 0
self.in_features = in_features
self.out_features = out_features or in_features
if num_heads is not None:
assert embed_dim % num_heads == 0
head_dim = embed_dim // num_heads
else:
assert embed_dim % head_dim == 0
num_heads = embed_dim // head_dim
self.feat_size = to_2tuple(feat_size)
self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias)
self.proj = nn.Linear(embed_dim, out_features)
self.seq_len = self.feat_size[0] * self.feat_size[1]
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.head_dim = head_dim
self.scale = self.head_dim ** -0.5
self.fused_attn = use_fused_attn()
spatial_dim = self.feat_size[0] * self.feat_size[1]
self.pos_embed = nn.Parameter(torch.zeros(spatial_dim + 1, in_features))
if qkv_separate:
self.q = nn.Linear(in_features, embed_dim, bias=qkv_bias)
self.k = nn.Linear(in_features, embed_dim, bias=qkv_bias)
self.v = nn.Linear(in_features, embed_dim, bias=qkv_bias)
self.qkv = None
else:
self.q = self.k = self.v = None
self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias)
self.proj = nn.Linear(embed_dim, self.out_features)
self.pos_embed = nn.Parameter(torch.zeros(self.seq_len + 1, in_features))
self.init_weights()
def init_weights(self, zero_init_last: bool = False):
if self.qkv is None:
in_features = self.q.in_features
trunc_normal_(self.q.weight, std=in_features ** -0.5)
nn.init.zeros_(self.q.bias)
trunc_normal_(self.k.weight, std=in_features ** -0.5)
nn.init.zeros_(self.k.bias)
trunc_normal_(self.v.weight, std=in_features ** -0.5)
nn.init.zeros_(self.v.bias)
else:
in_features = self.qkv.in_features
trunc_normal_(self.qkv.weight, std=in_features ** -0.5)
nn.init.zeros_(self.qkv.bias)
trunc_normal_(self.pos_embed, std=in_features ** -0.5)
trunc_normal_(self.qkv.weight, std=in_features ** -0.5)
nn.init.zeros_(self.qkv.bias)
def forward(self, x):
B, _, H, W = x.shape
N = H * W
assert self.feat_size[0] == H
assert self.feat_size[1] == W
x = x.reshape(B, -1, N).permute(0, 2, 1)
x = x.flatten(2).transpose(1, 2)
x = torch.cat([x.mean(1, keepdim=True), x], dim=1)
x = x + self.pos_embed.unsqueeze(0).to(x.dtype)
if self.seq_len != N:
pos_embed = resample_abs_pos_embed(self.pos_embed.unsqueeze(0), (H, W), num_prefix_tokens=1)
else:
pos_embed = self.pos_embed.unsqueeze(0).to(x.dtype)
x = x + pos_embed
x = self.qkv(x).reshape(B, N + 1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = x[0], x[1], x[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
if self.qkv is None:
q = self.q(x).reshape(B, N + 1, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k(x).reshape(B, N + 1, self.num_heads, self.head_dim).transpose(1, 2)
v = self.v(x).reshape(B, N + 1, self.num_heads, self.head_dim).transpose(1, 2)
else:
x = self.qkv(x).reshape(B, -1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = x.unbind(0)
x = (attn @ v).transpose(1, 2).reshape(B, N + 1, -1)
if self.fused_attn:
x = nn.functional.scaled_dot_product_attention(q, k, v)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
x = attn @ v
x = x.transpose(1, 2).reshape(B, N + 1, -1)
x = self.proj(x)
return x[:, 0]

View File

@ -24,8 +24,6 @@ def _create_pool(
):
flatten_in_pool = not use_conv # flatten when we use a Linear layer after pooling
if not pool_type:
assert num_classes == 0 or use_conv,\
'Pooling can only be disabled if classifier is also removed or conv classifier is used'
flatten_in_pool = False # disable flattening if pooling is pass-through (no pooling)
global_pool = SelectAdaptivePool2d(
pool_type=pool_type,

View File

@ -312,7 +312,6 @@ class RotaryEmbedding(nn.Module):
temperature=temperature,
step=1,
)
print(bands)
self.register_buffer(
'bands',
bands,

View File

@ -36,8 +36,8 @@ from typing import Tuple, List, Dict, Optional, Union, Any, Callable, Sequence
import torch
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import ClassifierHead, ConvNormAct, BatchNormAct2d, DropPath, AvgPool2dSame, \
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from timm.layers import ClassifierHead, ConvNormAct, ConvNormActAa, BatchNormAct2d, DropPath, AvgPool2dSame, \
create_conv2d, get_act_layer, get_norm_act_layer, get_attn, make_divisible, to_2tuple, EvoNorm2dS0a
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
@ -70,15 +70,23 @@ class ByoModelCfg:
downsample: str = 'conv1x1'
stem_type: str = '3x3'
stem_pool: Optional[str] = 'maxpool'
stem_chs: int = 32
stem_chs: Union[int, List[int], Tuple[int, ...]] = 32
width_factor: float = 1.0
num_features: int = 0 # num out_channels for final conv, no final 1x1 conv if 0
zero_init_last: bool = True # zero init last weight (usually bn) in residual path
fixed_input_size: bool = False # model constrained to a fixed-input size / img_size must be provided on creation
# layer config
act_layer: str = 'relu'
norm_layer: str = 'batchnorm'
aa_layer: str = ''
# Head config
attn_pool: str = ''
head_hidden_size: Optional[int] = None # feat dim of MLP head or AttentionPool output
head_type: str = ''
# Block config
# NOTE: these config items will be overridden by the block cfg (per-block) if they are set there
attn_layer: Optional[str] = None
attn_kwargs: dict = field(default_factory=lambda: dict())
@ -296,10 +304,7 @@ class BottleneckBlock(nn.Module):
mid_chs = make_divisible((in_chs if bottle_in else out_chs) * bottle_ratio)
groups = num_groups(group_size, mid_chs)
self.shortcut = create_shortcut(
downsample, in_chs, out_chs,
stride=stride, dilation=dilation, apply_act=False, layers=layers,
)
self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1)
self.conv2_kxk = layers.conv_norm_act(
@ -316,7 +321,10 @@ class BottleneckBlock(nn.Module):
self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
self.shortcut = create_shortcut(
downsample, in_chs, out_chs,
stride=stride, dilation=dilation, apply_act=False, layers=layers,
)
def init_weights(self, zero_init_last: bool = False):
if zero_init_last and self.shortcut is not None and getattr(self.conv3_1x1.bn, 'weight', None) is not None:
nn.init.zeros_(self.conv3_1x1.bn.weight)
@ -917,7 +925,7 @@ class Stem(nn.Sequential):
def __init__(
self,
in_chs: int,
out_chs: int,
out_chs: Union[int, List[int], Tuple[int, ...]],
kernel_size: int = 3,
stride: int = 4,
pool: str = 'maxpool',
@ -961,10 +969,19 @@ class Stem(nn.Sequential):
curr_stride *= s
prev_feat = conv_name
if pool and 'max' in pool.lower():
if pool:
pool = pool.lower()
assert pool in ('max', 'maxpool', 'avg', 'avgpool', 'max2', 'avg2')
last_feat_idx = i
self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat, stage=0))
self.add_module('pool', nn.MaxPool2d(3, 2, 1))
if pool == 'max2':
self.add_module('pool', nn.MaxPool2d(2))
elif pool == 'avg2':
self.add_module('pool', nn.AvgPool2d(2))
elif 'max' in pool:
self.add_module('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
elif 'avg' in pool:
self.add_module('pool', nn.AvgPool2d(kernel_size=3, stride=2, padding=1, count_include_pad=False))
curr_stride *= 2
prev_feat = 'pool'
@ -1012,11 +1029,14 @@ def create_byob_stem(
else:
stem = layers.conv_norm_act(in_chs, out_chs, 7, stride=2)
else:
# 3x3 stem conv as in RegNet is the default
if pool_type:
stem = Stem(in_chs, out_chs, 3, num_rep=1, pool=pool_type, layers=layers)
if isinstance(out_chs, (tuple, list)):
stem = Stem(in_chs, out_chs, 3, pool=pool_type, layers=layers)
else:
stem = layers.conv_norm_act(in_chs, out_chs, 3, stride=2)
# 3x3 stem conv as in RegNet is the default
if pool_type:
stem = Stem(in_chs, out_chs, 3, num_rep=1, pool=pool_type, layers=layers)
else:
stem = layers.conv_norm_act(in_chs, out_chs, 3, stride=2)
if isinstance(stem, Stem):
feature_info = [dict(f, module='.'.join([feat_prefix, f['module']])) for f in stem.feature_info]
@ -1138,13 +1158,16 @@ def create_byob_stages(
prev_feat = dict(num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}', stage=stage_idx + 1)
feature_info.append(prev_feat)
return nn.Sequential(*stages), feature_info
return nn.Sequential(*stages), feature_info, feat_size
def get_layer_fns(cfg: ByoModelCfg):
def get_layer_fns(cfg: ByoModelCfg, allow_aa: bool = True):
act = get_act_layer(cfg.act_layer)
norm_act = get_norm_act_layer(norm_layer=cfg.norm_layer, act_layer=act)
conv_norm_act = partial(ConvNormAct, norm_layer=cfg.norm_layer, act_layer=act)
if cfg.aa_layer and allow_aa:
conv_norm_act = partial(ConvNormActAa, norm_layer=cfg.norm_layer, act_layer=act, aa_layer=cfg.aa_layer)
else:
conv_norm_act = partial(ConvNormAct, norm_layer=cfg.norm_layer, act_layer=act)
attn = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None
self_attn = partial(get_attn(cfg.self_attn_layer), **cfg.self_attn_kwargs) if cfg.self_attn_layer else None
layer_fn = LayerFn(conv_norm_act=conv_norm_act, norm_act=norm_act, act=act, attn=attn, self_attn=self_attn)
@ -1191,23 +1214,33 @@ class ByobNet(nn.Module):
self.grad_checkpointing = False
cfg = replace(cfg, **kwargs) # overlay kwargs onto cfg
layers = get_layer_fns(cfg)
stem_layers = get_layer_fns(cfg, allow_aa=False) # keep aa off for stem-layers
stage_layers = get_layer_fns(cfg)
if cfg.fixed_input_size:
assert img_size is not None, 'img_size argument is required for fixed input size model'
feat_size = to_2tuple(img_size) if img_size is not None else None
self.feature_info = []
stem_chs = int(round((cfg.stem_chs or cfg.blocks[0].c) * cfg.width_factor))
self.stem, stem_feat = create_byob_stem(in_chans, stem_chs, cfg.stem_type, cfg.stem_pool, layers=layers)
if isinstance(cfg.stem_chs, (list, tuple)):
stem_chs = [int(round(c * cfg.width_factor)) for c in cfg.stem_chs]
else:
stem_chs = int(round((cfg.stem_chs or cfg.blocks[0].c) * cfg.width_factor))
self.stem, stem_feat = create_byob_stem(
in_chs=in_chans,
out_chs=stem_chs,
stem_type=cfg.stem_type,
pool_type=cfg.stem_pool,
layers=stem_layers,
)
self.feature_info.extend(stem_feat[:-1])
feat_size = reduce_feat_size(feat_size, stride=stem_feat[-1]['reduction'])
self.stages, stage_feat = create_byob_stages(
self.stages, stage_feat, feat_size = create_byob_stages(
cfg,
drop_path_rate,
output_stride,
stem_feat[-1],
layers=layers,
layers=stage_layers,
feat_size=feat_size,
)
self.feature_info.extend(stage_feat[:-1])
@ -1216,7 +1249,7 @@ class ByobNet(nn.Module):
prev_chs = stage_feat[-1]['num_chs']
if cfg.num_features:
self.num_features = int(round(cfg.width_factor * cfg.num_features))
self.final_conv = layers.conv_norm_act(prev_chs, self.num_features, 1)
self.final_conv = stage_layers.conv_norm_act(prev_chs, self.num_features, 1)
else:
self.num_features = prev_chs
self.final_conv = nn.Identity()
@ -1225,12 +1258,47 @@ class ByobNet(nn.Module):
self.stage_ends = [f['stage'] for f in self.feature_info]
self.head_hidden_size = self.num_features
self.head = ClassifierHead(
self.num_features,
num_classes,
pool_type=global_pool,
drop_rate=self.drop_rate,
)
assert cfg.head_type in ('', 'classifier', 'norm_mlp_classifier')
if cfg.head_type == 'norm_mlp_classifier':
from timm.layers import NormMlpClassifierHead
assert not cfg.attn_pool, "Cannot use attentional pooling with norm + MLP head"
self.attn_pool = nn.Identity()
self.head = NormMlpClassifierHead(
self.num_features,
num_classes,
hidden_size=cfg.head_hidden_size,
norm_layer=cfg.norm_layer,
act_layer=cfg.act_layer,
)
self.head_hidden_size = self.head.hidden_size
else:
if cfg.attn_pool == 'abs':
from timm.layers import AttentionPool2d
self.attn_pool = AttentionPool2d(
self.num_features,
out_features=cfg.head_hidden_size,
feat_size=feat_size,
qkv_separate=True,
)
self.head_hidden_size = self.attn_pool.out_features
elif cfg.attn_pool == 'rot':
from timm.layers import RotAttentionPool2d
self.attn_pool = RotAttentionPool2d(
self.num_features,
out_features=cfg.head_hidden_size,
ref_feat_size=feat_size,
)
self.head_hidden_size = self.attn_pool.out_features
else:
assert not cfg.attn_pool
self.attn_pool = nn.Identity()
self.head = ClassifierHead(
self.head_hidden_size,
num_classes,
pool_type='' if cfg.attn_pool else global_pool,
drop_rate=self.drop_rate,
)
# init weights
named_apply(partial(_init_weights, zero_init_last=zero_init_last), self)
@ -1345,6 +1413,7 @@ class ByobNet(nn.Module):
return x
def forward_head(self, x, pre_logits: bool = False):
x = self.attn_pool(x)
return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
def forward(self, x):
@ -1834,14 +1903,162 @@ model_cfgs = dict(
stem_type='one',
stem_chs=64,
),
resnet50_clip=ByoModelCfg(
blocks=(
ByoBlockCfg(type='bottle', d=3, c=256, s=1, br=0.25),
ByoBlockCfg(type='bottle', d=4, c=512, s=2, br=0.25),
ByoBlockCfg(type='bottle', d=6, c=1024, s=2, br=0.25),
ByoBlockCfg(type='bottle', d=3, c=2048, s=2, br=0.25),
),
stem_chs=(32, 32, 64),
stem_type='',
stem_pool='avg2',
downsample='avg',
aa_layer='avg',
attn_pool='abs',
head_hidden_size=1024,
),
resnet101_clip=ByoModelCfg(
blocks=(
ByoBlockCfg(type='bottle', d=3, c=256, s=1, br=0.25),
ByoBlockCfg(type='bottle', d=4, c=512, s=2, br=0.25),
ByoBlockCfg(type='bottle', d=23, c=1024, s=2, br=0.25),
ByoBlockCfg(type='bottle', d=3, c=2048, s=2, br=0.25),
),
stem_chs=(32, 32, 64),
stem_type='',
stem_pool='avg2',
downsample='avg',
aa_layer='avg',
attn_pool='abs',
head_hidden_size=512,
),
resnet50x4_clip=ByoModelCfg(
blocks=(
ByoBlockCfg(type='bottle', d=4, c=256, s=1, br=0.25),
ByoBlockCfg(type='bottle', d=6, c=512, s=2, br=0.25),
ByoBlockCfg(type='bottle', d=10, c=1024, s=2, br=0.25),
ByoBlockCfg(type='bottle', d=6, c=2048, s=2, br=0.25),
),
width_factor=1.25,
stem_chs=(32, 32, 64),
stem_type='',
stem_pool='avg2',
downsample='avg',
aa_layer='avg',
attn_pool='abs',
head_hidden_size=640,
),
resnet50x16_clip=ByoModelCfg(
blocks=(
ByoBlockCfg(type='bottle', d=6, c=256, s=1, br=0.25),
ByoBlockCfg(type='bottle', d=8, c=512, s=2, br=0.25),
ByoBlockCfg(type='bottle', d=18, c=1024, s=2, br=0.25),
ByoBlockCfg(type='bottle', d=8, c=2048, s=2, br=0.25),
),
stem_chs=(32, 32, 64),
stem_type='',
stem_pool='avg2',
downsample='avg',
aa_layer='avg',
attn_pool='abs',
head_hidden_size=768,
),
resnet50x64_clip=ByoModelCfg(
blocks=(
ByoBlockCfg(type='bottle', d=3, c=256, s=1, br=0.25),
ByoBlockCfg(type='bottle', d=15, c=512, s=2, br=0.25),
ByoBlockCfg(type='bottle', d=36, c=1024, s=2, br=0.25),
ByoBlockCfg(type='bottle', d=10, c=2048, s=2, br=0.25),
),
stem_chs=(32, 32, 64),
stem_type='',
stem_pool='avg2',
downsample='avg',
aa_layer='avg',
attn_pool='abs',
head_hidden_size=1024,
),
resnet50_nmlp=ByoModelCfg(
blocks=(
ByoBlockCfg(type='bottle', d=3, c=256, s=1, br=0.25),
ByoBlockCfg(type='bottle', d=4, c=512, s=2, br=0.25),
ByoBlockCfg(type='bottle', d=6, c=1024, s=2, br=0.25),
ByoBlockCfg(type='bottle', d=3, c=2048, s=2, br=0.25),
),
stem_chs=(32, 32, 64),
stem_type='',
stem_pool='avg2',
downsample='avg',
aa_layer='avg',
head_hidden_size=1024,
head_type='norm_mlp_classifier',
),
)
def _convert_openai_clip(
state_dict: Dict[str, torch.Tensor],
model: ByobNet,
prefix: str = 'visual.',
) -> Dict[str, torch.Tensor]:
import re
def _stage_sub(m):
stage_idx = int(m.group(1)) - 1
layer_idx, layer_type, layer_id = int(m.group(2)), m.group(3), int(m.group(4))
prefix_str = f'stages.{stage_idx}.{layer_idx}.'
id_map = {1: 'conv1_1x1.', 2: 'conv2_kxk.', 3: 'conv3_1x1.'}
suffix_str = id_map[layer_id] + layer_type
return prefix_str + suffix_str
def _down_sub(m):
stage_idx = int(m.group(1)) - 1
layer_idx, layer_id = int(m.group(2)), int(m.group(3))
return f'stages.{stage_idx}.{layer_idx}.shortcut.' + ('conv.conv' if layer_id == 0 else 'conv.bn')
out_dict = {}
for k, v in state_dict.items():
if not k.startswith(prefix):
continue
k = re.sub(rf'{prefix}conv([0-9])', r'stem.conv\1.conv', k)
k = re.sub(rf'{prefix}bn([0-9])', r'stem.conv\1.bn', k)
k = re.sub(rf'{prefix}layer([0-9])\.([0-9])\.([a-z]+)([0-9])', _stage_sub, k)
k = re.sub(rf'{prefix}layer([0-9])\.([0-9])\.downsample\.([0-9])', _down_sub, k)
if k.startswith(f'{prefix}attnpool'):
k = k.replace(prefix + 'attnpool', 'attn_pool')
k = k.replace('positional_embedding', 'pos_embed')
k = k.replace('q_proj', 'q')
k = k.replace('k_proj', 'k')
k = k.replace('v_proj', 'v')
k = k.replace('c_proj', 'proj')
out_dict[k] = v
return out_dict
def checkpoint_filter_fn(
state_dict: Dict[str, torch.Tensor],
model: ByobNet
):
if 'visual.conv1.weight' in state_dict:
state_dict = _convert_openai_clip(state_dict, model)
return state_dict
def _create_byobnet(variant, pretrained=False, **kwargs):
return build_model_with_cfg(
ByobNet, variant, pretrained,
model_cfg=model_cfgs[variant],
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(flatten_sequential=True),
#pretrained_strict=False,
**kwargs)
@ -2035,6 +2252,38 @@ default_cfgs = generate_default_cfgs({
crop_pct=0.9,
first_conv=('stem.conv_kxk.0.conv', 'stem.conv_scale.conv'),
),
'resnet50_clip.openai': _cfgr(
hf_hub_id='timm/',
hf_hub_filename='open_clip_pytorch_model.bin',
num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
fixed_input_size=True, input_size=(3, 224, 224), pool_size=(7, 7)
),
'resnet101_clip.openai': _cfgr(
hf_hub_id='timm/',
hf_hub_filename='open_clip_pytorch_model.bin',
num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
fixed_input_size=True, input_size=(3, 224, 224), pool_size=(7, 7)
),
'resnet50x4_clip.openai': _cfgr(
hf_hub_id='timm/',
hf_hub_filename='open_clip_pytorch_model.bin',
num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
fixed_input_size=True, input_size=(3, 288, 288), pool_size=(9, 9)
),
'resnet50x16_clip.openai': _cfgr(
hf_hub_id='timm/',
hf_hub_filename='open_clip_pytorch_model.bin',
num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
fixed_input_size=True, input_size=(3, 384, 384), pool_size=(12, 12)
),
'resnet50x64_clip.openai': _cfgr(
hf_hub_id='timm/',
hf_hub_filename='open_clip_pytorch_model.bin',
num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
fixed_input_size=True, input_size=(3, 448, 448), pool_size=(14, 14)
),
})
@ -2337,3 +2586,45 @@ def mobileone_s4(pretrained=False, **kwargs) -> ByobNet:
"""
"""
return _create_byobnet('mobileone_s4', pretrained=pretrained, **kwargs)
@register_model
def resnet50_clip(pretrained=False, **kwargs) -> ByobNet:
""" OpenAI Modified ResNet-50 CLIP image tower
"""
return _create_byobnet('resnet50_clip', pretrained=pretrained, **kwargs)
@register_model
def resnet101_clip(pretrained=False, **kwargs) -> ByobNet:
""" OpenAI Modified ResNet-101 CLIP image tower
"""
return _create_byobnet('resnet101_clip', pretrained=pretrained, **kwargs)
@register_model
def resnet50x4_clip(pretrained=False, **kwargs) -> ByobNet:
""" OpenAI Modified ResNet-50x4 CLIP image tower
"""
return _create_byobnet('resnet50x4_clip', pretrained=pretrained, **kwargs)
@register_model
def resnet50x16_clip(pretrained=False, **kwargs) -> ByobNet:
""" OpenAI Modified ResNet-50x16 CLIP image tower
"""
return _create_byobnet('resnet50x16_clip', pretrained=pretrained, **kwargs)
@register_model
def resnet50x64_clip(pretrained=False, **kwargs) -> ByobNet:
""" OpenAI Modified ResNet-50x64 CLIP image tower
"""
return _create_byobnet('resnet50x64_clip', pretrained=pretrained, **kwargs)
@register_model
def resnet50_nmlp(pretrained=False, **kwargs) -> ByobNet:
"""
"""
return _create_byobnet('resnet50_nmlp', pretrained=pretrained, **kwargs)