mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Mapping OpenAI CLIP Modified ResNet weights -> ByobNet. Improve AttentionPool2d layers. Fix #1731
This commit is contained in:
parent
7702d9afa1
commit
5efa15b2a2
@ -20,6 +20,7 @@ class AttentionPoolLatent(nn.Module):
|
||||
out_features: int = None,
|
||||
embed_dim: int = None,
|
||||
num_heads: int = 8,
|
||||
feat_size: Optional[int] = None,
|
||||
mlp_ratio: float = 4.0,
|
||||
qkv_bias: bool = True,
|
||||
qk_norm: bool = False,
|
||||
@ -36,13 +37,14 @@ class AttentionPoolLatent(nn.Module):
|
||||
assert embed_dim % num_heads == 0
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = embed_dim // num_heads
|
||||
self.feat_size = feat_size
|
||||
self.scale = self.head_dim ** -0.5
|
||||
self.pool = pool_type
|
||||
self.fused_attn = use_fused_attn()
|
||||
|
||||
if pos_embed == 'abs':
|
||||
spatial_len = self.feat_size
|
||||
self.pos_embed = nn.Parameter(torch.zeros(spatial_len, in_features))
|
||||
assert feat_size is not None
|
||||
self.pos_embed = nn.Parameter(torch.zeros(feat_size, in_features))
|
||||
else:
|
||||
self.pos_embed = None
|
||||
|
||||
|
@ -7,12 +7,14 @@ https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/cli
|
||||
|
||||
Hacked together by / Copyright 2021 Ross Wightman
|
||||
"""
|
||||
from typing import Union, Tuple
|
||||
from typing import Optional, Union, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from. config import use_fused_attn
|
||||
from .helpers import to_2tuple
|
||||
from .pos_embed import resample_abs_pos_embed
|
||||
from .pos_embed_sincos import apply_rot_embed, RotaryEmbedding
|
||||
from .weight_init import trunc_normal_
|
||||
|
||||
@ -27,51 +29,84 @@ class RotAttentionPool2d(nn.Module):
|
||||
NOTE: While this impl does not require a fixed feature size, performance at differeing resolutions from
|
||||
train varies widely and falls off dramatically. I'm not sure if there is a way around this... -RW
|
||||
"""
|
||||
fused_attn: torch.jit.Final[bool]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int = None,
|
||||
embed_dim: int = None,
|
||||
num_heads: int = 4,
|
||||
out_features: Optional[int] = None,
|
||||
ref_feat_size: Union[int, Tuple[int, int]] = 7,
|
||||
embed_dim: Optional[int] = None,
|
||||
head_dim: Optional[int] = 64,
|
||||
num_heads: Optional[int] = None,
|
||||
qkv_bias: bool = True,
|
||||
qkv_separate: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
embed_dim = embed_dim or in_features
|
||||
out_features = out_features or in_features
|
||||
self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias)
|
||||
self.proj = nn.Linear(embed_dim, out_features)
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features or in_features
|
||||
ref_feat_size = to_2tuple(ref_feat_size)
|
||||
if num_heads is not None:
|
||||
assert embed_dim % num_heads == 0
|
||||
head_dim = embed_dim // num_heads
|
||||
else:
|
||||
assert embed_dim % head_dim == 0
|
||||
num_heads = embed_dim // head_dim
|
||||
self.num_heads = num_heads
|
||||
assert embed_dim % num_heads == 0
|
||||
self.head_dim = embed_dim // num_heads
|
||||
self.head_dim = head_dim
|
||||
self.scale = self.head_dim ** -0.5
|
||||
self.pos_embed = RotaryEmbedding(self.head_dim)
|
||||
self.fused_attn = use_fused_attn()
|
||||
|
||||
trunc_normal_(self.qkv.weight, std=in_features ** -0.5)
|
||||
nn.init.zeros_(self.qkv.bias)
|
||||
if qkv_separate:
|
||||
self.q = nn.Linear(in_features, embed_dim, bias=qkv_bias)
|
||||
self.k = nn.Linear(in_features, embed_dim, bias=qkv_bias)
|
||||
self.v = nn.Linear(in_features, embed_dim, bias=qkv_bias)
|
||||
self.qkv = None
|
||||
else:
|
||||
self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias)
|
||||
self.proj = nn.Linear(embed_dim, self.out_features)
|
||||
self.pos_embed = RotaryEmbedding(self.head_dim, in_pixels=False, ref_feat_shape=ref_feat_size)
|
||||
|
||||
def init_weights(self, zero_init_last: bool = False):
|
||||
if self.qkv is None:
|
||||
in_features = self.q.in_features
|
||||
trunc_normal_(self.q.weight, std=in_features ** -0.5)
|
||||
nn.init.zeros_(self.q.bias)
|
||||
trunc_normal_(self.k.weight, std=in_features ** -0.5)
|
||||
nn.init.zeros_(self.k.bias)
|
||||
trunc_normal_(self.v.weight, std=in_features ** -0.5)
|
||||
nn.init.zeros_(self.v.bias)
|
||||
else:
|
||||
in_features = self.qkv.in_features
|
||||
trunc_normal_(self.qkv.weight, std=in_features ** -0.5)
|
||||
nn.init.zeros_(self.qkv.bias)
|
||||
|
||||
def forward(self, x):
|
||||
B, _, H, W = x.shape
|
||||
N = H * W
|
||||
x = x.reshape(B, -1, N).permute(0, 2, 1)
|
||||
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
x = torch.cat([x.mean(1, keepdim=True), x], dim=1)
|
||||
if self.qkv is None:
|
||||
q = self.q(x).reshape(B, N + 1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
k = self.k(x).reshape(B, N + 1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
v = self.v(x).reshape(B, N + 1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
else:
|
||||
x = self.qkv(x).reshape(B, N + 1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = x.unbind(0)
|
||||
|
||||
x = self.qkv(x).reshape(B, N + 1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = x[0], x[1], x[2]
|
||||
rse, rce = self.pos_embed.get_embed((H, W))
|
||||
q = torch.cat([q[:, :, :1, :], apply_rot_embed(q[:, :, 1:, :], rse, rce)], dim=2).type_as(v)
|
||||
k = torch.cat([k[:, :, :1, :], apply_rot_embed(k[:, :, 1:, :], rse, rce)], dim=2).type_as(v)
|
||||
|
||||
qc, q = q[:, :, :1], q[:, :, 1:]
|
||||
sin_emb, cos_emb = self.pos_embed.get_embed((H, W))
|
||||
q = apply_rot_embed(q, sin_emb, cos_emb)
|
||||
q = torch.cat([qc, q], dim=2)
|
||||
|
||||
kc, k = k[:, :, :1], k[:, :, 1:]
|
||||
k = apply_rot_embed(k, sin_emb, cos_emb)
|
||||
k = torch.cat([kc, k], dim=2)
|
||||
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
attn = attn.softmax(dim=-1)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N + 1, -1)
|
||||
if self.fused_attn:
|
||||
x = nn.functional.scaled_dot_product_attention(q, k, v)
|
||||
else:
|
||||
q = q * self.scale
|
||||
attn = q @ k.transpose(-2, -1)
|
||||
attn = attn.softmax(dim=-1)
|
||||
x = attn @ v
|
||||
x = x.transpose(1, 2).reshape(B, N + 1, -1)
|
||||
x = self.proj(x)
|
||||
return x[:, 0]
|
||||
|
||||
@ -85,47 +120,90 @@ class AttentionPool2d(nn.Module):
|
||||
|
||||
NOTE: This requires feature size upon construction and well prevent adaptive sizing of the network.
|
||||
"""
|
||||
fused_attn: torch.jit.Final[bool]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
feat_size: Union[int, Tuple[int, int]],
|
||||
out_features: int = None,
|
||||
embed_dim: int = None,
|
||||
num_heads: int = 4,
|
||||
feat_size: Union[int, Tuple[int, int]] = 7,
|
||||
out_features: Optional[int] = None,
|
||||
embed_dim: Optional[int] = None,
|
||||
head_dim: Optional[int] = 64,
|
||||
num_heads: Optional[int] = None,
|
||||
qkv_bias: bool = True,
|
||||
qkv_separate: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
embed_dim = embed_dim or in_features
|
||||
out_features = out_features or in_features
|
||||
assert embed_dim % num_heads == 0
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features or in_features
|
||||
if num_heads is not None:
|
||||
assert embed_dim % num_heads == 0
|
||||
head_dim = embed_dim // num_heads
|
||||
else:
|
||||
assert embed_dim % head_dim == 0
|
||||
num_heads = embed_dim // head_dim
|
||||
self.feat_size = to_2tuple(feat_size)
|
||||
self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias)
|
||||
self.proj = nn.Linear(embed_dim, out_features)
|
||||
self.seq_len = self.feat_size[0] * self.feat_size[1]
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = embed_dim // num_heads
|
||||
self.head_dim = head_dim
|
||||
self.scale = self.head_dim ** -0.5
|
||||
self.fused_attn = use_fused_attn()
|
||||
|
||||
spatial_dim = self.feat_size[0] * self.feat_size[1]
|
||||
self.pos_embed = nn.Parameter(torch.zeros(spatial_dim + 1, in_features))
|
||||
if qkv_separate:
|
||||
self.q = nn.Linear(in_features, embed_dim, bias=qkv_bias)
|
||||
self.k = nn.Linear(in_features, embed_dim, bias=qkv_bias)
|
||||
self.v = nn.Linear(in_features, embed_dim, bias=qkv_bias)
|
||||
self.qkv = None
|
||||
else:
|
||||
self.q = self.k = self.v = None
|
||||
self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias)
|
||||
self.proj = nn.Linear(embed_dim, self.out_features)
|
||||
self.pos_embed = nn.Parameter(torch.zeros(self.seq_len + 1, in_features))
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self, zero_init_last: bool = False):
|
||||
if self.qkv is None:
|
||||
in_features = self.q.in_features
|
||||
trunc_normal_(self.q.weight, std=in_features ** -0.5)
|
||||
nn.init.zeros_(self.q.bias)
|
||||
trunc_normal_(self.k.weight, std=in_features ** -0.5)
|
||||
nn.init.zeros_(self.k.bias)
|
||||
trunc_normal_(self.v.weight, std=in_features ** -0.5)
|
||||
nn.init.zeros_(self.v.bias)
|
||||
else:
|
||||
in_features = self.qkv.in_features
|
||||
trunc_normal_(self.qkv.weight, std=in_features ** -0.5)
|
||||
nn.init.zeros_(self.qkv.bias)
|
||||
trunc_normal_(self.pos_embed, std=in_features ** -0.5)
|
||||
trunc_normal_(self.qkv.weight, std=in_features ** -0.5)
|
||||
nn.init.zeros_(self.qkv.bias)
|
||||
|
||||
def forward(self, x):
|
||||
B, _, H, W = x.shape
|
||||
N = H * W
|
||||
assert self.feat_size[0] == H
|
||||
assert self.feat_size[1] == W
|
||||
x = x.reshape(B, -1, N).permute(0, 2, 1)
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
x = torch.cat([x.mean(1, keepdim=True), x], dim=1)
|
||||
x = x + self.pos_embed.unsqueeze(0).to(x.dtype)
|
||||
if self.seq_len != N:
|
||||
pos_embed = resample_abs_pos_embed(self.pos_embed.unsqueeze(0), (H, W), num_prefix_tokens=1)
|
||||
else:
|
||||
pos_embed = self.pos_embed.unsqueeze(0).to(x.dtype)
|
||||
x = x + pos_embed
|
||||
|
||||
x = self.qkv(x).reshape(B, N + 1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = x[0], x[1], x[2]
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
attn = attn.softmax(dim=-1)
|
||||
if self.qkv is None:
|
||||
q = self.q(x).reshape(B, N + 1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
k = self.k(x).reshape(B, N + 1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
v = self.v(x).reshape(B, N + 1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
else:
|
||||
x = self.qkv(x).reshape(B, -1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = x.unbind(0)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N + 1, -1)
|
||||
if self.fused_attn:
|
||||
x = nn.functional.scaled_dot_product_attention(q, k, v)
|
||||
else:
|
||||
q = q * self.scale
|
||||
attn = q @ k.transpose(-2, -1)
|
||||
attn = attn.softmax(dim=-1)
|
||||
x = attn @ v
|
||||
x = x.transpose(1, 2).reshape(B, N + 1, -1)
|
||||
x = self.proj(x)
|
||||
return x[:, 0]
|
||||
|
@ -24,8 +24,6 @@ def _create_pool(
|
||||
):
|
||||
flatten_in_pool = not use_conv # flatten when we use a Linear layer after pooling
|
||||
if not pool_type:
|
||||
assert num_classes == 0 or use_conv,\
|
||||
'Pooling can only be disabled if classifier is also removed or conv classifier is used'
|
||||
flatten_in_pool = False # disable flattening if pooling is pass-through (no pooling)
|
||||
global_pool = SelectAdaptivePool2d(
|
||||
pool_type=pool_type,
|
||||
|
@ -312,7 +312,6 @@ class RotaryEmbedding(nn.Module):
|
||||
temperature=temperature,
|
||||
step=1,
|
||||
)
|
||||
print(bands)
|
||||
self.register_buffer(
|
||||
'bands',
|
||||
bands,
|
||||
|
@ -36,8 +36,8 @@ from typing import Tuple, List, Dict, Optional, Union, Any, Callable, Sequence
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import ClassifierHead, ConvNormAct, BatchNormAct2d, DropPath, AvgPool2dSame, \
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||
from timm.layers import ClassifierHead, ConvNormAct, ConvNormActAa, BatchNormAct2d, DropPath, AvgPool2dSame, \
|
||||
create_conv2d, get_act_layer, get_norm_act_layer, get_attn, make_divisible, to_2tuple, EvoNorm2dS0a
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import feature_take_indices
|
||||
@ -70,15 +70,23 @@ class ByoModelCfg:
|
||||
downsample: str = 'conv1x1'
|
||||
stem_type: str = '3x3'
|
||||
stem_pool: Optional[str] = 'maxpool'
|
||||
stem_chs: int = 32
|
||||
stem_chs: Union[int, List[int], Tuple[int, ...]] = 32
|
||||
width_factor: float = 1.0
|
||||
num_features: int = 0 # num out_channels for final conv, no final 1x1 conv if 0
|
||||
zero_init_last: bool = True # zero init last weight (usually bn) in residual path
|
||||
fixed_input_size: bool = False # model constrained to a fixed-input size / img_size must be provided on creation
|
||||
|
||||
# layer config
|
||||
act_layer: str = 'relu'
|
||||
norm_layer: str = 'batchnorm'
|
||||
aa_layer: str = ''
|
||||
|
||||
# Head config
|
||||
attn_pool: str = ''
|
||||
head_hidden_size: Optional[int] = None # feat dim of MLP head or AttentionPool output
|
||||
head_type: str = ''
|
||||
|
||||
# Block config
|
||||
# NOTE: these config items will be overridden by the block cfg (per-block) if they are set there
|
||||
attn_layer: Optional[str] = None
|
||||
attn_kwargs: dict = field(default_factory=lambda: dict())
|
||||
@ -296,10 +304,7 @@ class BottleneckBlock(nn.Module):
|
||||
mid_chs = make_divisible((in_chs if bottle_in else out_chs) * bottle_ratio)
|
||||
groups = num_groups(group_size, mid_chs)
|
||||
|
||||
self.shortcut = create_shortcut(
|
||||
downsample, in_chs, out_chs,
|
||||
stride=stride, dilation=dilation, apply_act=False, layers=layers,
|
||||
)
|
||||
|
||||
|
||||
self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1)
|
||||
self.conv2_kxk = layers.conv_norm_act(
|
||||
@ -316,7 +321,10 @@ class BottleneckBlock(nn.Module):
|
||||
self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs)
|
||||
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
||||
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
|
||||
|
||||
self.shortcut = create_shortcut(
|
||||
downsample, in_chs, out_chs,
|
||||
stride=stride, dilation=dilation, apply_act=False, layers=layers,
|
||||
)
|
||||
def init_weights(self, zero_init_last: bool = False):
|
||||
if zero_init_last and self.shortcut is not None and getattr(self.conv3_1x1.bn, 'weight', None) is not None:
|
||||
nn.init.zeros_(self.conv3_1x1.bn.weight)
|
||||
@ -917,7 +925,7 @@ class Stem(nn.Sequential):
|
||||
def __init__(
|
||||
self,
|
||||
in_chs: int,
|
||||
out_chs: int,
|
||||
out_chs: Union[int, List[int], Tuple[int, ...]],
|
||||
kernel_size: int = 3,
|
||||
stride: int = 4,
|
||||
pool: str = 'maxpool',
|
||||
@ -961,10 +969,19 @@ class Stem(nn.Sequential):
|
||||
curr_stride *= s
|
||||
prev_feat = conv_name
|
||||
|
||||
if pool and 'max' in pool.lower():
|
||||
if pool:
|
||||
pool = pool.lower()
|
||||
assert pool in ('max', 'maxpool', 'avg', 'avgpool', 'max2', 'avg2')
|
||||
last_feat_idx = i
|
||||
self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat, stage=0))
|
||||
self.add_module('pool', nn.MaxPool2d(3, 2, 1))
|
||||
if pool == 'max2':
|
||||
self.add_module('pool', nn.MaxPool2d(2))
|
||||
elif pool == 'avg2':
|
||||
self.add_module('pool', nn.AvgPool2d(2))
|
||||
elif 'max' in pool:
|
||||
self.add_module('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
|
||||
elif 'avg' in pool:
|
||||
self.add_module('pool', nn.AvgPool2d(kernel_size=3, stride=2, padding=1, count_include_pad=False))
|
||||
curr_stride *= 2
|
||||
prev_feat = 'pool'
|
||||
|
||||
@ -1012,11 +1029,14 @@ def create_byob_stem(
|
||||
else:
|
||||
stem = layers.conv_norm_act(in_chs, out_chs, 7, stride=2)
|
||||
else:
|
||||
# 3x3 stem conv as in RegNet is the default
|
||||
if pool_type:
|
||||
stem = Stem(in_chs, out_chs, 3, num_rep=1, pool=pool_type, layers=layers)
|
||||
if isinstance(out_chs, (tuple, list)):
|
||||
stem = Stem(in_chs, out_chs, 3, pool=pool_type, layers=layers)
|
||||
else:
|
||||
stem = layers.conv_norm_act(in_chs, out_chs, 3, stride=2)
|
||||
# 3x3 stem conv as in RegNet is the default
|
||||
if pool_type:
|
||||
stem = Stem(in_chs, out_chs, 3, num_rep=1, pool=pool_type, layers=layers)
|
||||
else:
|
||||
stem = layers.conv_norm_act(in_chs, out_chs, 3, stride=2)
|
||||
|
||||
if isinstance(stem, Stem):
|
||||
feature_info = [dict(f, module='.'.join([feat_prefix, f['module']])) for f in stem.feature_info]
|
||||
@ -1138,13 +1158,16 @@ def create_byob_stages(
|
||||
prev_feat = dict(num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}', stage=stage_idx + 1)
|
||||
|
||||
feature_info.append(prev_feat)
|
||||
return nn.Sequential(*stages), feature_info
|
||||
return nn.Sequential(*stages), feature_info, feat_size
|
||||
|
||||
|
||||
def get_layer_fns(cfg: ByoModelCfg):
|
||||
def get_layer_fns(cfg: ByoModelCfg, allow_aa: bool = True):
|
||||
act = get_act_layer(cfg.act_layer)
|
||||
norm_act = get_norm_act_layer(norm_layer=cfg.norm_layer, act_layer=act)
|
||||
conv_norm_act = partial(ConvNormAct, norm_layer=cfg.norm_layer, act_layer=act)
|
||||
if cfg.aa_layer and allow_aa:
|
||||
conv_norm_act = partial(ConvNormActAa, norm_layer=cfg.norm_layer, act_layer=act, aa_layer=cfg.aa_layer)
|
||||
else:
|
||||
conv_norm_act = partial(ConvNormAct, norm_layer=cfg.norm_layer, act_layer=act)
|
||||
attn = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None
|
||||
self_attn = partial(get_attn(cfg.self_attn_layer), **cfg.self_attn_kwargs) if cfg.self_attn_layer else None
|
||||
layer_fn = LayerFn(conv_norm_act=conv_norm_act, norm_act=norm_act, act=act, attn=attn, self_attn=self_attn)
|
||||
@ -1191,23 +1214,33 @@ class ByobNet(nn.Module):
|
||||
self.grad_checkpointing = False
|
||||
|
||||
cfg = replace(cfg, **kwargs) # overlay kwargs onto cfg
|
||||
layers = get_layer_fns(cfg)
|
||||
stem_layers = get_layer_fns(cfg, allow_aa=False) # keep aa off for stem-layers
|
||||
stage_layers = get_layer_fns(cfg)
|
||||
if cfg.fixed_input_size:
|
||||
assert img_size is not None, 'img_size argument is required for fixed input size model'
|
||||
feat_size = to_2tuple(img_size) if img_size is not None else None
|
||||
|
||||
self.feature_info = []
|
||||
stem_chs = int(round((cfg.stem_chs or cfg.blocks[0].c) * cfg.width_factor))
|
||||
self.stem, stem_feat = create_byob_stem(in_chans, stem_chs, cfg.stem_type, cfg.stem_pool, layers=layers)
|
||||
if isinstance(cfg.stem_chs, (list, tuple)):
|
||||
stem_chs = [int(round(c * cfg.width_factor)) for c in cfg.stem_chs]
|
||||
else:
|
||||
stem_chs = int(round((cfg.stem_chs or cfg.blocks[0].c) * cfg.width_factor))
|
||||
self.stem, stem_feat = create_byob_stem(
|
||||
in_chs=in_chans,
|
||||
out_chs=stem_chs,
|
||||
stem_type=cfg.stem_type,
|
||||
pool_type=cfg.stem_pool,
|
||||
layers=stem_layers,
|
||||
)
|
||||
self.feature_info.extend(stem_feat[:-1])
|
||||
feat_size = reduce_feat_size(feat_size, stride=stem_feat[-1]['reduction'])
|
||||
|
||||
self.stages, stage_feat = create_byob_stages(
|
||||
self.stages, stage_feat, feat_size = create_byob_stages(
|
||||
cfg,
|
||||
drop_path_rate,
|
||||
output_stride,
|
||||
stem_feat[-1],
|
||||
layers=layers,
|
||||
layers=stage_layers,
|
||||
feat_size=feat_size,
|
||||
)
|
||||
self.feature_info.extend(stage_feat[:-1])
|
||||
@ -1216,7 +1249,7 @@ class ByobNet(nn.Module):
|
||||
prev_chs = stage_feat[-1]['num_chs']
|
||||
if cfg.num_features:
|
||||
self.num_features = int(round(cfg.width_factor * cfg.num_features))
|
||||
self.final_conv = layers.conv_norm_act(prev_chs, self.num_features, 1)
|
||||
self.final_conv = stage_layers.conv_norm_act(prev_chs, self.num_features, 1)
|
||||
else:
|
||||
self.num_features = prev_chs
|
||||
self.final_conv = nn.Identity()
|
||||
@ -1225,12 +1258,47 @@ class ByobNet(nn.Module):
|
||||
self.stage_ends = [f['stage'] for f in self.feature_info]
|
||||
|
||||
self.head_hidden_size = self.num_features
|
||||
self.head = ClassifierHead(
|
||||
self.num_features,
|
||||
num_classes,
|
||||
pool_type=global_pool,
|
||||
drop_rate=self.drop_rate,
|
||||
)
|
||||
assert cfg.head_type in ('', 'classifier', 'norm_mlp_classifier')
|
||||
if cfg.head_type == 'norm_mlp_classifier':
|
||||
from timm.layers import NormMlpClassifierHead
|
||||
assert not cfg.attn_pool, "Cannot use attentional pooling with norm + MLP head"
|
||||
self.attn_pool = nn.Identity()
|
||||
self.head = NormMlpClassifierHead(
|
||||
self.num_features,
|
||||
num_classes,
|
||||
hidden_size=cfg.head_hidden_size,
|
||||
norm_layer=cfg.norm_layer,
|
||||
act_layer=cfg.act_layer,
|
||||
)
|
||||
self.head_hidden_size = self.head.hidden_size
|
||||
else:
|
||||
if cfg.attn_pool == 'abs':
|
||||
from timm.layers import AttentionPool2d
|
||||
self.attn_pool = AttentionPool2d(
|
||||
self.num_features,
|
||||
out_features=cfg.head_hidden_size,
|
||||
feat_size=feat_size,
|
||||
qkv_separate=True,
|
||||
)
|
||||
self.head_hidden_size = self.attn_pool.out_features
|
||||
elif cfg.attn_pool == 'rot':
|
||||
from timm.layers import RotAttentionPool2d
|
||||
self.attn_pool = RotAttentionPool2d(
|
||||
self.num_features,
|
||||
out_features=cfg.head_hidden_size,
|
||||
ref_feat_size=feat_size,
|
||||
)
|
||||
self.head_hidden_size = self.attn_pool.out_features
|
||||
else:
|
||||
assert not cfg.attn_pool
|
||||
self.attn_pool = nn.Identity()
|
||||
|
||||
self.head = ClassifierHead(
|
||||
self.head_hidden_size,
|
||||
num_classes,
|
||||
pool_type='' if cfg.attn_pool else global_pool,
|
||||
drop_rate=self.drop_rate,
|
||||
)
|
||||
|
||||
# init weights
|
||||
named_apply(partial(_init_weights, zero_init_last=zero_init_last), self)
|
||||
@ -1345,6 +1413,7 @@ class ByobNet(nn.Module):
|
||||
return x
|
||||
|
||||
def forward_head(self, x, pre_logits: bool = False):
|
||||
x = self.attn_pool(x)
|
||||
return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
|
||||
|
||||
def forward(self, x):
|
||||
@ -1834,14 +1903,162 @@ model_cfgs = dict(
|
||||
stem_type='one',
|
||||
stem_chs=64,
|
||||
),
|
||||
|
||||
resnet50_clip=ByoModelCfg(
|
||||
blocks=(
|
||||
ByoBlockCfg(type='bottle', d=3, c=256, s=1, br=0.25),
|
||||
ByoBlockCfg(type='bottle', d=4, c=512, s=2, br=0.25),
|
||||
ByoBlockCfg(type='bottle', d=6, c=1024, s=2, br=0.25),
|
||||
ByoBlockCfg(type='bottle', d=3, c=2048, s=2, br=0.25),
|
||||
),
|
||||
stem_chs=(32, 32, 64),
|
||||
stem_type='',
|
||||
stem_pool='avg2',
|
||||
downsample='avg',
|
||||
aa_layer='avg',
|
||||
attn_pool='abs',
|
||||
head_hidden_size=1024,
|
||||
),
|
||||
|
||||
resnet101_clip=ByoModelCfg(
|
||||
blocks=(
|
||||
ByoBlockCfg(type='bottle', d=3, c=256, s=1, br=0.25),
|
||||
ByoBlockCfg(type='bottle', d=4, c=512, s=2, br=0.25),
|
||||
ByoBlockCfg(type='bottle', d=23, c=1024, s=2, br=0.25),
|
||||
ByoBlockCfg(type='bottle', d=3, c=2048, s=2, br=0.25),
|
||||
),
|
||||
stem_chs=(32, 32, 64),
|
||||
stem_type='',
|
||||
stem_pool='avg2',
|
||||
downsample='avg',
|
||||
aa_layer='avg',
|
||||
attn_pool='abs',
|
||||
head_hidden_size=512,
|
||||
),
|
||||
|
||||
resnet50x4_clip=ByoModelCfg(
|
||||
blocks=(
|
||||
ByoBlockCfg(type='bottle', d=4, c=256, s=1, br=0.25),
|
||||
ByoBlockCfg(type='bottle', d=6, c=512, s=2, br=0.25),
|
||||
ByoBlockCfg(type='bottle', d=10, c=1024, s=2, br=0.25),
|
||||
ByoBlockCfg(type='bottle', d=6, c=2048, s=2, br=0.25),
|
||||
),
|
||||
width_factor=1.25,
|
||||
stem_chs=(32, 32, 64),
|
||||
stem_type='',
|
||||
stem_pool='avg2',
|
||||
downsample='avg',
|
||||
aa_layer='avg',
|
||||
attn_pool='abs',
|
||||
head_hidden_size=640,
|
||||
),
|
||||
|
||||
resnet50x16_clip=ByoModelCfg(
|
||||
blocks=(
|
||||
ByoBlockCfg(type='bottle', d=6, c=256, s=1, br=0.25),
|
||||
ByoBlockCfg(type='bottle', d=8, c=512, s=2, br=0.25),
|
||||
ByoBlockCfg(type='bottle', d=18, c=1024, s=2, br=0.25),
|
||||
ByoBlockCfg(type='bottle', d=8, c=2048, s=2, br=0.25),
|
||||
),
|
||||
stem_chs=(32, 32, 64),
|
||||
stem_type='',
|
||||
stem_pool='avg2',
|
||||
downsample='avg',
|
||||
aa_layer='avg',
|
||||
attn_pool='abs',
|
||||
head_hidden_size=768,
|
||||
),
|
||||
|
||||
resnet50x64_clip=ByoModelCfg(
|
||||
blocks=(
|
||||
ByoBlockCfg(type='bottle', d=3, c=256, s=1, br=0.25),
|
||||
ByoBlockCfg(type='bottle', d=15, c=512, s=2, br=0.25),
|
||||
ByoBlockCfg(type='bottle', d=36, c=1024, s=2, br=0.25),
|
||||
ByoBlockCfg(type='bottle', d=10, c=2048, s=2, br=0.25),
|
||||
),
|
||||
stem_chs=(32, 32, 64),
|
||||
stem_type='',
|
||||
stem_pool='avg2',
|
||||
downsample='avg',
|
||||
aa_layer='avg',
|
||||
attn_pool='abs',
|
||||
head_hidden_size=1024,
|
||||
),
|
||||
|
||||
resnet50_nmlp=ByoModelCfg(
|
||||
blocks=(
|
||||
ByoBlockCfg(type='bottle', d=3, c=256, s=1, br=0.25),
|
||||
ByoBlockCfg(type='bottle', d=4, c=512, s=2, br=0.25),
|
||||
ByoBlockCfg(type='bottle', d=6, c=1024, s=2, br=0.25),
|
||||
ByoBlockCfg(type='bottle', d=3, c=2048, s=2, br=0.25),
|
||||
),
|
||||
stem_chs=(32, 32, 64),
|
||||
stem_type='',
|
||||
stem_pool='avg2',
|
||||
downsample='avg',
|
||||
aa_layer='avg',
|
||||
head_hidden_size=1024,
|
||||
head_type='norm_mlp_classifier',
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _convert_openai_clip(
|
||||
state_dict: Dict[str, torch.Tensor],
|
||||
model: ByobNet,
|
||||
prefix: str = 'visual.',
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
import re
|
||||
|
||||
def _stage_sub(m):
|
||||
stage_idx = int(m.group(1)) - 1
|
||||
layer_idx, layer_type, layer_id = int(m.group(2)), m.group(3), int(m.group(4))
|
||||
prefix_str = f'stages.{stage_idx}.{layer_idx}.'
|
||||
id_map = {1: 'conv1_1x1.', 2: 'conv2_kxk.', 3: 'conv3_1x1.'}
|
||||
suffix_str = id_map[layer_id] + layer_type
|
||||
return prefix_str + suffix_str
|
||||
|
||||
def _down_sub(m):
|
||||
stage_idx = int(m.group(1)) - 1
|
||||
layer_idx, layer_id = int(m.group(2)), int(m.group(3))
|
||||
return f'stages.{stage_idx}.{layer_idx}.shortcut.' + ('conv.conv' if layer_id == 0 else 'conv.bn')
|
||||
|
||||
out_dict = {}
|
||||
for k, v in state_dict.items():
|
||||
if not k.startswith(prefix):
|
||||
continue
|
||||
k = re.sub(rf'{prefix}conv([0-9])', r'stem.conv\1.conv', k)
|
||||
k = re.sub(rf'{prefix}bn([0-9])', r'stem.conv\1.bn', k)
|
||||
k = re.sub(rf'{prefix}layer([0-9])\.([0-9])\.([a-z]+)([0-9])', _stage_sub, k)
|
||||
k = re.sub(rf'{prefix}layer([0-9])\.([0-9])\.downsample\.([0-9])', _down_sub, k)
|
||||
if k.startswith(f'{prefix}attnpool'):
|
||||
k = k.replace(prefix + 'attnpool', 'attn_pool')
|
||||
k = k.replace('positional_embedding', 'pos_embed')
|
||||
k = k.replace('q_proj', 'q')
|
||||
k = k.replace('k_proj', 'k')
|
||||
k = k.replace('v_proj', 'v')
|
||||
k = k.replace('c_proj', 'proj')
|
||||
out_dict[k] = v
|
||||
|
||||
return out_dict
|
||||
|
||||
|
||||
def checkpoint_filter_fn(
|
||||
state_dict: Dict[str, torch.Tensor],
|
||||
model: ByobNet
|
||||
):
|
||||
if 'visual.conv1.weight' in state_dict:
|
||||
state_dict = _convert_openai_clip(state_dict, model)
|
||||
return state_dict
|
||||
|
||||
|
||||
def _create_byobnet(variant, pretrained=False, **kwargs):
|
||||
return build_model_with_cfg(
|
||||
ByobNet, variant, pretrained,
|
||||
model_cfg=model_cfgs[variant],
|
||||
pretrained_filter_fn=checkpoint_filter_fn,
|
||||
feature_cfg=dict(flatten_sequential=True),
|
||||
#pretrained_strict=False,
|
||||
**kwargs)
|
||||
|
||||
|
||||
@ -2035,6 +2252,38 @@ default_cfgs = generate_default_cfgs({
|
||||
crop_pct=0.9,
|
||||
first_conv=('stem.conv_kxk.0.conv', 'stem.conv_scale.conv'),
|
||||
),
|
||||
|
||||
'resnet50_clip.openai': _cfgr(
|
||||
hf_hub_id='timm/',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
||||
fixed_input_size=True, input_size=(3, 224, 224), pool_size=(7, 7)
|
||||
),
|
||||
'resnet101_clip.openai': _cfgr(
|
||||
hf_hub_id='timm/',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
||||
fixed_input_size=True, input_size=(3, 224, 224), pool_size=(7, 7)
|
||||
),
|
||||
'resnet50x4_clip.openai': _cfgr(
|
||||
hf_hub_id='timm/',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
||||
fixed_input_size=True, input_size=(3, 288, 288), pool_size=(9, 9)
|
||||
),
|
||||
'resnet50x16_clip.openai': _cfgr(
|
||||
hf_hub_id='timm/',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
||||
fixed_input_size=True, input_size=(3, 384, 384), pool_size=(12, 12)
|
||||
),
|
||||
'resnet50x64_clip.openai': _cfgr(
|
||||
hf_hub_id='timm/',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
||||
fixed_input_size=True, input_size=(3, 448, 448), pool_size=(14, 14)
|
||||
),
|
||||
|
||||
})
|
||||
|
||||
|
||||
@ -2337,3 +2586,45 @@ def mobileone_s4(pretrained=False, **kwargs) -> ByobNet:
|
||||
"""
|
||||
"""
|
||||
return _create_byobnet('mobileone_s4', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnet50_clip(pretrained=False, **kwargs) -> ByobNet:
|
||||
""" OpenAI Modified ResNet-50 CLIP image tower
|
||||
"""
|
||||
return _create_byobnet('resnet50_clip', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnet101_clip(pretrained=False, **kwargs) -> ByobNet:
|
||||
""" OpenAI Modified ResNet-101 CLIP image tower
|
||||
"""
|
||||
return _create_byobnet('resnet101_clip', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnet50x4_clip(pretrained=False, **kwargs) -> ByobNet:
|
||||
""" OpenAI Modified ResNet-50x4 CLIP image tower
|
||||
"""
|
||||
return _create_byobnet('resnet50x4_clip', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnet50x16_clip(pretrained=False, **kwargs) -> ByobNet:
|
||||
""" OpenAI Modified ResNet-50x16 CLIP image tower
|
||||
"""
|
||||
return _create_byobnet('resnet50x16_clip', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnet50x64_clip(pretrained=False, **kwargs) -> ByobNet:
|
||||
""" OpenAI Modified ResNet-50x64 CLIP image tower
|
||||
"""
|
||||
return _create_byobnet('resnet50x64_clip', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnet50_nmlp(pretrained=False, **kwargs) -> ByobNet:
|
||||
"""
|
||||
"""
|
||||
return _create_byobnet('resnet50_nmlp', pretrained=pretrained, **kwargs)
|
||||
|
Loading…
x
Reference in New Issue
Block a user