mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Fix load of larger ResNet CLIP models, experimenting with making AttentionPool *the* head, seems to fine-tune better, one less layer.
This commit is contained in:
parent
5e9ff5798f
commit
30ffa152de
@ -41,9 +41,10 @@ class RotAttentionPool2d(nn.Module):
|
|||||||
num_heads: Optional[int] = None,
|
num_heads: Optional[int] = None,
|
||||||
qkv_bias: bool = True,
|
qkv_bias: bool = True,
|
||||||
qkv_separate: bool = False,
|
qkv_separate: bool = False,
|
||||||
|
drop: float = 0.,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
embed_dim = embed_dim or in_features
|
self.embed_dim = embed_dim = embed_dim or in_features
|
||||||
self.in_features = in_features
|
self.in_features = in_features
|
||||||
self.out_features = out_features or in_features
|
self.out_features = out_features or in_features
|
||||||
ref_feat_size = to_2tuple(ref_feat_size)
|
ref_feat_size = to_2tuple(ref_feat_size)
|
||||||
@ -82,7 +83,7 @@ class RotAttentionPool2d(nn.Module):
|
|||||||
trunc_normal_(self.qkv.weight, std=in_features ** -0.5)
|
trunc_normal_(self.qkv.weight, std=in_features ** -0.5)
|
||||||
nn.init.zeros_(self.qkv.bias)
|
nn.init.zeros_(self.qkv.bias)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x, pre_logits: bool = False):
|
||||||
B, _, H, W = x.shape
|
B, _, H, W = x.shape
|
||||||
N = H * W
|
N = H * W
|
||||||
x = x.flatten(2).transpose(1, 2)
|
x = x.flatten(2).transpose(1, 2)
|
||||||
@ -107,8 +108,12 @@ class RotAttentionPool2d(nn.Module):
|
|||||||
attn = attn.softmax(dim=-1)
|
attn = attn.softmax(dim=-1)
|
||||||
x = attn @ v
|
x = attn @ v
|
||||||
x = x.transpose(1, 2).reshape(B, N + 1, -1)
|
x = x.transpose(1, 2).reshape(B, N + 1, -1)
|
||||||
|
x = x[:, 0]
|
||||||
|
x = self.drop(x)
|
||||||
|
if pre_logits:
|
||||||
|
return x
|
||||||
x = self.proj(x)
|
x = self.proj(x)
|
||||||
return x[:, 0]
|
return x
|
||||||
|
|
||||||
|
|
||||||
class AttentionPool2d(nn.Module):
|
class AttentionPool2d(nn.Module):
|
||||||
@ -132,9 +137,10 @@ class AttentionPool2d(nn.Module):
|
|||||||
num_heads: Optional[int] = None,
|
num_heads: Optional[int] = None,
|
||||||
qkv_bias: bool = True,
|
qkv_bias: bool = True,
|
||||||
qkv_separate: bool = False,
|
qkv_separate: bool = False,
|
||||||
|
drop: float = 0.,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
embed_dim = embed_dim or in_features
|
self.embed_dim = embed_dim = embed_dim or in_features
|
||||||
self.in_features = in_features
|
self.in_features = in_features
|
||||||
self.out_features = out_features or in_features
|
self.out_features = out_features or in_features
|
||||||
if num_heads is not None:
|
if num_heads is not None:
|
||||||
@ -158,6 +164,7 @@ class AttentionPool2d(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.q = self.k = self.v = None
|
self.q = self.k = self.v = None
|
||||||
self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias)
|
self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias)
|
||||||
|
self.drop = nn.Dropout(drop)
|
||||||
self.proj = nn.Linear(embed_dim, self.out_features)
|
self.proj = nn.Linear(embed_dim, self.out_features)
|
||||||
self.pos_embed = nn.Parameter(torch.zeros(self.seq_len + 1, in_features))
|
self.pos_embed = nn.Parameter(torch.zeros(self.seq_len + 1, in_features))
|
||||||
|
|
||||||
@ -178,15 +185,12 @@ class AttentionPool2d(nn.Module):
|
|||||||
nn.init.zeros_(self.qkv.bias)
|
nn.init.zeros_(self.qkv.bias)
|
||||||
trunc_normal_(self.pos_embed, std=in_features ** -0.5)
|
trunc_normal_(self.pos_embed, std=in_features ** -0.5)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x, pre_logits: bool = False):
|
||||||
B, _, H, W = x.shape
|
B, _, H, W = x.shape
|
||||||
N = H * W
|
N = H * W
|
||||||
x = x.flatten(2).transpose(1, 2)
|
x = x.flatten(2).transpose(1, 2)
|
||||||
x = torch.cat([x.mean(1, keepdim=True), x], dim=1)
|
x = torch.cat([x.mean(1, keepdim=True), x], dim=1)
|
||||||
if self.seq_len != N:
|
pos_embed = resample_abs_pos_embed(self.pos_embed.unsqueeze(0), (H, W), num_prefix_tokens=1)
|
||||||
pos_embed = resample_abs_pos_embed(self.pos_embed.unsqueeze(0), (H, W), num_prefix_tokens=1)
|
|
||||||
else:
|
|
||||||
pos_embed = self.pos_embed.unsqueeze(0).to(x.dtype)
|
|
||||||
x = x + pos_embed
|
x = x + pos_embed
|
||||||
|
|
||||||
if self.qkv is None:
|
if self.qkv is None:
|
||||||
@ -205,5 +209,9 @@ class AttentionPool2d(nn.Module):
|
|||||||
attn = attn.softmax(dim=-1)
|
attn = attn.softmax(dim=-1)
|
||||||
x = attn @ v
|
x = attn @ v
|
||||||
x = x.transpose(1, 2).reshape(B, N + 1, -1)
|
x = x.transpose(1, 2).reshape(B, N + 1, -1)
|
||||||
|
x = x[:, 0]
|
||||||
|
x = self.drop(x)
|
||||||
|
if pre_logits:
|
||||||
|
return x
|
||||||
x = self.proj(x)
|
x = self.proj(x)
|
||||||
return x[:, 0]
|
return x
|
||||||
|
@ -37,7 +37,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||||
from timm.layers import ClassifierHead, ConvNormAct, ConvNormActAa, BatchNormAct2d, DropPath, AvgPool2dSame, \
|
from timm.layers import ClassifierHead, ConvNormAct, BatchNormAct2d, DropPath, AvgPool2dSame, \
|
||||||
create_conv2d, get_act_layer, get_norm_act_layer, get_attn, make_divisible, to_2tuple, EvoNorm2dS0a
|
create_conv2d, get_act_layer, get_norm_act_layer, get_attn, make_divisible, to_2tuple, EvoNorm2dS0a
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
from ._features import feature_take_indices
|
from ._features import feature_take_indices
|
||||||
@ -82,7 +82,6 @@ class ByoModelCfg:
|
|||||||
aa_layer: str = ''
|
aa_layer: str = ''
|
||||||
|
|
||||||
# Head config
|
# Head config
|
||||||
attn_pool: str = ''
|
|
||||||
head_hidden_size: Optional[int] = None # feat dim of MLP head or AttentionPool output
|
head_hidden_size: Optional[int] = None # feat dim of MLP head or AttentionPool output
|
||||||
head_type: str = ''
|
head_type: str = ''
|
||||||
|
|
||||||
@ -304,7 +303,10 @@ class BottleneckBlock(nn.Module):
|
|||||||
mid_chs = make_divisible((in_chs if bottle_in else out_chs) * bottle_ratio)
|
mid_chs = make_divisible((in_chs if bottle_in else out_chs) * bottle_ratio)
|
||||||
groups = num_groups(group_size, mid_chs)
|
groups = num_groups(group_size, mid_chs)
|
||||||
|
|
||||||
|
self.shortcut = create_shortcut(
|
||||||
|
downsample, in_chs, out_chs,
|
||||||
|
stride=stride, dilation=dilation, apply_act=False, layers=layers,
|
||||||
|
)
|
||||||
|
|
||||||
self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1)
|
self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1)
|
||||||
self.conv2_kxk = layers.conv_norm_act(
|
self.conv2_kxk = layers.conv_norm_act(
|
||||||
@ -321,10 +323,7 @@ class BottleneckBlock(nn.Module):
|
|||||||
self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs)
|
self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs)
|
||||||
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
||||||
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
|
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
|
||||||
self.shortcut = create_shortcut(
|
|
||||||
downsample, in_chs, out_chs,
|
|
||||||
stride=stride, dilation=dilation, apply_act=False, layers=layers,
|
|
||||||
)
|
|
||||||
def init_weights(self, zero_init_last: bool = False):
|
def init_weights(self, zero_init_last: bool = False):
|
||||||
if zero_init_last and self.shortcut is not None and getattr(self.conv3_1x1.bn, 'weight', None) is not None:
|
if zero_init_last and self.shortcut is not None and getattr(self.conv3_1x1.bn, 'weight', None) is not None:
|
||||||
nn.init.zeros_(self.conv3_1x1.bn.weight)
|
nn.init.zeros_(self.conv3_1x1.bn.weight)
|
||||||
@ -1165,7 +1164,7 @@ def get_layer_fns(cfg: ByoModelCfg, allow_aa: bool = True):
|
|||||||
act = get_act_layer(cfg.act_layer)
|
act = get_act_layer(cfg.act_layer)
|
||||||
norm_act = get_norm_act_layer(norm_layer=cfg.norm_layer, act_layer=act)
|
norm_act = get_norm_act_layer(norm_layer=cfg.norm_layer, act_layer=act)
|
||||||
if cfg.aa_layer and allow_aa:
|
if cfg.aa_layer and allow_aa:
|
||||||
conv_norm_act = partial(ConvNormActAa, norm_layer=cfg.norm_layer, act_layer=act, aa_layer=cfg.aa_layer)
|
conv_norm_act = partial(ConvNormAct, norm_layer=cfg.norm_layer, act_layer=act, aa_layer=cfg.aa_layer)
|
||||||
else:
|
else:
|
||||||
conv_norm_act = partial(ConvNormAct, norm_layer=cfg.norm_layer, act_layer=act)
|
conv_norm_act = partial(ConvNormAct, norm_layer=cfg.norm_layer, act_layer=act)
|
||||||
attn = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None
|
attn = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None
|
||||||
@ -1258,6 +1257,7 @@ class ByobNet(nn.Module):
|
|||||||
self.stage_ends = [f['stage'] for f in self.feature_info]
|
self.stage_ends = [f['stage'] for f in self.feature_info]
|
||||||
|
|
||||||
self.head_hidden_size = self.num_features
|
self.head_hidden_size = self.num_features
|
||||||
|
self.global_pool = global_pool
|
||||||
assert cfg.head_type in ('', 'classifier', 'norm_mlp_classifier')
|
assert cfg.head_type in ('', 'classifier', 'norm_mlp_classifier')
|
||||||
if cfg.head_type == 'norm_mlp_classifier':
|
if cfg.head_type == 'norm_mlp_classifier':
|
||||||
from timm.layers import NormMlpClassifierHead
|
from timm.layers import NormMlpClassifierHead
|
||||||
@ -1272,33 +1272,61 @@ class ByobNet(nn.Module):
|
|||||||
)
|
)
|
||||||
self.head_hidden_size = self.head.hidden_size
|
self.head_hidden_size = self.head.hidden_size
|
||||||
else:
|
else:
|
||||||
if cfg.attn_pool == 'abs':
|
# FIXME evaluating different head vs pool configurations
|
||||||
from timm.layers import AttentionPool2d
|
if False:
|
||||||
self.attn_pool = AttentionPool2d(
|
if global_pool == 'attn_abs':
|
||||||
self.num_features,
|
from timm.layers import AttentionPool2d
|
||||||
out_features=cfg.head_hidden_size,
|
self.attn_pool = AttentionPool2d(
|
||||||
feat_size=feat_size,
|
self.num_features,
|
||||||
qkv_separate=True,
|
out_features=cfg.head_hidden_size,
|
||||||
)
|
feat_size=feat_size,
|
||||||
self.head_hidden_size = self.attn_pool.out_features
|
qkv_separate=True,
|
||||||
elif cfg.attn_pool == 'rot':
|
)
|
||||||
from timm.layers import RotAttentionPool2d
|
global_pool = '' # clear for ClassifierHead
|
||||||
self.attn_pool = RotAttentionPool2d(
|
self.head_hidden_size = self.attn_pool.out_features
|
||||||
self.num_features,
|
elif global_pool =='attn_rot':
|
||||||
out_features=cfg.head_hidden_size,
|
from timm.layers import RotAttentionPool2d
|
||||||
ref_feat_size=feat_size,
|
self.attn_pool = RotAttentionPool2d(
|
||||||
)
|
self.num_features,
|
||||||
self.head_hidden_size = self.attn_pool.out_features
|
out_features=cfg.head_hidden_size,
|
||||||
else:
|
ref_feat_size=feat_size,
|
||||||
assert not cfg.attn_pool
|
)
|
||||||
self.attn_pool = nn.Identity()
|
global_pool = '' # clear for ClassifierHead
|
||||||
|
self.head_hidden_size = self.attn_pool.out_features
|
||||||
|
else:
|
||||||
|
self.attn_pool = nn.Identity()
|
||||||
|
|
||||||
self.head = ClassifierHead(
|
self.head = ClassifierHead(
|
||||||
self.head_hidden_size,
|
self.head_hidden_size,
|
||||||
num_classes,
|
num_classes,
|
||||||
pool_type='' if cfg.attn_pool else global_pool,
|
pool_type=global_pool,
|
||||||
drop_rate=self.drop_rate,
|
drop_rate=self.drop_rate,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
if global_pool == 'attn_abs':
|
||||||
|
from timm.layers import AttentionPool2d
|
||||||
|
self.head = AttentionPool2d(
|
||||||
|
self.num_features,
|
||||||
|
out_features=num_classes,
|
||||||
|
feat_size=feat_size,
|
||||||
|
qkv_separate=True,
|
||||||
|
)
|
||||||
|
self.head_hidden_size = self.head.embed_dim
|
||||||
|
elif global_pool == 'attn_rot':
|
||||||
|
from timm.layers import RotAttentionPool2d
|
||||||
|
self.head = RotAttentionPool2d(
|
||||||
|
self.num_features,
|
||||||
|
out_features=num_classes,
|
||||||
|
ref_feat_size=feat_size,
|
||||||
|
)
|
||||||
|
self.head_hidden_size = self.head.embed_dim
|
||||||
|
else:
|
||||||
|
self.head = ClassifierHead(
|
||||||
|
self.num_features,
|
||||||
|
num_classes,
|
||||||
|
pool_type=global_pool,
|
||||||
|
drop_rate=self.drop_rate,
|
||||||
|
)
|
||||||
|
|
||||||
# init weights
|
# init weights
|
||||||
named_apply(partial(_init_weights, zero_init_last=zero_init_last), self)
|
named_apply(partial(_init_weights, zero_init_last=zero_init_last), self)
|
||||||
@ -1324,6 +1352,9 @@ class ByobNet(nn.Module):
|
|||||||
|
|
||||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
|
if global_pool is not None:
|
||||||
|
if self.global_pool in ('attn_abs', 'attn_rot'):
|
||||||
|
raise RuntimeError('Cannot change attention pool on head reset.')
|
||||||
self.head.reset(num_classes, global_pool)
|
self.head.reset(num_classes, global_pool)
|
||||||
|
|
||||||
def forward_intermediates(
|
def forward_intermediates(
|
||||||
@ -1413,7 +1444,7 @@ class ByobNet(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
def forward_head(self, x, pre_logits: bool = False):
|
def forward_head(self, x, pre_logits: bool = False):
|
||||||
x = self.attn_pool(x)
|
#x = self.attn_pool(x)
|
||||||
return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
|
return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
@ -1916,7 +1947,6 @@ model_cfgs = dict(
|
|||||||
stem_pool='avg2',
|
stem_pool='avg2',
|
||||||
downsample='avg',
|
downsample='avg',
|
||||||
aa_layer='avg',
|
aa_layer='avg',
|
||||||
attn_pool='abs',
|
|
||||||
head_hidden_size=1024,
|
head_hidden_size=1024,
|
||||||
),
|
),
|
||||||
|
|
||||||
@ -1932,7 +1962,6 @@ model_cfgs = dict(
|
|||||||
stem_pool='avg2',
|
stem_pool='avg2',
|
||||||
downsample='avg',
|
downsample='avg',
|
||||||
aa_layer='avg',
|
aa_layer='avg',
|
||||||
attn_pool='abs',
|
|
||||||
head_hidden_size=512,
|
head_hidden_size=512,
|
||||||
),
|
),
|
||||||
|
|
||||||
@ -1949,7 +1978,6 @@ model_cfgs = dict(
|
|||||||
stem_pool='avg2',
|
stem_pool='avg2',
|
||||||
downsample='avg',
|
downsample='avg',
|
||||||
aa_layer='avg',
|
aa_layer='avg',
|
||||||
attn_pool='abs',
|
|
||||||
head_hidden_size=640,
|
head_hidden_size=640,
|
||||||
),
|
),
|
||||||
|
|
||||||
@ -1960,12 +1988,12 @@ model_cfgs = dict(
|
|||||||
ByoBlockCfg(type='bottle', d=18, c=1024, s=2, br=0.25),
|
ByoBlockCfg(type='bottle', d=18, c=1024, s=2, br=0.25),
|
||||||
ByoBlockCfg(type='bottle', d=8, c=2048, s=2, br=0.25),
|
ByoBlockCfg(type='bottle', d=8, c=2048, s=2, br=0.25),
|
||||||
),
|
),
|
||||||
|
width_factor=1.5,
|
||||||
stem_chs=(32, 32, 64),
|
stem_chs=(32, 32, 64),
|
||||||
stem_type='',
|
stem_type='',
|
||||||
stem_pool='avg2',
|
stem_pool='avg2',
|
||||||
downsample='avg',
|
downsample='avg',
|
||||||
aa_layer='avg',
|
aa_layer='avg',
|
||||||
attn_pool='abs',
|
|
||||||
head_hidden_size=768,
|
head_hidden_size=768,
|
||||||
),
|
),
|
||||||
|
|
||||||
@ -1976,12 +2004,12 @@ model_cfgs = dict(
|
|||||||
ByoBlockCfg(type='bottle', d=36, c=1024, s=2, br=0.25),
|
ByoBlockCfg(type='bottle', d=36, c=1024, s=2, br=0.25),
|
||||||
ByoBlockCfg(type='bottle', d=10, c=2048, s=2, br=0.25),
|
ByoBlockCfg(type='bottle', d=10, c=2048, s=2, br=0.25),
|
||||||
),
|
),
|
||||||
|
width_factor=2.0,
|
||||||
stem_chs=(32, 32, 64),
|
stem_chs=(32, 32, 64),
|
||||||
stem_type='',
|
stem_type='',
|
||||||
stem_pool='avg2',
|
stem_pool='avg2',
|
||||||
downsample='avg',
|
downsample='avg',
|
||||||
aa_layer='avg',
|
aa_layer='avg',
|
||||||
attn_pool='abs',
|
|
||||||
head_hidden_size=1024,
|
head_hidden_size=1024,
|
||||||
),
|
),
|
||||||
|
|
||||||
@ -2029,10 +2057,10 @@ def _convert_openai_clip(
|
|||||||
continue
|
continue
|
||||||
k = re.sub(rf'{prefix}conv([0-9])', r'stem.conv\1.conv', k)
|
k = re.sub(rf'{prefix}conv([0-9])', r'stem.conv\1.conv', k)
|
||||||
k = re.sub(rf'{prefix}bn([0-9])', r'stem.conv\1.bn', k)
|
k = re.sub(rf'{prefix}bn([0-9])', r'stem.conv\1.bn', k)
|
||||||
k = re.sub(rf'{prefix}layer([0-9])\.([0-9])\.([a-z]+)([0-9])', _stage_sub, k)
|
k = re.sub(rf'{prefix}layer([0-9])\.([0-9]+)\.([a-z]+)([0-9])', _stage_sub, k)
|
||||||
k = re.sub(rf'{prefix}layer([0-9])\.([0-9])\.downsample\.([0-9])', _down_sub, k)
|
k = re.sub(rf'{prefix}layer([0-9])\.([0-9]+)\.downsample\.([0-9])', _down_sub, k)
|
||||||
if k.startswith(f'{prefix}attnpool'):
|
if k.startswith(f'{prefix}attnpool'):
|
||||||
k = k.replace(prefix + 'attnpool', 'attn_pool')
|
k = k.replace(prefix + 'attnpool', 'head') #'attn_pool')
|
||||||
k = k.replace('positional_embedding', 'pos_embed')
|
k = k.replace('positional_embedding', 'pos_embed')
|
||||||
k = k.replace('q_proj', 'q')
|
k = k.replace('q_proj', 'q')
|
||||||
k = k.replace('k_proj', 'k')
|
k = k.replace('k_proj', 'k')
|
||||||
@ -2053,13 +2081,19 @@ def checkpoint_filter_fn(
|
|||||||
|
|
||||||
|
|
||||||
def _create_byobnet(variant, pretrained=False, **kwargs):
|
def _create_byobnet(variant, pretrained=False, **kwargs):
|
||||||
|
strict = True
|
||||||
|
if 'clip' in variant and kwargs.get('global_pool', None) != 'attn_abs':
|
||||||
|
# NOTE: a hack to allow removing attention pool from CLIP ResNet variants
|
||||||
|
strict = False
|
||||||
|
|
||||||
return build_model_with_cfg(
|
return build_model_with_cfg(
|
||||||
ByobNet, variant, pretrained,
|
ByobNet, variant, pretrained,
|
||||||
model_cfg=model_cfgs[variant],
|
model_cfg=model_cfgs[variant],
|
||||||
pretrained_filter_fn=checkpoint_filter_fn,
|
pretrained_filter_fn=checkpoint_filter_fn,
|
||||||
feature_cfg=dict(flatten_sequential=True),
|
feature_cfg=dict(flatten_sequential=True),
|
||||||
#pretrained_strict=False,
|
pretrained_strict=strict,
|
||||||
**kwargs)
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _cfg(url='', **kwargs):
|
def _cfg(url='', **kwargs):
|
||||||
@ -2257,31 +2291,36 @@ default_cfgs = generate_default_cfgs({
|
|||||||
hf_hub_id='timm/',
|
hf_hub_id='timm/',
|
||||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||||
num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
||||||
fixed_input_size=True, input_size=(3, 224, 224), pool_size=(7, 7)
|
fixed_input_size=True, input_size=(3, 224, 224), pool_size=(7, 7),
|
||||||
|
classifier = 'head.proj',
|
||||||
),
|
),
|
||||||
'resnet101_clip.openai': _cfgr(
|
'resnet101_clip.openai': _cfgr(
|
||||||
hf_hub_id='timm/',
|
hf_hub_id='timm/',
|
||||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||||
num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
||||||
fixed_input_size=True, input_size=(3, 224, 224), pool_size=(7, 7)
|
fixed_input_size=True, input_size=(3, 224, 224), pool_size=(7, 7),
|
||||||
|
classifier='head.proj',
|
||||||
),
|
),
|
||||||
'resnet50x4_clip.openai': _cfgr(
|
'resnet50x4_clip.openai': _cfgr(
|
||||||
hf_hub_id='timm/',
|
hf_hub_id='timm/',
|
||||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||||
num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
||||||
fixed_input_size=True, input_size=(3, 288, 288), pool_size=(9, 9)
|
fixed_input_size=True, input_size=(3, 288, 288), pool_size=(9, 9),
|
||||||
|
classifier = 'head.proj',
|
||||||
),
|
),
|
||||||
'resnet50x16_clip.openai': _cfgr(
|
'resnet50x16_clip.openai': _cfgr(
|
||||||
hf_hub_id='timm/',
|
hf_hub_id='timm/',
|
||||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||||
num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
||||||
fixed_input_size=True, input_size=(3, 384, 384), pool_size=(12, 12)
|
fixed_input_size=True, input_size=(3, 384, 384), pool_size=(12, 12),
|
||||||
|
classifier = 'head.proj',
|
||||||
),
|
),
|
||||||
'resnet50x64_clip.openai': _cfgr(
|
'resnet50x64_clip.openai': _cfgr(
|
||||||
hf_hub_id='timm/',
|
hf_hub_id='timm/',
|
||||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||||
num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
||||||
fixed_input_size=True, input_size=(3, 448, 448), pool_size=(14, 14)
|
fixed_input_size=True, input_size=(3, 448, 448), pool_size=(14, 14),
|
||||||
|
classifier = 'head.proj',
|
||||||
),
|
),
|
||||||
|
|
||||||
})
|
})
|
||||||
@ -2592,35 +2631,40 @@ def mobileone_s4(pretrained=False, **kwargs) -> ByobNet:
|
|||||||
def resnet50_clip(pretrained=False, **kwargs) -> ByobNet:
|
def resnet50_clip(pretrained=False, **kwargs) -> ByobNet:
|
||||||
""" OpenAI Modified ResNet-50 CLIP image tower
|
""" OpenAI Modified ResNet-50 CLIP image tower
|
||||||
"""
|
"""
|
||||||
return _create_byobnet('resnet50_clip', pretrained=pretrained, **kwargs)
|
model_args = dict(global_pool='attn_abs')
|
||||||
|
return _create_byobnet('resnet50_clip', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def resnet101_clip(pretrained=False, **kwargs) -> ByobNet:
|
def resnet101_clip(pretrained=False, **kwargs) -> ByobNet:
|
||||||
""" OpenAI Modified ResNet-101 CLIP image tower
|
""" OpenAI Modified ResNet-101 CLIP image tower
|
||||||
"""
|
"""
|
||||||
return _create_byobnet('resnet101_clip', pretrained=pretrained, **kwargs)
|
model_args = dict(global_pool='attn_abs')
|
||||||
|
return _create_byobnet('resnet101_clip', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def resnet50x4_clip(pretrained=False, **kwargs) -> ByobNet:
|
def resnet50x4_clip(pretrained=False, **kwargs) -> ByobNet:
|
||||||
""" OpenAI Modified ResNet-50x4 CLIP image tower
|
""" OpenAI Modified ResNet-50x4 CLIP image tower
|
||||||
"""
|
"""
|
||||||
return _create_byobnet('resnet50x4_clip', pretrained=pretrained, **kwargs)
|
model_args = dict(global_pool='attn_abs')
|
||||||
|
return _create_byobnet('resnet50x4_clip', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def resnet50x16_clip(pretrained=False, **kwargs) -> ByobNet:
|
def resnet50x16_clip(pretrained=False, **kwargs) -> ByobNet:
|
||||||
""" OpenAI Modified ResNet-50x16 CLIP image tower
|
""" OpenAI Modified ResNet-50x16 CLIP image tower
|
||||||
"""
|
"""
|
||||||
return _create_byobnet('resnet50x16_clip', pretrained=pretrained, **kwargs)
|
model_args = dict(global_pool='attn_abs')
|
||||||
|
return _create_byobnet('resnet50x16_clip', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def resnet50x64_clip(pretrained=False, **kwargs) -> ByobNet:
|
def resnet50x64_clip(pretrained=False, **kwargs) -> ByobNet:
|
||||||
""" OpenAI Modified ResNet-50x64 CLIP image tower
|
""" OpenAI Modified ResNet-50x64 CLIP image tower
|
||||||
"""
|
"""
|
||||||
return _create_byobnet('resnet50x64_clip', pretrained=pretrained, **kwargs)
|
model_args = dict(global_pool='attn_abs')
|
||||||
|
return _create_byobnet('resnet50x64_clip', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
|
Loading…
x
Reference in New Issue
Block a user