Make 2d attention pool modules compatible with head interface. Use attention pool in CLIP ResNets as head. Make separate set of GAP models w/ avg pool instead of attn pool.

This commit is contained in:
Ross Wightman 2024-06-11 21:32:07 -07:00
parent 30ffa152de
commit cdc7bcea69
2 changed files with 209 additions and 104 deletions

View File

@ -41,9 +41,12 @@ class RotAttentionPool2d(nn.Module):
num_heads: Optional[int] = None, num_heads: Optional[int] = None,
qkv_bias: bool = True, qkv_bias: bool = True,
qkv_separate: bool = False, qkv_separate: bool = False,
drop: float = 0., pool_type: str = 'token',
avg_token: bool = True,
drop_rate: float = 0.,
): ):
super().__init__() super().__init__()
assert pool_type in ('', 'token')
self.embed_dim = embed_dim = embed_dim or in_features self.embed_dim = embed_dim = embed_dim or in_features
self.in_features = in_features self.in_features = in_features
self.out_features = out_features or in_features self.out_features = out_features or in_features
@ -56,6 +59,7 @@ class RotAttentionPool2d(nn.Module):
num_heads = embed_dim // head_dim num_heads = embed_dim // head_dim
self.num_heads = num_heads self.num_heads = num_heads
self.head_dim = head_dim self.head_dim = head_dim
self.pool_type = pool_type.lower()
self.scale = self.head_dim ** -0.5 self.scale = self.head_dim ** -0.5
self.fused_attn = use_fused_attn() self.fused_attn = use_fused_attn()
@ -66,6 +70,7 @@ class RotAttentionPool2d(nn.Module):
self.qkv = None self.qkv = None
else: else:
self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias) self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias)
self.drop = nn.Dropout(drop_rate)
self.proj = nn.Linear(embed_dim, self.out_features) self.proj = nn.Linear(embed_dim, self.out_features)
self.pos_embed = RotaryEmbedding(self.head_dim, in_pixels=False, ref_feat_shape=ref_feat_size) self.pos_embed = RotaryEmbedding(self.head_dim, in_pixels=False, ref_feat_shape=ref_feat_size)
@ -83,6 +88,23 @@ class RotAttentionPool2d(nn.Module):
trunc_normal_(self.qkv.weight, std=in_features ** -0.5) trunc_normal_(self.qkv.weight, std=in_features ** -0.5)
nn.init.zeros_(self.qkv.bias) nn.init.zeros_(self.qkv.bias)
def reset(self, num_classes: Optional[int] = None, pool_type: Optional[str] = None):
# NOTE: this module is being used as a head, so need compatible reset()
if pool_type is not None:
assert pool_type in ('', 'token')
self.pool_type = pool_type
if num_classes is not None:
self.proj = nn.Linear(self.in_features, num_classes) if num_classes > 0 else nn.Identity()
self.out_features = num_classes if num_classes > 0 else self.embed_dim
def _pool(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
if self.pool_type == 'token':
x = x[:, 0]
else:
# if not pooled, return spatial output without token
x = x[:, 1:].reshape(x.shape[0], H, W, -1).permute(0, 3, 1, 2)
return x
def forward(self, x, pre_logits: bool = False): def forward(self, x, pre_logits: bool = False):
B, _, H, W = x.shape B, _, H, W = x.shape
N = H * W N = H * W
@ -111,8 +133,10 @@ class RotAttentionPool2d(nn.Module):
x = x[:, 0] x = x[:, 0]
x = self.drop(x) x = self.drop(x)
if pre_logits: if pre_logits:
x = self._pool(x, H, W)
return x return x
x = self.proj(x) x = self.proj(x)
x = self._pool(x, H, W)
return x return x
@ -137,9 +161,12 @@ class AttentionPool2d(nn.Module):
num_heads: Optional[int] = None, num_heads: Optional[int] = None,
qkv_bias: bool = True, qkv_bias: bool = True,
qkv_separate: bool = False, qkv_separate: bool = False,
drop: float = 0., pool_type: str = 'token',
learned_token: bool = False,
drop_rate: float = 0.,
): ):
super().__init__() super().__init__()
assert pool_type in ('', 'token')
self.embed_dim = embed_dim = embed_dim or in_features self.embed_dim = embed_dim = embed_dim or in_features
self.in_features = in_features self.in_features = in_features
self.out_features = out_features or in_features self.out_features = out_features or in_features
@ -153,9 +180,15 @@ class AttentionPool2d(nn.Module):
self.seq_len = self.feat_size[0] * self.feat_size[1] self.seq_len = self.feat_size[0] * self.feat_size[1]
self.num_heads = num_heads self.num_heads = num_heads
self.head_dim = head_dim self.head_dim = head_dim
self.pool_type = pool_type
self.scale = self.head_dim ** -0.5 self.scale = self.head_dim ** -0.5
self.fused_attn = use_fused_attn() self.fused_attn = use_fused_attn()
if learned_token:
self.token = nn.Parameter(torch.zeros(1, embed_dim))
else:
self.token = None
if qkv_separate: if qkv_separate:
self.q = nn.Linear(in_features, embed_dim, bias=qkv_bias) self.q = nn.Linear(in_features, embed_dim, bias=qkv_bias)
self.k = nn.Linear(in_features, embed_dim, bias=qkv_bias) self.k = nn.Linear(in_features, embed_dim, bias=qkv_bias)
@ -164,7 +197,7 @@ class AttentionPool2d(nn.Module):
else: else:
self.q = self.k = self.v = None self.q = self.k = self.v = None
self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias) self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias)
self.drop = nn.Dropout(drop) self.drop = nn.Dropout(drop_rate)
self.proj = nn.Linear(embed_dim, self.out_features) self.proj = nn.Linear(embed_dim, self.out_features)
self.pos_embed = nn.Parameter(torch.zeros(self.seq_len + 1, in_features)) self.pos_embed = nn.Parameter(torch.zeros(self.seq_len + 1, in_features))
@ -185,11 +218,31 @@ class AttentionPool2d(nn.Module):
nn.init.zeros_(self.qkv.bias) nn.init.zeros_(self.qkv.bias)
trunc_normal_(self.pos_embed, std=in_features ** -0.5) trunc_normal_(self.pos_embed, std=in_features ** -0.5)
def reset(self, num_classes: Optional[int] = None, pool_type: Optional[str] = None):
# NOTE: this module is being used as a head, so need compatible reset()
if pool_type is not None:
assert pool_type in ('', 'token')
self.pool_type = pool_type
if num_classes is not None:
self.proj = nn.Linear(self.in_features, num_classes) if num_classes > 0 else nn.Identity()
self.out_features = num_classes if num_classes > 0 else self.embed_dim
def _pool(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
if self.pool_type == 'token':
x = x[:, 0]
else:
# if not pooled, return spatial output without token
x = x[:, 1:].reshape(x.shape[0], H, W, -1).permute(0, 3, 1, 2)
return x
def forward(self, x, pre_logits: bool = False): def forward(self, x, pre_logits: bool = False):
B, _, H, W = x.shape B, _, H, W = x.shape
N = H * W N = H * W
x = x.flatten(2).transpose(1, 2) x = x.flatten(2).transpose(1, 2)
x = torch.cat([x.mean(1, keepdim=True), x], dim=1) if self.token is not None:
x = torch.cat([self.token.expand(x.shape[0], -1, -1), x], dim=1)
else:
x = torch.cat([x.mean(1, keepdim=True), x], dim=1)
pos_embed = resample_abs_pos_embed(self.pos_embed.unsqueeze(0), (H, W), num_prefix_tokens=1) pos_embed = resample_abs_pos_embed(self.pos_embed.unsqueeze(0), (H, W), num_prefix_tokens=1)
x = x + pos_embed x = x + pos_embed
@ -209,9 +262,10 @@ class AttentionPool2d(nn.Module):
attn = attn.softmax(dim=-1) attn = attn.softmax(dim=-1)
x = attn @ v x = attn @ v
x = x.transpose(1, 2).reshape(B, N + 1, -1) x = x.transpose(1, 2).reshape(B, N + 1, -1)
x = x[:, 0]
x = self.drop(x) x = self.drop(x)
if pre_logits: if pre_logits:
x = self._pool(x, H, W)
return x return x
x = self.proj(x) x = self.proj(x)
x = self._pool(x, H, W)
return x return x

View File

@ -37,8 +37,11 @@ import torch
import torch.nn as nn import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from timm.layers import ClassifierHead, ConvNormAct, BatchNormAct2d, DropPath, AvgPool2dSame, \ from timm.layers import (
create_conv2d, get_act_layer, get_norm_act_layer, get_attn, make_divisible, to_2tuple, EvoNorm2dS0a ClassifierHead, NormMlpClassifierHead, ConvNormAct, BatchNormAct2d, EvoNorm2dS0a,
AttentionPool2d, RotAttentionPool2d, DropPath, AvgPool2dSame,
create_conv2d, get_act_layer, get_norm_act_layer, get_attn, make_divisible, to_2tuple,
)
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._features import feature_take_indices from ._features import feature_take_indices
from ._manipulate import named_apply, checkpoint_seq from ._manipulate import named_apply, checkpoint_seq
@ -83,7 +86,7 @@ class ByoModelCfg:
# Head config # Head config
head_hidden_size: Optional[int] = None # feat dim of MLP head or AttentionPool output head_hidden_size: Optional[int] = None # feat dim of MLP head or AttentionPool output
head_type: str = '' head_type: str = 'classifier'
# Block config # Block config
# NOTE: these config items will be overridden by the block cfg (per-block) if they are set there # NOTE: these config items will be overridden by the block cfg (per-block) if they are set there
@ -1186,7 +1189,7 @@ class ByobNet(nn.Module):
cfg: ByoModelCfg, cfg: ByoModelCfg,
num_classes: int = 1000, num_classes: int = 1000,
in_chans: int = 3, in_chans: int = 3,
global_pool: str = 'avg', global_pool: Optional[str] = None,
output_stride: int = 32, output_stride: int = 32,
img_size: Optional[Union[int, Tuple[int, int]]] = None, img_size: Optional[Union[int, Tuple[int, int]]] = None,
drop_rate: float = 0., drop_rate: float = 0.,
@ -1257,76 +1260,59 @@ class ByobNet(nn.Module):
self.stage_ends = [f['stage'] for f in self.feature_info] self.stage_ends = [f['stage'] for f in self.feature_info]
self.head_hidden_size = self.num_features self.head_hidden_size = self.num_features
self.global_pool = global_pool assert cfg.head_type in ('', 'classifier', 'mlp', 'attn_abs', 'attn_rot')
assert cfg.head_type in ('', 'classifier', 'norm_mlp_classifier') if cfg.head_type == 'mlp':
if cfg.head_type == 'norm_mlp_classifier': if global_pool is None:
from timm.layers import NormMlpClassifierHead global_pool = 'avg'
assert not cfg.attn_pool, "Cannot use attentional pooling with norm + MLP head"
self.attn_pool = nn.Identity()
self.head = NormMlpClassifierHead( self.head = NormMlpClassifierHead(
self.num_features, self.num_features,
num_classes, num_classes,
hidden_size=cfg.head_hidden_size, hidden_size=cfg.head_hidden_size,
pool_type=global_pool,
norm_layer=cfg.norm_layer, norm_layer=cfg.norm_layer,
act_layer=cfg.act_layer, act_layer=cfg.act_layer,
drop_rate=self.drop_rate,
) )
self.head_hidden_size = self.head.hidden_size self.head_hidden_size = self.head.hidden_size
elif cfg.head_type == 'attn_abs':
if global_pool is None:
global_pool = 'token'
assert global_pool in ('', 'token')
self.head = AttentionPool2d(
self.num_features,
embed_dim=cfg.head_hidden_size,
out_features=num_classes,
feat_size=feat_size,
pool_type=global_pool,
drop_rate=self.drop_rate,
qkv_separate=True,
)
self.head_hidden_size = self.head.embed_dim
elif cfg.head_type =='attn_rot':
if global_pool is None:
global_pool = 'token'
assert global_pool in ('', 'token')
self.head = RotAttentionPool2d(
self.num_features,
embed_dim=cfg.head_hidden_size,
out_features=num_classes,
ref_feat_size=feat_size,
pool_type=global_pool,
drop_rate=self.drop_rate,
qkv_separate=True,
)
self.head_hidden_size = self.head.embed_dim
else: else:
# FIXME evaluating different head vs pool configurations if global_pool is None:
if False: global_pool = 'avg'
if global_pool == 'attn_abs': assert cfg.head_hidden_size is None
from timm.layers import AttentionPool2d self.head = ClassifierHead(
self.attn_pool = AttentionPool2d( self.num_features,
self.num_features, num_classes,
out_features=cfg.head_hidden_size, pool_type=global_pool,
feat_size=feat_size, drop_rate=self.drop_rate,
qkv_separate=True, )
) self.global_pool = global_pool
global_pool = '' # clear for ClassifierHead
self.head_hidden_size = self.attn_pool.out_features
elif global_pool =='attn_rot':
from timm.layers import RotAttentionPool2d
self.attn_pool = RotAttentionPool2d(
self.num_features,
out_features=cfg.head_hidden_size,
ref_feat_size=feat_size,
)
global_pool = '' # clear for ClassifierHead
self.head_hidden_size = self.attn_pool.out_features
else:
self.attn_pool = nn.Identity()
self.head = ClassifierHead(
self.head_hidden_size,
num_classes,
pool_type=global_pool,
drop_rate=self.drop_rate,
)
else:
if global_pool == 'attn_abs':
from timm.layers import AttentionPool2d
self.head = AttentionPool2d(
self.num_features,
out_features=num_classes,
feat_size=feat_size,
qkv_separate=True,
)
self.head_hidden_size = self.head.embed_dim
elif global_pool == 'attn_rot':
from timm.layers import RotAttentionPool2d
self.head = RotAttentionPool2d(
self.num_features,
out_features=num_classes,
ref_feat_size=feat_size,
)
self.head_hidden_size = self.head.embed_dim
else:
self.head = ClassifierHead(
self.num_features,
num_classes,
pool_type=global_pool,
drop_rate=self.drop_rate,
)
# init weights # init weights
named_apply(partial(_init_weights, zero_init_last=zero_init_last), self) named_apply(partial(_init_weights, zero_init_last=zero_init_last), self)
@ -1352,9 +1338,6 @@ class ByobNet(nn.Module):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes self.num_classes = num_classes
if global_pool is not None:
if self.global_pool in ('attn_abs', 'attn_rot'):
raise RuntimeError('Cannot change attention pool on head reset.')
self.head.reset(num_classes, global_pool) self.head.reset(num_classes, global_pool)
def forward_intermediates( def forward_intermediates(
@ -1444,7 +1427,6 @@ class ByobNet(nn.Module):
return x return x
def forward_head(self, x, pre_logits: bool = False): def forward_head(self, x, pre_logits: bool = False):
#x = self.attn_pool(x)
return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x) return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
def forward(self, x): def forward(self, x):
@ -1947,7 +1929,7 @@ model_cfgs = dict(
stem_pool='avg2', stem_pool='avg2',
downsample='avg', downsample='avg',
aa_layer='avg', aa_layer='avg',
head_hidden_size=1024, head_type='attn_abs',
), ),
resnet101_clip=ByoModelCfg( resnet101_clip=ByoModelCfg(
@ -1962,7 +1944,8 @@ model_cfgs = dict(
stem_pool='avg2', stem_pool='avg2',
downsample='avg', downsample='avg',
aa_layer='avg', aa_layer='avg',
head_hidden_size=512, head_type='attn_abs',
#head_hidden_size=512,
), ),
resnet50x4_clip=ByoModelCfg( resnet50x4_clip=ByoModelCfg(
@ -1978,7 +1961,8 @@ model_cfgs = dict(
stem_pool='avg2', stem_pool='avg2',
downsample='avg', downsample='avg',
aa_layer='avg', aa_layer='avg',
head_hidden_size=640, head_type='attn_abs',
#head_hidden_size=640,
), ),
resnet50x16_clip=ByoModelCfg( resnet50x16_clip=ByoModelCfg(
@ -1994,7 +1978,8 @@ model_cfgs = dict(
stem_pool='avg2', stem_pool='avg2',
downsample='avg', downsample='avg',
aa_layer='avg', aa_layer='avg',
head_hidden_size=768, head_type='attn_abs',
#head_hidden_size=768,
), ),
resnet50x64_clip=ByoModelCfg( resnet50x64_clip=ByoModelCfg(
@ -2010,10 +1995,11 @@ model_cfgs = dict(
stem_pool='avg2', stem_pool='avg2',
downsample='avg', downsample='avg',
aa_layer='avg', aa_layer='avg',
head_hidden_size=1024, head_type='attn_abs',
#head_hidden_size=1024,
), ),
resnet50_nmlp=ByoModelCfg( resnet50_mlp=ByoModelCfg(
blocks=( blocks=(
ByoBlockCfg(type='bottle', d=3, c=256, s=1, br=0.25), ByoBlockCfg(type='bottle', d=3, c=256, s=1, br=0.25),
ByoBlockCfg(type='bottle', d=4, c=512, s=2, br=0.25), ByoBlockCfg(type='bottle', d=4, c=512, s=2, br=0.25),
@ -2026,9 +2012,11 @@ model_cfgs = dict(
downsample='avg', downsample='avg',
aa_layer='avg', aa_layer='avg',
head_hidden_size=1024, head_hidden_size=1024,
head_type='norm_mlp_classifier', head_type='mlp',
), ),
) )
for k in ('resnet50_clip', 'resnet101_clip', 'resnet50x4_clip', 'resnet50x16_clip', 'resnet50x64_clip'):
model_cfgs[k + '_gap'] = replace(model_cfgs[k], head_type='classifier')
def _convert_openai_clip( def _convert_openai_clip(
@ -2036,6 +2024,7 @@ def _convert_openai_clip(
model: ByobNet, model: ByobNet,
prefix: str = 'visual.', prefix: str = 'visual.',
) -> Dict[str, torch.Tensor]: ) -> Dict[str, torch.Tensor]:
model_has_attn_pool = isinstance(model.head, (RotAttentionPool2d, AttentionPool2d))
import re import re
def _stage_sub(m): def _stage_sub(m):
@ -2060,6 +2049,8 @@ def _convert_openai_clip(
k = re.sub(rf'{prefix}layer([0-9])\.([0-9]+)\.([a-z]+)([0-9])', _stage_sub, k) k = re.sub(rf'{prefix}layer([0-9])\.([0-9]+)\.([a-z]+)([0-9])', _stage_sub, k)
k = re.sub(rf'{prefix}layer([0-9])\.([0-9]+)\.downsample\.([0-9])', _down_sub, k) k = re.sub(rf'{prefix}layer([0-9])\.([0-9]+)\.downsample\.([0-9])', _down_sub, k)
if k.startswith(f'{prefix}attnpool'): if k.startswith(f'{prefix}attnpool'):
if not model_has_attn_pool:
continue
k = k.replace(prefix + 'attnpool', 'head') #'attn_pool') k = k.replace(prefix + 'attnpool', 'head') #'attn_pool')
k = k.replace('positional_embedding', 'pos_embed') k = k.replace('positional_embedding', 'pos_embed')
k = k.replace('q_proj', 'q') k = k.replace('q_proj', 'q')
@ -2081,17 +2072,11 @@ def checkpoint_filter_fn(
def _create_byobnet(variant, pretrained=False, **kwargs): def _create_byobnet(variant, pretrained=False, **kwargs):
strict = True
if 'clip' in variant and kwargs.get('global_pool', None) != 'attn_abs':
# NOTE: a hack to allow removing attention pool from CLIP ResNet variants
strict = False
return build_model_with_cfg( return build_model_with_cfg(
ByobNet, variant, pretrained, ByobNet, variant, pretrained,
model_cfg=model_cfgs[variant], model_cfg=model_cfgs[variant],
pretrained_filter_fn=checkpoint_filter_fn, pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(flatten_sequential=True), feature_cfg=dict(flatten_sequential=True),
pretrained_strict=strict,
**kwargs, **kwargs,
) )
@ -2287,42 +2272,78 @@ default_cfgs = generate_default_cfgs({
first_conv=('stem.conv_kxk.0.conv', 'stem.conv_scale.conv'), first_conv=('stem.conv_kxk.0.conv', 'stem.conv_scale.conv'),
), ),
# original attention pool head variants
'resnet50_clip.openai': _cfgr( 'resnet50_clip.openai': _cfgr(
hf_hub_id='timm/', hf_hub_id='timm/',
hf_hub_filename='open_clip_pytorch_model.bin', hf_hub_filename='open_clip_pytorch_model.bin',
num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=1024, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
fixed_input_size=True, input_size=(3, 224, 224), pool_size=(7, 7), fixed_input_size=True, input_size=(3, 224, 224), pool_size=(7, 7),
classifier = 'head.proj', classifier = 'head.proj',
), ),
'resnet101_clip.openai': _cfgr( 'resnet101_clip.openai': _cfgr(
hf_hub_id='timm/', hf_hub_id='timm/',
hf_hub_filename='open_clip_pytorch_model.bin', hf_hub_filename='open_clip_pytorch_model.bin',
num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
fixed_input_size=True, input_size=(3, 224, 224), pool_size=(7, 7), fixed_input_size=True, input_size=(3, 224, 224), pool_size=(7, 7),
classifier='head.proj', classifier='head.proj',
), ),
'resnet50x4_clip.openai': _cfgr( 'resnet50x4_clip.openai': _cfgr(
hf_hub_id='timm/', hf_hub_id='timm/',
hf_hub_filename='open_clip_pytorch_model.bin', hf_hub_filename='open_clip_pytorch_model.bin',
num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=640, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
fixed_input_size=True, input_size=(3, 288, 288), pool_size=(9, 9), fixed_input_size=True, input_size=(3, 288, 288), pool_size=(9, 9),
classifier = 'head.proj', classifier = 'head.proj',
), ),
'resnet50x16_clip.openai': _cfgr( 'resnet50x16_clip.openai': _cfgr(
hf_hub_id='timm/', hf_hub_id='timm/',
hf_hub_filename='open_clip_pytorch_model.bin', hf_hub_filename='open_clip_pytorch_model.bin',
num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=768, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
fixed_input_size=True, input_size=(3, 384, 384), pool_size=(12, 12), fixed_input_size=True, input_size=(3, 384, 384), pool_size=(12, 12),
classifier = 'head.proj', classifier = 'head.proj',
), ),
'resnet50x64_clip.openai': _cfgr( 'resnet50x64_clip.openai': _cfgr(
hf_hub_id='timm/', hf_hub_id='timm/',
hf_hub_filename='open_clip_pytorch_model.bin', hf_hub_filename='open_clip_pytorch_model.bin',
num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=1024, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
fixed_input_size=True, input_size=(3, 448, 448), pool_size=(14, 14), fixed_input_size=True, input_size=(3, 448, 448), pool_size=(14, 14),
classifier = 'head.proj', classifier = 'head.proj',
), ),
# avg-pool w/ optional standard classifier head variants
'resnet50_clip_gap.openai': _cfgr(
hf_hub_id='timm/resnet50_clip.openai',
hf_hub_filename='open_clip_pytorch_model.bin',
num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 224, 224), pool_size=(7, 7),
),
'resnet101_clip_gap.openai': _cfgr(
hf_hub_id='timm/resnet101_clip.openai',
hf_hub_filename='open_clip_pytorch_model.bin',
num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 224, 224), pool_size=(7, 7),
),
'resnet50x4_clip_gap.openai': _cfgr(
hf_hub_id='timm/resnet50x4_clip.openai',
hf_hub_filename='open_clip_pytorch_model.bin',
num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 288, 288), pool_size=(9, 9),
),
'resnet50x16_clip_gap.openai': _cfgr(
hf_hub_id='timm/resnet50x16_clip.openai',
hf_hub_filename='open_clip_pytorch_model.bin',
num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 384, 384), pool_size=(12, 12),
),
'resnet50x64_clip_gap.openai': _cfgr(
hf_hub_id='timm/resnet50x64_clip.openai',
hf_hub_filename='open_clip_pytorch_model.bin',
num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 448, 448), pool_size=(14, 14),
),
'resnet50_mlp.untrained': _cfgr(
input_size=(3, 256, 256), pool_size=(8, 8),
),
}) })
@ -2631,44 +2652,74 @@ def mobileone_s4(pretrained=False, **kwargs) -> ByobNet:
def resnet50_clip(pretrained=False, **kwargs) -> ByobNet: def resnet50_clip(pretrained=False, **kwargs) -> ByobNet:
""" OpenAI Modified ResNet-50 CLIP image tower """ OpenAI Modified ResNet-50 CLIP image tower
""" """
model_args = dict(global_pool='attn_abs') return _create_byobnet('resnet50_clip', pretrained=pretrained, **kwargs)
return _create_byobnet('resnet50_clip', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model @register_model
def resnet101_clip(pretrained=False, **kwargs) -> ByobNet: def resnet101_clip(pretrained=False, **kwargs) -> ByobNet:
""" OpenAI Modified ResNet-101 CLIP image tower """ OpenAI Modified ResNet-101 CLIP image tower
""" """
model_args = dict(global_pool='attn_abs') return _create_byobnet('resnet101_clip', pretrained=pretrained, **kwargs)
return _create_byobnet('resnet101_clip', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model @register_model
def resnet50x4_clip(pretrained=False, **kwargs) -> ByobNet: def resnet50x4_clip(pretrained=False, **kwargs) -> ByobNet:
""" OpenAI Modified ResNet-50x4 CLIP image tower """ OpenAI Modified ResNet-50x4 CLIP image tower
""" """
model_args = dict(global_pool='attn_abs') return _create_byobnet('resnet50x4_clip', pretrained=pretrained, **kwargs)
return _create_byobnet('resnet50x4_clip', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model @register_model
def resnet50x16_clip(pretrained=False, **kwargs) -> ByobNet: def resnet50x16_clip(pretrained=False, **kwargs) -> ByobNet:
""" OpenAI Modified ResNet-50x16 CLIP image tower """ OpenAI Modified ResNet-50x16 CLIP image tower
""" """
model_args = dict(global_pool='attn_abs') return _create_byobnet('resnet50x16_clip', pretrained=pretrained, **kwargs)
return _create_byobnet('resnet50x16_clip', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model @register_model
def resnet50x64_clip(pretrained=False, **kwargs) -> ByobNet: def resnet50x64_clip(pretrained=False, **kwargs) -> ByobNet:
""" OpenAI Modified ResNet-50x64 CLIP image tower """ OpenAI Modified ResNet-50x64 CLIP image tower
""" """
model_args = dict(global_pool='attn_abs') return _create_byobnet('resnet50x64_clip', pretrained=pretrained, **kwargs)
return _create_byobnet('resnet50x64_clip', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model @register_model
def resnet50_nmlp(pretrained=False, **kwargs) -> ByobNet: def resnet50_clip_gap(pretrained=False, **kwargs) -> ByobNet:
""" OpenAI Modified ResNet-50 CLIP image tower w/ avg pool (no attention pool)
"""
return _create_byobnet('resnet50_clip_gap', pretrained=pretrained, **kwargs)
@register_model
def resnet101_clip_gap(pretrained=False, **kwargs) -> ByobNet:
""" OpenAI Modified ResNet-101 CLIP image tower w/ avg pool (no attention pool)
"""
return _create_byobnet('resnet101_clip_gap', pretrained=pretrained, **kwargs)
@register_model
def resnet50x4_clip_gap(pretrained=False, **kwargs) -> ByobNet:
""" OpenAI Modified ResNet-50x4 CLIP image tower w/ avg pool (no attention pool)
"""
return _create_byobnet('resnet50x4_clip_gap', pretrained=pretrained, **kwargs)
@register_model
def resnet50x16_clip_gap(pretrained=False, **kwargs) -> ByobNet:
""" OpenAI Modified ResNet-50x16 CLIP image tower w/ avg pool (no attention pool)
"""
return _create_byobnet('resnet50x16_clip_gap', pretrained=pretrained, **kwargs)
@register_model
def resnet50x64_clip_gap(pretrained=False, **kwargs) -> ByobNet:
""" OpenAI Modified ResNet-50x64 CLIP image tower w/ avg pool (no attention pool)
"""
return _create_byobnet('resnet50x64_clip_gap', pretrained=pretrained, **kwargs)
@register_model
def resnet50_mlp(pretrained=False, **kwargs) -> ByobNet:
""" """
""" """
return _create_byobnet('resnet50_nmlp', pretrained=pretrained, **kwargs) return _create_byobnet('resnet50_mlp', pretrained=pretrained, **kwargs)