mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Refactoring, cleanup, improved test coverage.
* Add eca_nfnet_l2 weights, 84.7 @ 384x384 * All 'non-std' (ie transformer / mlp) models have classifier / default_cfg test added * Fix #694 reset_classifer / num_features / forward_features / num_classes=0 consistency for transformer / mlp models * Add direct loading of npz to vision transformer (pure transformer so far, hybrid to come) * Rename vit_deit* to deit_* * Remove some deprecated vit hybrid model defs * Clean up classifier flatten for conv classifiers and unusual cases (mobilenetv3/ghostnet) * Remove explicit model fns for levit conv, just pass in arg
This commit is contained in:
parent
ba2ca4b464
commit
8880f696b6
@ -17,7 +17,7 @@ if hasattr(torch._C, '_jit_set_profiling_executor'):
|
||||
# transformer models don't support many of the spatial / feature based model functionalities
|
||||
NON_STD_FILTERS = [
|
||||
'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
|
||||
'convit_*', 'levit*', 'visformer*']
|
||||
'convit_*', 'levit*', 'visformer*', 'deit*']
|
||||
NUM_NON_STD = len(NON_STD_FILTERS)
|
||||
|
||||
# exclude models that cause specific test failures
|
||||
@ -120,7 +120,6 @@ def test_model_default_cfgs(model_name, batch_size):
|
||||
state_dict = model.state_dict()
|
||||
cfg = model.default_cfg
|
||||
|
||||
classifier = cfg['classifier']
|
||||
pool_size = cfg['pool_size']
|
||||
input_size = model.default_cfg['input_size']
|
||||
|
||||
@ -149,7 +148,57 @@ def test_model_default_cfgs(model_name, batch_size):
|
||||
assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2]
|
||||
|
||||
# check classifier name matches default_cfg
|
||||
assert classifier + ".weight" in state_dict.keys(), f'{classifier} not in model params'
|
||||
classifier = cfg['classifier']
|
||||
if not isinstance(classifier, (tuple, list)):
|
||||
classifier = classifier,
|
||||
for c in classifier:
|
||||
assert c + ".weight" in state_dict.keys(), f'{c} not in model params'
|
||||
|
||||
# check first conv(s) names match default_cfg
|
||||
first_conv = cfg['first_conv']
|
||||
if isinstance(first_conv, str):
|
||||
first_conv = (first_conv,)
|
||||
assert isinstance(first_conv, (tuple, list))
|
||||
for fc in first_conv:
|
||||
assert fc + ".weight" in state_dict.keys(), f'{fc} not in model params'
|
||||
|
||||
|
||||
@pytest.mark.timeout(300)
|
||||
@pytest.mark.parametrize('model_name', list_models(filter=NON_STD_FILTERS))
|
||||
@pytest.mark.parametrize('batch_size', [1])
|
||||
def test_model_default_cfgs_non_std(model_name, batch_size):
|
||||
"""Run a single forward pass with each model"""
|
||||
model = create_model(model_name, pretrained=False)
|
||||
model.eval()
|
||||
state_dict = model.state_dict()
|
||||
cfg = model.default_cfg
|
||||
|
||||
input_size = _get_input_size(model_name=model_name, target=TARGET_FWD_SIZE)
|
||||
if max(input_size) > MAX_FWD_SIZE:
|
||||
pytest.skip("Fixed input size model > limit.")
|
||||
|
||||
input_tensor = torch.randn((batch_size, *input_size))
|
||||
|
||||
# test forward_features (always unpooled)
|
||||
outputs = model.forward_features(input_tensor)
|
||||
if isinstance(outputs, tuple):
|
||||
outputs = outputs[0]
|
||||
assert outputs.shape[1] == model.num_features
|
||||
|
||||
# test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features
|
||||
model.reset_classifier(0)
|
||||
outputs = model.forward(input_tensor)
|
||||
if isinstance(outputs, tuple):
|
||||
outputs = outputs[0]
|
||||
assert len(outputs.shape) == 2
|
||||
assert outputs.shape[1] == model.num_features
|
||||
|
||||
# check classifier name matches default_cfg
|
||||
classifier = cfg['classifier']
|
||||
if not isinstance(classifier, (tuple, list)):
|
||||
classifier = classifier,
|
||||
for c in classifier:
|
||||
assert c + ".weight" in state_dict.keys(), f'{c} not in model params'
|
||||
|
||||
# check first conv(s) names match default_cfg
|
||||
first_conv = cfg['first_conv']
|
||||
|
@ -74,11 +74,11 @@ default_cfgs = dict(
|
||||
class ClassAttn(nn.Module):
|
||||
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
||||
# with slight modifications to do CA
|
||||
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
||||
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = qk_scale or head_dim ** -0.5
|
||||
self.scale = head_dim ** -0.5
|
||||
|
||||
self.q = nn.Linear(dim, dim, bias=qkv_bias)
|
||||
self.k = nn.Linear(dim, dim, bias=qkv_bias)
|
||||
@ -110,13 +110,13 @@ class LayerScaleBlockClassAttn(nn.Module):
|
||||
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
||||
# with slight modifications to add CA and LayerScale
|
||||
def __init__(
|
||||
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
||||
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
|
||||
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, attn_block=ClassAttn,
|
||||
mlp_block=Mlp, init_values=1e-4):
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = attn_block(
|
||||
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
||||
dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
@ -134,14 +134,14 @@ class LayerScaleBlockClassAttn(nn.Module):
|
||||
class TalkingHeadAttn(nn.Module):
|
||||
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
||||
# with slight modifications to add Talking Heads Attention (https://arxiv.org/pdf/2003.02436v1.pdf)
|
||||
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
||||
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
|
||||
super().__init__()
|
||||
|
||||
self.num_heads = num_heads
|
||||
|
||||
head_dim = dim // num_heads
|
||||
|
||||
self.scale = qk_scale or head_dim ** -0.5
|
||||
self.scale = head_dim ** -0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
@ -177,13 +177,13 @@ class LayerScaleBlock(nn.Module):
|
||||
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
||||
# with slight modifications to add layerScale
|
||||
def __init__(
|
||||
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
||||
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
|
||||
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, attn_block=TalkingHeadAttn,
|
||||
mlp_block=Mlp, init_values=1e-4):
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = attn_block(
|
||||
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
||||
dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
@ -202,7 +202,7 @@ class Cait(nn.Module):
|
||||
# with slight modifications to adapt to our cait models
|
||||
def __init__(
|
||||
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
||||
num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
|
||||
num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
global_pool=None,
|
||||
@ -235,14 +235,14 @@ class Cait(nn.Module):
|
||||
dpr = [drop_path_rate for i in range(depth)]
|
||||
self.blocks = nn.ModuleList([
|
||||
block_layers(
|
||||
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
||||
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
|
||||
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
|
||||
act_layer=act_layer, attn_block=attn_block, mlp_block=mlp_block, init_values=init_scale)
|
||||
for i in range(depth)])
|
||||
|
||||
self.blocks_token_only = nn.ModuleList([
|
||||
block_layers_token(
|
||||
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio_clstk, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
||||
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio_clstk, qkv_bias=qkv_bias,
|
||||
drop=0.0, attn_drop=0.0, drop_path=0.0, norm_layer=norm_layer,
|
||||
act_layer=act_layer, attn_block=attn_block_token_only,
|
||||
mlp_block=mlp_block_token_only, init_values=init_scale)
|
||||
@ -270,6 +270,13 @@ class Cait(nn.Module):
|
||||
def no_weight_decay(self):
|
||||
return {'pos_embed', 'cls_token'}
|
||||
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=''):
|
||||
self.num_classes = num_classes
|
||||
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def forward_features(self, x):
|
||||
B = x.shape[0]
|
||||
x = self.patch_embed(x)
|
||||
@ -293,7 +300,6 @@ class Cait(nn.Module):
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
x = self.head(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
|
@ -335,6 +335,8 @@ class CoaT(nn.Module):
|
||||
crpe_window = crpe_window or {3: 2, 5: 3, 7: 3}
|
||||
self.return_interm_layers = return_interm_layers
|
||||
self.out_features = out_features
|
||||
self.embed_dims = embed_dims
|
||||
self.num_features = embed_dims[-1]
|
||||
self.num_classes = num_classes
|
||||
|
||||
# Patch embeddings.
|
||||
@ -441,10 +443,10 @@ class CoaT(nn.Module):
|
||||
# CoaT series: Aggregate features of last three scales for classification.
|
||||
assert embed_dims[1] == embed_dims[2] == embed_dims[3]
|
||||
self.aggregate = torch.nn.Conv1d(in_channels=3, out_channels=1, kernel_size=1)
|
||||
self.head = nn.Linear(embed_dims[3], num_classes)
|
||||
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
else:
|
||||
# CoaT-Lite series: Use feature of last scale for classification.
|
||||
self.head = nn.Linear(embed_dims[3], num_classes)
|
||||
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
# Initialize weights.
|
||||
trunc_normal_(self.cls_token1, std=.02)
|
||||
@ -471,7 +473,7 @@ class CoaT(nn.Module):
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=''):
|
||||
self.num_classes = num_classes
|
||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def insert_cls(self, x, cls_token):
|
||||
""" Insert CLS token. """
|
||||
|
@ -57,13 +57,13 @@ default_cfgs = {
|
||||
|
||||
|
||||
class GPSA(nn.Module):
|
||||
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.,
|
||||
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.,
|
||||
locality_strength=1.):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.dim = dim
|
||||
head_dim = dim // num_heads
|
||||
self.scale = qk_scale or head_dim ** -0.5
|
||||
self.scale = head_dim ** -0.5
|
||||
self.locality_strength = locality_strength
|
||||
|
||||
self.qk = nn.Linear(dim, dim * 2, bias=qkv_bias)
|
||||
@ -142,11 +142,11 @@ class GPSA(nn.Module):
|
||||
|
||||
|
||||
class MHSA(nn.Module):
|
||||
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
||||
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = qk_scale or head_dim ** -0.5
|
||||
self.scale = head_dim ** -0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
@ -191,19 +191,16 @@ class MHSA(nn.Module):
|
||||
|
||||
class Block(nn.Module):
|
||||
|
||||
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
||||
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
|
||||
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_gpsa=True, **kwargs):
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.use_gpsa = use_gpsa
|
||||
if self.use_gpsa:
|
||||
self.attn = GPSA(
|
||||
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop,
|
||||
proj_drop=drop, **kwargs)
|
||||
dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, **kwargs)
|
||||
else:
|
||||
self.attn = MHSA(
|
||||
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop,
|
||||
proj_drop=drop, **kwargs)
|
||||
self.attn = MHSA(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
@ -220,7 +217,7 @@ class ConViT(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
||||
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
|
||||
num_heads=12, mlp_ratio=4., qkv_bias=False, drop_rate=0., attn_drop_rate=0.,
|
||||
drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm, global_pool=None,
|
||||
local_up_to_layer=3, locality_strength=1., use_pos_embed=True):
|
||||
super().__init__()
|
||||
@ -250,13 +247,13 @@ class ConViT(nn.Module):
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
||||
self.blocks = nn.ModuleList([
|
||||
Block(
|
||||
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
||||
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
|
||||
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
|
||||
use_gpsa=True,
|
||||
locality_strength=locality_strength)
|
||||
if i < local_up_to_layer else
|
||||
Block(
|
||||
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
||||
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
|
||||
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
|
||||
use_gpsa=False)
|
||||
for i in range(depth)])
|
||||
|
@ -288,6 +288,8 @@ class DLA(nn.Module):
|
||||
self.num_features = channels[-1]
|
||||
self.global_pool, self.fc = create_classifier(
|
||||
self.num_features, self.num_classes, pool_type=global_pool, use_conv=True)
|
||||
self.flatten = nn.Flatten(1) if global_pool else nn.Identity()
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
@ -314,6 +316,7 @@ class DLA(nn.Module):
|
||||
self.num_classes = num_classes
|
||||
self.global_pool, self.fc = create_classifier(
|
||||
self.num_features, self.num_classes, pool_type=global_pool, use_conv=True)
|
||||
self.flatten = nn.Flatten(1) if global_pool else nn.Identity()
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.base_layer(x)
|
||||
@ -331,8 +334,7 @@ class DLA(nn.Module):
|
||||
if self.drop_rate > 0.:
|
||||
x = F.dropout(x, p=self.drop_rate, training=self.training)
|
||||
x = self.fc(x)
|
||||
if not self.global_pool.is_identity():
|
||||
x = x.flatten(1) # conv classifier, flatten if pooling isn't pass-through (disabled)
|
||||
x = self.flatten(x)
|
||||
return x
|
||||
|
||||
|
||||
|
@ -237,6 +237,7 @@ class DPN(nn.Module):
|
||||
# Using 1x1 conv for the FC layer to allow the extra pooling scheme
|
||||
self.global_pool, self.classifier = create_classifier(
|
||||
self.num_features, self.num_classes, pool_type=global_pool, use_conv=True)
|
||||
self.flatten = nn.Flatten(1) if global_pool else nn.Identity()
|
||||
|
||||
def get_classifier(self):
|
||||
return self.classifier
|
||||
@ -245,6 +246,7 @@ class DPN(nn.Module):
|
||||
self.num_classes = num_classes
|
||||
self.global_pool, self.classifier = create_classifier(
|
||||
self.num_features, self.num_classes, pool_type=global_pool, use_conv=True)
|
||||
self.flatten = nn.Flatten(1) if global_pool else nn.Identity()
|
||||
|
||||
def forward_features(self, x):
|
||||
return self.features(x)
|
||||
@ -255,8 +257,7 @@ class DPN(nn.Module):
|
||||
if self.drop_rate > 0.:
|
||||
x = F.dropout(x, p=self.drop_rate, training=self.training)
|
||||
x = self.classifier(x)
|
||||
if not self.global_pool.is_identity():
|
||||
x = x.flatten(1) # conv classifier, flatten if pooling isn't pass-through (disabled)
|
||||
x = self.flatten(x)
|
||||
return x
|
||||
|
||||
|
||||
|
@ -133,7 +133,7 @@ class GhostBottleneck(nn.Module):
|
||||
|
||||
|
||||
class GhostNet(nn.Module):
|
||||
def __init__(self, cfgs, num_classes=1000, width=1.0, dropout=0.2, in_chans=3, output_stride=32):
|
||||
def __init__(self, cfgs, num_classes=1000, width=1.0, dropout=0.2, in_chans=3, output_stride=32, global_pool='avg'):
|
||||
super(GhostNet, self).__init__()
|
||||
# setting of inverted residual blocks
|
||||
assert output_stride == 32, 'only output_stride==32 is valid, dilation not supported'
|
||||
@ -178,9 +178,10 @@ class GhostNet(nn.Module):
|
||||
|
||||
# building last several layers
|
||||
self.num_features = out_chs = 1280
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type='avg')
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||
self.conv_head = nn.Conv2d(prev_chs, out_chs, 1, 1, 0, bias=True)
|
||||
self.act2 = nn.ReLU(inplace=True)
|
||||
self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
|
||||
self.classifier = Linear(out_chs, num_classes)
|
||||
|
||||
def get_classifier(self):
|
||||
@ -190,6 +191,7 @@ class GhostNet(nn.Module):
|
||||
self.num_classes = num_classes
|
||||
# cannot meaningfully change pooling of efficient head after creation
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||
self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
|
||||
self.classifier = Linear(self.pool_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def forward_features(self, x):
|
||||
@ -204,8 +206,7 @@ class GhostNet(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
if not self.global_pool.is_identity():
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.flatten(x)
|
||||
if self.dropout > 0.:
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
x = self.classifier(x)
|
||||
|
@ -45,6 +45,13 @@ def load_state_dict(checkpoint_path, use_ema=False):
|
||||
|
||||
|
||||
def load_checkpoint(model, checkpoint_path, use_ema=False, strict=True):
|
||||
if os.path.splitext(checkpoint_path)[-1].lower() in ('.npz', '.npy'):
|
||||
# numpy checkpoint, try to load via model specific load_pretrained fn
|
||||
if hasattr(model, 'load_pretrained'):
|
||||
model.load_pretrained(checkpoint_path)
|
||||
else:
|
||||
raise NotImplementedError('Model cannot load numpy checkpoint')
|
||||
return
|
||||
state_dict = load_state_dict(checkpoint_path, use_ema)
|
||||
model.load_state_dict(state_dict, strict=strict)
|
||||
|
||||
@ -477,3 +484,25 @@ def model_parameters(model, exclude_head=False):
|
||||
return [p for p in model.parameters()][:-2]
|
||||
else:
|
||||
return model.parameters()
|
||||
|
||||
|
||||
def named_apply(fn: Callable, module: nn.Module, name='', depth_first=True, include_root=False) -> nn.Module:
|
||||
if not depth_first and include_root:
|
||||
fn(module=module, name=name)
|
||||
for child_name, child_module in module.named_children():
|
||||
child_name = '.'.join((name, child_name)) if name else child_name
|
||||
named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
|
||||
if depth_first and include_root:
|
||||
fn(module=module, name=name)
|
||||
return module
|
||||
|
||||
|
||||
def named_modules(module: nn.Module, name='', depth_first=True, include_root=False):
|
||||
if not depth_first and include_root:
|
||||
yield name, module
|
||||
for child_name, child_module in module.named_children():
|
||||
child_name = '.'.join((name, child_name)) if name else child_name
|
||||
yield from named_modules(
|
||||
module=child_module, name=child_name, depth_first=depth_first, include_root=True)
|
||||
if depth_first and include_root:
|
||||
yield name, module
|
||||
|
@ -55,7 +55,7 @@ class FastAdaptiveAvgPool2d(nn.Module):
|
||||
self.flatten = flatten
|
||||
|
||||
def forward(self, x):
|
||||
return x.mean((2, 3)) if self.flatten else x.mean((2, 3), keepdim=True)
|
||||
return x.mean((2, 3), keepdim=not self.flatten)
|
||||
|
||||
|
||||
class AdaptiveAvgMaxPool2d(nn.Module):
|
||||
@ -82,13 +82,13 @@ class SelectAdaptivePool2d(nn.Module):
|
||||
def __init__(self, output_size=1, pool_type='fast', flatten=False):
|
||||
super(SelectAdaptivePool2d, self).__init__()
|
||||
self.pool_type = pool_type or '' # convert other falsy values to empty string for consistent TS typing
|
||||
self.flatten = flatten
|
||||
self.flatten = nn.Flatten(1) if flatten else nn.Identity()
|
||||
if pool_type == '':
|
||||
self.pool = nn.Identity() # pass through
|
||||
elif pool_type == 'fast':
|
||||
assert output_size == 1
|
||||
self.pool = FastAdaptiveAvgPool2d(self.flatten)
|
||||
self.flatten = False
|
||||
self.pool = FastAdaptiveAvgPool2d(flatten)
|
||||
self.flatten = nn.Identity()
|
||||
elif pool_type == 'avg':
|
||||
self.pool = nn.AdaptiveAvgPool2d(output_size)
|
||||
elif pool_type == 'avgmax':
|
||||
@ -101,12 +101,11 @@ class SelectAdaptivePool2d(nn.Module):
|
||||
assert False, 'Invalid pool type: %s' % pool_type
|
||||
|
||||
def is_identity(self):
|
||||
return self.pool_type == ''
|
||||
return not self.pool_type
|
||||
|
||||
def forward(self, x):
|
||||
x = self.pool(x)
|
||||
if self.flatten:
|
||||
x = x.flatten(1)
|
||||
x = self.flatten(x)
|
||||
return x
|
||||
|
||||
def feat_mult(self):
|
||||
|
@ -20,7 +20,7 @@ def _create_pool(num_features, num_classes, pool_type='avg', use_conv=False):
|
||||
return global_pool, num_pooled_features
|
||||
|
||||
|
||||
def _create_fc(num_features, num_classes, pool_type='avg', use_conv=False):
|
||||
def _create_fc(num_features, num_classes, use_conv=False):
|
||||
if num_classes <= 0:
|
||||
fc = nn.Identity() # pass-through (no classifier)
|
||||
elif use_conv:
|
||||
@ -45,11 +45,12 @@ class ClassifierHead(nn.Module):
|
||||
self.drop_rate = drop_rate
|
||||
self.global_pool, num_pooled_features = _create_pool(in_chs, num_classes, pool_type, use_conv=use_conv)
|
||||
self.fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv)
|
||||
self.flatten_after_fc = use_conv and pool_type
|
||||
self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.global_pool(x)
|
||||
if self.drop_rate:
|
||||
x = F.dropout(x, p=float(self.drop_rate), training=self.training)
|
||||
x = self.fc(x)
|
||||
x = self.flatten(x)
|
||||
return x
|
||||
|
@ -40,6 +40,12 @@ class GluMlp(nn.Module):
|
||||
self.fc2 = nn.Linear(hidden_features // 2, out_features)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def init_weights(self):
|
||||
# override init of fc1 w/ gate portion set to weight near zero, bias=1
|
||||
fc1_mid = self.fc1.bias.shape[0] // 2
|
||||
nn.init.ones_(self.fc1.bias[fc1_mid:])
|
||||
nn.init.normal_(self.fc1.weight[fc1_mid:], std=1e-6)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x, gates = x.chunk(2, dim=-1)
|
||||
|
@ -84,63 +84,33 @@ __all__ = ['Levit']
|
||||
|
||||
|
||||
@register_model
|
||||
def levit_128s(pretrained=False, fuse=False,distillation=True, use_conv=False, **kwargs):
|
||||
def levit_128s(pretrained=False, use_conv=False, **kwargs):
|
||||
return create_levit(
|
||||
'levit_128s', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
|
||||
'levit_128s', pretrained=pretrained, use_conv=use_conv, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def levit_128(pretrained=False, fuse=False, distillation=True, use_conv=False, **kwargs):
|
||||
def levit_128(pretrained=False, use_conv=False, **kwargs):
|
||||
return create_levit(
|
||||
'levit_128', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
|
||||
'levit_128', pretrained=pretrained, use_conv=use_conv, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def levit_192(pretrained=False, fuse=False, distillation=True, use_conv=False, **kwargs):
|
||||
def levit_192(pretrained=False, use_conv=False, **kwargs):
|
||||
return create_levit(
|
||||
'levit_192', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
|
||||
'levit_192', pretrained=pretrained, use_conv=use_conv, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def levit_256(pretrained=False, fuse=False, distillation=True, use_conv=False, **kwargs):
|
||||
def levit_256(pretrained=False, use_conv=False, **kwargs):
|
||||
return create_levit(
|
||||
'levit_256', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
|
||||
'levit_256', pretrained=pretrained, use_conv=use_conv, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def levit_384(pretrained=False, fuse=False, distillation=True, use_conv=False, **kwargs):
|
||||
def levit_384(pretrained=False, use_conv=False, **kwargs):
|
||||
return create_levit(
|
||||
'levit_384', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def levit_c_128s(pretrained=False, fuse=False, distillation=True, use_conv=True,**kwargs):
|
||||
return create_levit(
|
||||
'levit_128s', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def levit_c_128(pretrained=False, fuse=False,distillation=True, use_conv=True, **kwargs):
|
||||
return create_levit(
|
||||
'levit_128', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def levit_c_192(pretrained=False, fuse=False, distillation=True, use_conv=True, **kwargs):
|
||||
return create_levit(
|
||||
'levit_192', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def levit_c_256(pretrained=False, fuse=False, distillation=True, use_conv=True, **kwargs):
|
||||
return create_levit(
|
||||
'levit_256', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def levit_c_384(pretrained=False, fuse=False, distillation=True, use_conv=True, **kwargs):
|
||||
return create_levit(
|
||||
'levit_384', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
|
||||
'levit_384', pretrained=pretrained, use_conv=use_conv, **kwargs)
|
||||
|
||||
|
||||
class ConvNorm(nn.Sequential):
|
||||
@ -427,6 +397,9 @@ class AttentionSubsample(nn.Module):
|
||||
|
||||
class Levit(nn.Module):
|
||||
""" Vision Transformer with support for patch or hybrid CNN input stage
|
||||
|
||||
NOTE: distillation is defaulted to True since pretrained weights use it, will cause problems
|
||||
w/ train scripts that don't take tuple outputs,
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -447,7 +420,8 @@ class Levit(nn.Module):
|
||||
attn_act_layer='hard_swish',
|
||||
distillation=True,
|
||||
use_conv=False,
|
||||
drop_path=0):
|
||||
drop_rate=0.,
|
||||
drop_path_rate=0.):
|
||||
super().__init__()
|
||||
act_layer = get_act_layer(act_layer)
|
||||
attn_act_layer = get_act_layer(attn_act_layer)
|
||||
@ -486,7 +460,7 @@ class Levit(nn.Module):
|
||||
Attention(
|
||||
ed, kd, nh, attn_ratio=ar, act_layer=attn_act_layer,
|
||||
resolution=resolution, use_conv=use_conv),
|
||||
drop_path))
|
||||
drop_path_rate))
|
||||
if mr > 0:
|
||||
h = int(ed * mr)
|
||||
self.blocks.append(
|
||||
@ -494,7 +468,7 @@ class Levit(nn.Module):
|
||||
ln_layer(ed, h, resolution=resolution),
|
||||
act_layer(),
|
||||
ln_layer(h, ed, bn_weight_init=0, resolution=resolution),
|
||||
), drop_path))
|
||||
), drop_path_rate))
|
||||
if do[0] == 'Subsample':
|
||||
# ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride)
|
||||
resolution_ = (resolution - 1) // do[5] + 1
|
||||
@ -511,26 +485,45 @@ class Levit(nn.Module):
|
||||
ln_layer(embed_dim[i + 1], h, resolution=resolution),
|
||||
act_layer(),
|
||||
ln_layer(h, embed_dim[i + 1], bn_weight_init=0, resolution=resolution),
|
||||
), drop_path))
|
||||
), drop_path_rate))
|
||||
self.blocks = nn.Sequential(*self.blocks)
|
||||
|
||||
# Classifier head
|
||||
self.head = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity()
|
||||
self.head_dist = None
|
||||
if distillation:
|
||||
self.head_dist = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity()
|
||||
else:
|
||||
self.head_dist = None
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay(self):
|
||||
return {x for x in self.state_dict().keys() if 'attention_biases' in x}
|
||||
|
||||
def forward(self, x):
|
||||
def get_classifier(self):
|
||||
if self.head_dist is None:
|
||||
return self.head
|
||||
else:
|
||||
return self.head, self.head_dist
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool='', distillation=None):
|
||||
self.num_classes = num_classes
|
||||
self.head = NormLinear(self.embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity()
|
||||
if distillation is not None:
|
||||
self.distillation = distillation
|
||||
if self.distillation:
|
||||
self.head_dist = NormLinear(self.embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity()
|
||||
else:
|
||||
self.head_dist = None
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.patch_embed(x)
|
||||
if not self.use_conv:
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
x = self.blocks(x)
|
||||
x = x.mean((-2, -1)) if self.use_conv else x.mean(1)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
if self.head_dist is not None:
|
||||
x, x_dist = self.head(x), self.head_dist(x)
|
||||
if self.training and not torch.jit.is_scripting():
|
||||
|
@ -45,7 +45,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .helpers import build_model_with_cfg, overlay_external_default_cfg
|
||||
from .helpers import build_model_with_cfg, overlay_external_default_cfg, named_apply
|
||||
from .layers import PatchEmbed, Mlp, GluMlp, GatedMlp, DropPath, lecun_normal_, to_2tuple
|
||||
from .registry import register_model
|
||||
|
||||
@ -169,6 +169,11 @@ class SpatialGatingUnit(nn.Module):
|
||||
self.norm = norm_layer(gate_dim)
|
||||
self.proj = nn.Linear(seq_len, seq_len)
|
||||
|
||||
def init_weights(self):
|
||||
# special init for the projection gate, called as override by base model init
|
||||
nn.init.normal_(self.proj.weight, std=1e-6)
|
||||
nn.init.ones_(self.proj.bias)
|
||||
|
||||
def forward(self, x):
|
||||
u, v = x.chunk(2, dim=-1)
|
||||
v = self.norm(v)
|
||||
@ -205,7 +210,7 @@ class MlpMixer(nn.Module):
|
||||
in_chans=3,
|
||||
patch_size=16,
|
||||
num_blocks=8,
|
||||
hidden_dim=512,
|
||||
embed_dim=512,
|
||||
mlp_ratio=(0.5, 4.0),
|
||||
block_layer=MixerBlock,
|
||||
mlp_layer=Mlp,
|
||||
@ -218,59 +223,71 @@ class MlpMixer(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
||||
|
||||
self.stem = PatchEmbed(
|
||||
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=hidden_dim,
|
||||
norm_layer=norm_layer if stem_norm else None)
|
||||
img_size=img_size, patch_size=patch_size, in_chans=in_chans,
|
||||
embed_dim=embed_dim, norm_layer=norm_layer if stem_norm else None)
|
||||
# FIXME drop_path (stochastic depth scaling rule or all the same?)
|
||||
self.blocks = nn.Sequential(*[
|
||||
block_layer(
|
||||
hidden_dim, self.stem.num_patches, mlp_ratio, mlp_layer=mlp_layer, norm_layer=norm_layer,
|
||||
embed_dim, self.stem.num_patches, mlp_ratio, mlp_layer=mlp_layer, norm_layer=norm_layer,
|
||||
act_layer=act_layer, drop=drop_rate, drop_path=drop_path_rate)
|
||||
for _ in range(num_blocks)])
|
||||
self.norm = norm_layer(hidden_dim)
|
||||
self.head = nn.Linear(hidden_dim, self.num_classes) # zero init
|
||||
self.norm = norm_layer(embed_dim)
|
||||
self.head = nn.Linear(embed_dim, self.num_classes) # zero init
|
||||
|
||||
self.init_weights(nlhb=nlhb)
|
||||
|
||||
def init_weights(self, nlhb=False):
|
||||
head_bias = -math.log(self.num_classes) if nlhb else 0.
|
||||
for n, m in self.named_modules():
|
||||
_init_weights(m, n, head_bias=head_bias)
|
||||
named_apply(partial(_init_weights, head_bias=head_bias), module=self) # depth-first
|
||||
|
||||
def forward(self, x):
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=''):
|
||||
self.num_classes = num_classes
|
||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.stem(x)
|
||||
x = self.blocks(x)
|
||||
x = self.norm(x)
|
||||
x = x.mean(dim=1)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
x = self.head(x)
|
||||
return x
|
||||
|
||||
|
||||
def _init_weights(m, n: str, head_bias: float = 0.):
|
||||
def _init_weights(module: nn.Module, name: str, head_bias: float = 0.):
|
||||
""" Mixer weight initialization (trying to match Flax defaults)
|
||||
"""
|
||||
if isinstance(m, nn.Linear):
|
||||
if n.startswith('head'):
|
||||
nn.init.zeros_(m.weight)
|
||||
nn.init.constant_(m.bias, head_bias)
|
||||
elif n.endswith('gate.proj'):
|
||||
nn.init.normal_(m.weight, std=1e-4)
|
||||
nn.init.ones_(m.bias)
|
||||
if isinstance(module, nn.Linear):
|
||||
if name.startswith('head'):
|
||||
nn.init.zeros_(module.weight)
|
||||
nn.init.constant_(module.bias, head_bias)
|
||||
else:
|
||||
nn.init.xavier_uniform_(m.weight)
|
||||
if m.bias is not None:
|
||||
if 'mlp' in n:
|
||||
nn.init.normal_(m.bias, std=1e-6)
|
||||
nn.init.xavier_uniform_(module.weight)
|
||||
if module.bias is not None:
|
||||
if 'mlp' in name:
|
||||
nn.init.normal_(module.bias, std=1e-6)
|
||||
else:
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
lecun_normal_(m.weight)
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.zeros_(m.bias)
|
||||
nn.init.ones_(m.weight)
|
||||
nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.Conv2d):
|
||||
lecun_normal_(module.weight)
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.ones_(module.weight)
|
||||
nn.init.zeros_(module.bias)
|
||||
elif hasattr(module, 'init_weights'):
|
||||
# NOTE if a parent module contains init_weights method, it can override the init of the
|
||||
# child modules as this will be called in depth-first order.
|
||||
module.init_weights()
|
||||
|
||||
|
||||
def _create_mixer(variant, pretrained=False, **kwargs):
|
||||
@ -289,7 +306,7 @@ def mixer_s32_224(pretrained=False, **kwargs):
|
||||
""" Mixer-S/32 224x224
|
||||
Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
|
||||
"""
|
||||
model_args = dict(patch_size=32, num_blocks=8, hidden_dim=512, **kwargs)
|
||||
model_args = dict(patch_size=32, num_blocks=8, embed_dim=512, **kwargs)
|
||||
model = _create_mixer('mixer_s32_224', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
@ -299,7 +316,7 @@ def mixer_s16_224(pretrained=False, **kwargs):
|
||||
""" Mixer-S/16 224x224
|
||||
Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
|
||||
"""
|
||||
model_args = dict(patch_size=16, num_blocks=8, hidden_dim=512, **kwargs)
|
||||
model_args = dict(patch_size=16, num_blocks=8, embed_dim=512, **kwargs)
|
||||
model = _create_mixer('mixer_s16_224', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
@ -309,7 +326,7 @@ def mixer_b32_224(pretrained=False, **kwargs):
|
||||
""" Mixer-B/32 224x224
|
||||
Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
|
||||
"""
|
||||
model_args = dict(patch_size=32, num_blocks=12, hidden_dim=768, **kwargs)
|
||||
model_args = dict(patch_size=32, num_blocks=12, embed_dim=768, **kwargs)
|
||||
model = _create_mixer('mixer_b32_224', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
@ -319,7 +336,7 @@ def mixer_b16_224(pretrained=False, **kwargs):
|
||||
""" Mixer-B/16 224x224. ImageNet-1k pretrained weights.
|
||||
Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
|
||||
"""
|
||||
model_args = dict(patch_size=16, num_blocks=12, hidden_dim=768, **kwargs)
|
||||
model_args = dict(patch_size=16, num_blocks=12, embed_dim=768, **kwargs)
|
||||
model = _create_mixer('mixer_b16_224', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
@ -329,7 +346,7 @@ def mixer_b16_224_in21k(pretrained=False, **kwargs):
|
||||
""" Mixer-B/16 224x224. ImageNet-21k pretrained weights.
|
||||
Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
|
||||
"""
|
||||
model_args = dict(patch_size=16, num_blocks=12, hidden_dim=768, **kwargs)
|
||||
model_args = dict(patch_size=16, num_blocks=12, embed_dim=768, **kwargs)
|
||||
model = _create_mixer('mixer_b16_224_in21k', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
@ -339,7 +356,7 @@ def mixer_l32_224(pretrained=False, **kwargs):
|
||||
""" Mixer-L/32 224x224.
|
||||
Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
|
||||
"""
|
||||
model_args = dict(patch_size=32, num_blocks=24, hidden_dim=1024, **kwargs)
|
||||
model_args = dict(patch_size=32, num_blocks=24, embed_dim=1024, **kwargs)
|
||||
model = _create_mixer('mixer_l32_224', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
@ -349,7 +366,7 @@ def mixer_l16_224(pretrained=False, **kwargs):
|
||||
""" Mixer-L/16 224x224. ImageNet-1k pretrained weights.
|
||||
Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
|
||||
"""
|
||||
model_args = dict(patch_size=16, num_blocks=24, hidden_dim=1024, **kwargs)
|
||||
model_args = dict(patch_size=16, num_blocks=24, embed_dim=1024, **kwargs)
|
||||
model = _create_mixer('mixer_l16_224', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
@ -359,35 +376,38 @@ def mixer_l16_224_in21k(pretrained=False, **kwargs):
|
||||
""" Mixer-L/16 224x224. ImageNet-21k pretrained weights.
|
||||
Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
|
||||
"""
|
||||
model_args = dict(patch_size=16, num_blocks=24, hidden_dim=1024, **kwargs)
|
||||
model_args = dict(patch_size=16, num_blocks=24, embed_dim=1024, **kwargs)
|
||||
model = _create_mixer('mixer_l16_224_in21k', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mixer_b16_224_miil(pretrained=False, **kwargs):
|
||||
""" Mixer-B/16 224x224. ImageNet-21k pretrained weights.
|
||||
Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
|
||||
"""
|
||||
model_args = dict(patch_size=16, num_blocks=12, hidden_dim=768, **kwargs)
|
||||
model_args = dict(patch_size=16, num_blocks=12, embed_dim=768, **kwargs)
|
||||
model = _create_mixer('mixer_b16_224_miil', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mixer_b16_224_miil_in21k(pretrained=False, **kwargs):
|
||||
""" Mixer-B/16 224x224. ImageNet-1k pretrained weights.
|
||||
Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
|
||||
"""
|
||||
model_args = dict(patch_size=16, num_blocks=12, hidden_dim=768, **kwargs)
|
||||
model_args = dict(patch_size=16, num_blocks=12, embed_dim=768, **kwargs)
|
||||
model = _create_mixer('mixer_b16_224_miil_in21k', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def gmixer_12_224(pretrained=False, **kwargs):
|
||||
""" Glu-Mixer-12 224x224 (short & fat)
|
||||
Experiment by Ross Wightman, adding (Si)GLU to MLP-Mixer
|
||||
"""
|
||||
model_args = dict(
|
||||
patch_size=20, num_blocks=12, hidden_dim=512, mlp_ratio=(1.0, 6.0),
|
||||
patch_size=16, num_blocks=12, embed_dim=512, mlp_ratio=(1.0, 6.0),
|
||||
mlp_layer=GluMlp, act_layer=nn.SiLU, **kwargs)
|
||||
model = _create_mixer('gmixer_12_224', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
@ -399,7 +419,7 @@ def gmixer_24_224(pretrained=False, **kwargs):
|
||||
Experiment by Ross Wightman, adding (Si)GLU to MLP-Mixer
|
||||
"""
|
||||
model_args = dict(
|
||||
patch_size=20, num_blocks=24, hidden_dim=384, mlp_ratio=(1.0, 6.0),
|
||||
patch_size=16, num_blocks=24, embed_dim=384, mlp_ratio=(1.0, 6.0),
|
||||
mlp_layer=GluMlp, act_layer=nn.SiLU, **kwargs)
|
||||
model = _create_mixer('gmixer_24_224', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
@ -411,7 +431,7 @@ def resmlp_12_224(pretrained=False, **kwargs):
|
||||
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
|
||||
"""
|
||||
model_args = dict(
|
||||
patch_size=16, num_blocks=12, hidden_dim=384, mlp_ratio=4, block_layer=ResBlock, norm_layer=Affine, **kwargs)
|
||||
patch_size=16, num_blocks=12, embed_dim=384, mlp_ratio=4, block_layer=ResBlock, norm_layer=Affine, **kwargs)
|
||||
model = _create_mixer('resmlp_12_224', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
@ -422,7 +442,7 @@ def resmlp_24_224(pretrained=False, **kwargs):
|
||||
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
|
||||
"""
|
||||
model_args = dict(
|
||||
patch_size=16, num_blocks=24, hidden_dim=384, mlp_ratio=4,
|
||||
patch_size=16, num_blocks=24, embed_dim=384, mlp_ratio=4,
|
||||
block_layer=partial(ResBlock, init_values=1e-5), norm_layer=Affine, **kwargs)
|
||||
model = _create_mixer('resmlp_24_224', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
@ -434,7 +454,7 @@ def resmlp_36_224(pretrained=False, **kwargs):
|
||||
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
|
||||
"""
|
||||
model_args = dict(
|
||||
patch_size=16, num_blocks=36, hidden_dim=384, mlp_ratio=4,
|
||||
patch_size=16, num_blocks=36, embed_dim=384, mlp_ratio=4,
|
||||
block_layer=partial(ResBlock, init_values=1e-5), norm_layer=Affine, **kwargs)
|
||||
model = _create_mixer('resmlp_36_224', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
@ -446,7 +466,7 @@ def gmlp_ti16_224(pretrained=False, **kwargs):
|
||||
Paper: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050
|
||||
"""
|
||||
model_args = dict(
|
||||
patch_size=16, num_blocks=30, hidden_dim=128, mlp_ratio=6, block_layer=SpatialGatingBlock,
|
||||
patch_size=16, num_blocks=30, embed_dim=128, mlp_ratio=6, block_layer=SpatialGatingBlock,
|
||||
mlp_layer=GatedMlp, **kwargs)
|
||||
model = _create_mixer('gmlp_ti16_224', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
@ -458,7 +478,7 @@ def gmlp_s16_224(pretrained=False, **kwargs):
|
||||
Paper: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050
|
||||
"""
|
||||
model_args = dict(
|
||||
patch_size=16, num_blocks=30, hidden_dim=256, mlp_ratio=6, block_layer=SpatialGatingBlock,
|
||||
patch_size=16, num_blocks=30, embed_dim=256, mlp_ratio=6, block_layer=SpatialGatingBlock,
|
||||
mlp_layer=GatedMlp, **kwargs)
|
||||
model = _create_mixer('gmlp_s16_224', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
@ -470,7 +490,7 @@ def gmlp_b16_224(pretrained=False, **kwargs):
|
||||
Paper: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050
|
||||
"""
|
||||
model_args = dict(
|
||||
patch_size=16, num_blocks=30, hidden_dim=512, mlp_ratio=6, block_layer=SpatialGatingBlock,
|
||||
patch_size=16, num_blocks=30, embed_dim=512, mlp_ratio=6, block_layer=SpatialGatingBlock,
|
||||
mlp_layer=GatedMlp, **kwargs)
|
||||
model = _create_mixer('gmlp_b16_224', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
@ -119,6 +119,7 @@ class MobileNetV3(nn.Module):
|
||||
num_pooled_chs = head_chs * self.global_pool.feat_mult()
|
||||
self.conv_head = create_conv2d(num_pooled_chs, self.num_features, 1, padding=pad_type, bias=head_bias)
|
||||
self.act2 = act_layer(inplace=True)
|
||||
self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
|
||||
self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
efficientnet_init_weights(self)
|
||||
@ -137,6 +138,7 @@ class MobileNetV3(nn.Module):
|
||||
self.num_classes = num_classes
|
||||
# cannot meaningfully change pooling of efficient head after creation
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||
self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
|
||||
self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def forward_features(self, x):
|
||||
@ -151,8 +153,7 @@ class MobileNetV3(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
if not self.global_pool.is_identity():
|
||||
x = x.flatten(1)
|
||||
x = self.flatten(x)
|
||||
if self.drop_rate > 0.:
|
||||
x = F.dropout(x, p=self.drop_rate, training=self.training)
|
||||
return self.classifier(x)
|
||||
|
@ -111,11 +111,11 @@ default_cfgs = dict(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecanfnet_l1_ra2-7dce93cd.pth',
|
||||
pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 320, 320), crop_pct=1.0),
|
||||
eca_nfnet_l2=_dcfg(
|
||||
url='',
|
||||
pool_size=(9, 9), input_size=(3, 288, 288), test_input_size=(3, 352, 352), crop_pct=1.0),
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecanfnet_l2_ra3-da781a61.pth',
|
||||
pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 384, 384), crop_pct=1.0),
|
||||
eca_nfnet_l3=_dcfg(
|
||||
url='',
|
||||
pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 384, 384), crop_pct=1.0),
|
||||
pool_size=(11, 11), input_size=(3, 352, 352), test_input_size=(3, 448, 448), crop_pct=1.0),
|
||||
|
||||
nf_regnet_b0=_dcfg(
|
||||
url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv'),
|
||||
@ -210,6 +210,7 @@ def _dm_nfnet_cfg(depths, channels=(256, 512, 1536, 1536), act_layer='gelu', ski
|
||||
return cfg
|
||||
|
||||
|
||||
|
||||
model_cfgs = dict(
|
||||
# NFNet-F models w/ GELU compatible with DeepMind weights
|
||||
dm_nfnet_f0=_dm_nfnet_cfg(depths=(1, 2, 6, 3)),
|
||||
|
@ -186,12 +186,13 @@ class PoolingVisionTransformer(nn.Module):
|
||||
]
|
||||
self.transformers = SequentialTuple(*transformers)
|
||||
self.norm = nn.LayerNorm(base_dims[-1] * heads[-1], eps=1e-6)
|
||||
self.embed_dim = base_dims[-1] * heads[-1]
|
||||
self.num_features = self.embed_dim = base_dims[-1] * heads[-1]
|
||||
|
||||
# Classifier head
|
||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) \
|
||||
if num_classes > 0 and distilled else nn.Identity()
|
||||
self.head_dist = None
|
||||
if distilled:
|
||||
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
trunc_normal_(self.pos_embed, std=.02)
|
||||
trunc_normal_(self.cls_token, std=.02)
|
||||
@ -207,13 +208,16 @@ class PoolingVisionTransformer(nn.Module):
|
||||
return {'pos_embed', 'cls_token'}
|
||||
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
if self.head_dist is not None:
|
||||
return self.head, self.head_dist
|
||||
else:
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=''):
|
||||
self.num_classes = num_classes
|
||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) \
|
||||
if num_classes > 0 and self.num_tokens == 2 else nn.Identity()
|
||||
if self.head_dist is not None:
|
||||
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.patch_embed(x)
|
||||
@ -221,19 +225,21 @@ class PoolingVisionTransformer(nn.Module):
|
||||
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
|
||||
x, cls_tokens = self.transformers((x, cls_tokens))
|
||||
cls_tokens = self.norm(cls_tokens)
|
||||
return cls_tokens
|
||||
if self.head_dist is not None:
|
||||
return cls_tokens[:, 0], cls_tokens[:, 1]
|
||||
else:
|
||||
return cls_tokens[:, 0]
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
x_cls = self.head(x[:, 0])
|
||||
if self.num_tokens > 1:
|
||||
x_dist = self.head_dist(x[:, 1])
|
||||
if self.head_dist is not None:
|
||||
x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple
|
||||
if self.training and not torch.jit.is_scripting():
|
||||
return x_cls, x_dist
|
||||
return x, x_dist
|
||||
else:
|
||||
return (x_cls + x_dist) / 2
|
||||
return (x + x_dist) / 2
|
||||
else:
|
||||
return x_cls
|
||||
return self.head(x)
|
||||
|
||||
|
||||
def checkpoint_filter_fn(state_dict, model):
|
||||
|
@ -65,11 +65,18 @@ def list_models(filter='', module='', pretrained=False, exclude_filters='', name
|
||||
model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module
|
||||
"""
|
||||
if module:
|
||||
models = list(_module_to_models[module])
|
||||
all_models = list(_module_to_models[module])
|
||||
else:
|
||||
models = _model_entrypoints.keys()
|
||||
all_models = _model_entrypoints.keys()
|
||||
if filter:
|
||||
models = fnmatch.filter(models, filter) # include these models
|
||||
models = []
|
||||
include_filters = filter if isinstance(filter, (tuple, list)) else [filter]
|
||||
for f in include_filters:
|
||||
include_models = fnmatch.filter(all_models, f) # include these models
|
||||
if len(include_models):
|
||||
models = set(models).union(include_models)
|
||||
else:
|
||||
models = all_models
|
||||
if exclude_filters:
|
||||
if not isinstance(exclude_filters, (tuple, list)):
|
||||
exclude_filters = [exclude_filters]
|
||||
|
@ -638,12 +638,15 @@ class ResNet(nn.Module):
|
||||
self.num_features = 512 * block.expansion
|
||||
self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
|
||||
|
||||
self.init_weights(zero_init_last_bn=zero_init_last_bn)
|
||||
|
||||
def init_weights(self, zero_init_last_bn=True):
|
||||
for n, m in self.named_modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1.)
|
||||
nn.init.constant_(m.bias, 0.)
|
||||
nn.init.ones_(m.weight)
|
||||
nn.init.zeros_(m.bias)
|
||||
if zero_init_last_bn:
|
||||
for m in self.modules():
|
||||
if hasattr(m, 'zero_init_last_bn'):
|
||||
|
@ -35,9 +35,9 @@ import torch.nn as nn
|
||||
from functools import partial
|
||||
|
||||
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
from .helpers import build_model_with_cfg
|
||||
from .helpers import build_model_with_cfg, named_apply, adapt_input_conv
|
||||
from .registry import register_model
|
||||
from .layers import GroupNormAct, ClassifierHead, DropPath, AvgPool2dSame, create_pool2d, StdConv2d
|
||||
from .layers import GroupNormAct, ClassifierHead, DropPath, AvgPool2dSame, create_pool2d, StdConv2d, create_conv2d
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
@ -86,20 +86,10 @@ default_cfgs = {
|
||||
url='https://storage.googleapis.com/bit_models/BiT-M-R152x4.npz',
|
||||
num_classes=21843),
|
||||
|
||||
|
||||
# trained on imagenet-1k, NOTE not overly interesting set of weights, leaving disabled for now
|
||||
# 'resnetv2_50x1_bits': _cfg(
|
||||
# url='https://storage.googleapis.com/bit_models/BiT-S-R50x1.npz'),
|
||||
# 'resnetv2_50x3_bits': _cfg(
|
||||
# url='https://storage.googleapis.com/bit_models/BiT-S-R50x3.npz'),
|
||||
# 'resnetv2_101x1_bits': _cfg(
|
||||
# url='https://storage.googleapis.com/bit_models/BiT-S-R101x3.npz'),
|
||||
# 'resnetv2_101x3_bits': _cfg(
|
||||
# url='https://storage.googleapis.com/bit_models/BiT-S-R101x3.npz'),
|
||||
# 'resnetv2_152x2_bits': _cfg(
|
||||
# url='https://storage.googleapis.com/bit_models/BiT-S-R152x2.npz'),
|
||||
# 'resnetv2_152x4_bits': _cfg(
|
||||
# url='https://storage.googleapis.com/bit_models/BiT-S-R152x4.npz'),
|
||||
'resnetv2_50': _cfg(
|
||||
input_size=(3, 224, 224), crop_pct=0.875, interpolation='bicubic'),
|
||||
'resnetv2_50d': _cfg(
|
||||
input_size=(3, 224, 224), crop_pct=0.875, interpolation='bicubic', first_conv='stem.conv1'),
|
||||
}
|
||||
|
||||
|
||||
@ -111,13 +101,6 @@ def make_div(v, divisor=8):
|
||||
return new_v
|
||||
|
||||
|
||||
def tf2th(conv_weights):
|
||||
"""Possibly convert HWIO to OIHW."""
|
||||
if conv_weights.ndim == 4:
|
||||
conv_weights = conv_weights.transpose([3, 2, 0, 1])
|
||||
return torch.from_numpy(conv_weights)
|
||||
|
||||
|
||||
class PreActBottleneck(nn.Module):
|
||||
"""Pre-activation (v2) bottleneck block.
|
||||
|
||||
@ -152,6 +135,9 @@ class PreActBottleneck(nn.Module):
|
||||
self.conv3 = conv_layer(mid_chs, out_chs, 1)
|
||||
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
|
||||
|
||||
def zero_init_last_bn(self):
|
||||
nn.init.zeros_(self.norm3.weight)
|
||||
|
||||
def forward(self, x):
|
||||
x_preact = self.norm1(x)
|
||||
|
||||
@ -198,6 +184,9 @@ class Bottleneck(nn.Module):
|
||||
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
|
||||
self.act3 = act_layer(inplace=True)
|
||||
|
||||
def zero_init_last_bn(self):
|
||||
nn.init.zeros_(self.norm3.weight)
|
||||
|
||||
def forward(self, x):
|
||||
# shortcut branch
|
||||
shortcut = x
|
||||
@ -276,7 +265,7 @@ class ResNetStage(nn.Module):
|
||||
|
||||
def create_resnetv2_stem(
|
||||
in_chs, out_chs=64, stem_type='', preact=True,
|
||||
conv_layer=partial(StdConv2d, eps=1e-8), norm_layer=partial(GroupNormAct, num_groups=32)):
|
||||
conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32)):
|
||||
stem = OrderedDict()
|
||||
assert stem_type in ('', 'fixed', 'same', 'deep', 'deep_fixed', 'deep_same')
|
||||
|
||||
@ -285,14 +274,17 @@ def create_resnetv2_stem(
|
||||
# A 3 deep 3x3 conv stack as in ResNet V1D models
|
||||
mid_chs = out_chs // 2
|
||||
stem['conv1'] = conv_layer(in_chs, mid_chs, kernel_size=3, stride=2)
|
||||
stem['norm1'] = norm_layer(mid_chs)
|
||||
stem['conv2'] = conv_layer(mid_chs, mid_chs, kernel_size=3, stride=1)
|
||||
stem['norm2'] = norm_layer(mid_chs)
|
||||
stem['conv3'] = conv_layer(mid_chs, out_chs, kernel_size=3, stride=1)
|
||||
if not preact:
|
||||
stem['norm3'] = norm_layer(out_chs)
|
||||
else:
|
||||
# The usual 7x7 stem conv
|
||||
stem['conv'] = conv_layer(in_chs, out_chs, kernel_size=7, stride=2)
|
||||
|
||||
if not preact:
|
||||
stem['norm'] = norm_layer(out_chs)
|
||||
if not preact:
|
||||
stem['norm'] = norm_layer(out_chs)
|
||||
|
||||
if 'fixed' in stem_type:
|
||||
# 'fixed' SAME padding approximation that is used in BiT models
|
||||
@ -312,11 +304,12 @@ class ResNetV2(nn.Module):
|
||||
"""Implementation of Pre-activation (v2) ResNet mode.
|
||||
"""
|
||||
|
||||
def __init__(self, layers, channels=(256, 512, 1024, 2048),
|
||||
num_classes=1000, in_chans=3, global_pool='avg', output_stride=32,
|
||||
width_factor=1, stem_chs=64, stem_type='', avg_down=False, preact=True,
|
||||
act_layer=nn.ReLU, conv_layer=partial(StdConv2d, eps=1e-8),
|
||||
norm_layer=partial(GroupNormAct, num_groups=32), drop_rate=0., drop_path_rate=0.):
|
||||
def __init__(
|
||||
self, layers, channels=(256, 512, 1024, 2048),
|
||||
num_classes=1000, in_chans=3, global_pool='avg', output_stride=32,
|
||||
width_factor=1, stem_chs=64, stem_type='', avg_down=False, preact=True,
|
||||
act_layer=nn.ReLU, conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32),
|
||||
drop_rate=0., drop_path_rate=0., zero_init_last_bn=True):
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
self.drop_rate = drop_rate
|
||||
@ -354,12 +347,14 @@ class ResNetV2(nn.Module):
|
||||
self.head = ClassifierHead(
|
||||
self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate, use_conv=True)
|
||||
|
||||
for n, m in self.named_modules():
|
||||
if isinstance(m, nn.Linear) or ('.fc' in n and isinstance(m, nn.Conv2d)):
|
||||
nn.init.normal_(m.weight, mean=0.0, std=0.01)
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
self.init_weights(zero_init_last_bn=zero_init_last_bn)
|
||||
|
||||
def init_weights(self, zero_init_last_bn=True):
|
||||
named_apply(partial(_init_weights, zero_init_last_bn=zero_init_last_bn), self)
|
||||
|
||||
@torch.jit.ignore()
|
||||
def load_pretrained(self, checkpoint_path, prefix='resnet/'):
|
||||
_load_weights(self, checkpoint_path, prefix)
|
||||
|
||||
def get_classifier(self):
|
||||
return self.head.fc
|
||||
@ -378,41 +373,59 @@ class ResNetV2(nn.Module):
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
x = self.head(x)
|
||||
if not self.head.global_pool.is_identity():
|
||||
x = x.flatten(1) # conv classifier, flatten if pooling isn't pass-through (disabled)
|
||||
return x
|
||||
|
||||
def load_pretrained(self, checkpoint_path, prefix='resnet/'):
|
||||
import numpy as np
|
||||
weights = np.load(checkpoint_path)
|
||||
with torch.no_grad():
|
||||
stem_conv_w = tf2th(weights[f'{prefix}root_block/standardized_conv2d/kernel'])
|
||||
if self.stem.conv.weight.shape[1] == 1:
|
||||
self.stem.conv.weight.copy_(stem_conv_w.sum(dim=1, keepdim=True))
|
||||
# FIXME handle > 3 in_chans?
|
||||
else:
|
||||
self.stem.conv.weight.copy_(stem_conv_w)
|
||||
self.norm.weight.copy_(tf2th(weights[f'{prefix}group_norm/gamma']))
|
||||
self.norm.bias.copy_(tf2th(weights[f'{prefix}group_norm/beta']))
|
||||
if self.head.fc.weight.shape[0] == weights[f'{prefix}head/conv2d/kernel'].shape[-1]:
|
||||
self.head.fc.weight.copy_(tf2th(weights[f'{prefix}head/conv2d/kernel']))
|
||||
self.head.fc.bias.copy_(tf2th(weights[f'{prefix}head/conv2d/bias']))
|
||||
for i, (sname, stage) in enumerate(self.stages.named_children()):
|
||||
for j, (bname, block) in enumerate(stage.blocks.named_children()):
|
||||
convname = 'standardized_conv2d'
|
||||
block_prefix = f'{prefix}block{i + 1}/unit{j + 1:02d}/'
|
||||
block.conv1.weight.copy_(tf2th(weights[f'{block_prefix}a/{convname}/kernel']))
|
||||
block.conv2.weight.copy_(tf2th(weights[f'{block_prefix}b/{convname}/kernel']))
|
||||
block.conv3.weight.copy_(tf2th(weights[f'{block_prefix}c/{convname}/kernel']))
|
||||
block.norm1.weight.copy_(tf2th(weights[f'{block_prefix}a/group_norm/gamma']))
|
||||
block.norm2.weight.copy_(tf2th(weights[f'{block_prefix}b/group_norm/gamma']))
|
||||
block.norm3.weight.copy_(tf2th(weights[f'{block_prefix}c/group_norm/gamma']))
|
||||
block.norm1.bias.copy_(tf2th(weights[f'{block_prefix}a/group_norm/beta']))
|
||||
block.norm2.bias.copy_(tf2th(weights[f'{block_prefix}b/group_norm/beta']))
|
||||
block.norm3.bias.copy_(tf2th(weights[f'{block_prefix}c/group_norm/beta']))
|
||||
if block.downsample is not None:
|
||||
w = weights[f'{block_prefix}a/proj/{convname}/kernel']
|
||||
block.downsample.conv.weight.copy_(tf2th(w))
|
||||
|
||||
def _init_weights(module: nn.Module, name: str = '', zero_init_last_bn=True):
|
||||
if isinstance(module, nn.Linear) or ('head.fc' in name and isinstance(module, nn.Conv2d)):
|
||||
nn.init.normal_(module.weight, mean=0.0, std=0.01)
|
||||
nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, (nn.BatchNorm2d, nn.LayerNorm, nn.GroupNorm)):
|
||||
nn.init.ones_(module.weight)
|
||||
nn.init.zeros_(module.bias)
|
||||
elif zero_init_last_bn and hasattr(module, 'zero_init_last_bn'):
|
||||
module.zero_init_last_bn()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def _load_weights(model: nn.Module, checkpoint_path: str, prefix: str = 'resnet/'):
|
||||
import numpy as np
|
||||
|
||||
def t2p(conv_weights):
|
||||
"""Possibly convert HWIO to OIHW."""
|
||||
if conv_weights.ndim == 4:
|
||||
conv_weights = conv_weights.transpose([3, 2, 0, 1])
|
||||
return torch.from_numpy(conv_weights)
|
||||
|
||||
weights = np.load(checkpoint_path)
|
||||
stem_conv_w = adapt_input_conv(
|
||||
model.stem.conv.weight.shape[1], t2p(weights[f'{prefix}root_block/standardized_conv2d/kernel']))
|
||||
model.stem.conv.weight.copy_(stem_conv_w)
|
||||
model.norm.weight.copy_(t2p(weights[f'{prefix}group_norm/gamma']))
|
||||
model.norm.bias.copy_(t2p(weights[f'{prefix}group_norm/beta']))
|
||||
if model.head.fc.weight.shape[0] == weights[f'{prefix}head/conv2d/kernel'].shape[-1]:
|
||||
model.head.fc.weight.copy_(t2p(weights[f'{prefix}head/conv2d/kernel']))
|
||||
model.head.fc.bias.copy_(t2p(weights[f'{prefix}head/conv2d/bias']))
|
||||
for i, (sname, stage) in enumerate(model.stages.named_children()):
|
||||
for j, (bname, block) in enumerate(stage.blocks.named_children()):
|
||||
cname = 'standardized_conv2d'
|
||||
block_prefix = f'{prefix}block{i + 1}/unit{j + 1:02d}/'
|
||||
block.conv1.weight.copy_(t2p(weights[f'{block_prefix}a/{cname}/kernel']))
|
||||
block.conv2.weight.copy_(t2p(weights[f'{block_prefix}b/{cname}/kernel']))
|
||||
block.conv3.weight.copy_(t2p(weights[f'{block_prefix}c/{cname}/kernel']))
|
||||
block.norm1.weight.copy_(t2p(weights[f'{block_prefix}a/group_norm/gamma']))
|
||||
block.norm2.weight.copy_(t2p(weights[f'{block_prefix}b/group_norm/gamma']))
|
||||
block.norm3.weight.copy_(t2p(weights[f'{block_prefix}c/group_norm/gamma']))
|
||||
block.norm1.bias.copy_(t2p(weights[f'{block_prefix}a/group_norm/beta']))
|
||||
block.norm2.bias.copy_(t2p(weights[f'{block_prefix}b/group_norm/beta']))
|
||||
block.norm3.bias.copy_(t2p(weights[f'{block_prefix}c/group_norm/beta']))
|
||||
if block.downsample is not None:
|
||||
w = weights[f'{block_prefix}a/proj/{cname}/kernel']
|
||||
block.downsample.conv.weight.copy_(t2p(w))
|
||||
|
||||
|
||||
def _create_resnetv2(variant, pretrained=False, **kwargs):
|
||||
@ -425,130 +438,99 @@ def _create_resnetv2(variant, pretrained=False, **kwargs):
|
||||
**kwargs)
|
||||
|
||||
|
||||
def _create_resnetv2_bit(variant, pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
variant, pretrained=pretrained, stem_type='fixed', conv_layer=partial(StdConv2d, eps=1e-8), **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_50x1_bitm(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_50x1_bitm', pretrained=pretrained,
|
||||
layers=[3, 4, 6, 3], width_factor=1, stem_type='fixed', **kwargs)
|
||||
return _create_resnetv2_bit(
|
||||
'resnetv2_50x1_bitm', pretrained=pretrained, layers=[3, 4, 6, 3], width_factor=1, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_50x3_bitm(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_50x3_bitm', pretrained=pretrained,
|
||||
layers=[3, 4, 6, 3], width_factor=3, stem_type='fixed', **kwargs)
|
||||
return _create_resnetv2_bit(
|
||||
'resnetv2_50x3_bitm', pretrained=pretrained, layers=[3, 4, 6, 3], width_factor=3, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_101x1_bitm(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_101x1_bitm', pretrained=pretrained,
|
||||
layers=[3, 4, 23, 3], width_factor=1, stem_type='fixed', **kwargs)
|
||||
return _create_resnetv2_bit(
|
||||
'resnetv2_101x1_bitm', pretrained=pretrained, layers=[3, 4, 23, 3], width_factor=1, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_101x3_bitm(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_101x3_bitm', pretrained=pretrained,
|
||||
layers=[3, 4, 23, 3], width_factor=3, stem_type='fixed', **kwargs)
|
||||
return _create_resnetv2_bit(
|
||||
'resnetv2_101x3_bitm', pretrained=pretrained, layers=[3, 4, 23, 3], width_factor=3, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_152x2_bitm(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_152x2_bitm', pretrained=pretrained,
|
||||
layers=[3, 8, 36, 3], width_factor=2, stem_type='fixed', **kwargs)
|
||||
return _create_resnetv2_bit(
|
||||
'resnetv2_152x2_bitm', pretrained=pretrained, layers=[3, 8, 36, 3], width_factor=2, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_152x4_bitm(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_152x4_bitm', pretrained=pretrained,
|
||||
layers=[3, 8, 36, 3], width_factor=4, stem_type='fixed', **kwargs)
|
||||
return _create_resnetv2_bit(
|
||||
'resnetv2_152x4_bitm', pretrained=pretrained, layers=[3, 8, 36, 3], width_factor=4, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_50x1_bitm_in21k(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
return _create_resnetv2_bit(
|
||||
'resnetv2_50x1_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
|
||||
layers=[3, 4, 6, 3], width_factor=1, stem_type='fixed', **kwargs)
|
||||
layers=[3, 4, 6, 3], width_factor=1, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_50x3_bitm_in21k(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
return _create_resnetv2_bit(
|
||||
'resnetv2_50x3_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
|
||||
layers=[3, 4, 6, 3], width_factor=3, stem_type='fixed', **kwargs)
|
||||
layers=[3, 4, 6, 3], width_factor=3, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_101x1_bitm_in21k(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_101x1_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
|
||||
layers=[3, 4, 23, 3], width_factor=1, stem_type='fixed', **kwargs)
|
||||
layers=[3, 4, 23, 3], width_factor=1, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_101x3_bitm_in21k(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
return _create_resnetv2_bit(
|
||||
'resnetv2_101x3_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
|
||||
layers=[3, 4, 23, 3], width_factor=3, stem_type='fixed', **kwargs)
|
||||
layers=[3, 4, 23, 3], width_factor=3, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_152x2_bitm_in21k(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
return _create_resnetv2_bit(
|
||||
'resnetv2_152x2_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
|
||||
layers=[3, 8, 36, 3], width_factor=2, stem_type='fixed', **kwargs)
|
||||
layers=[3, 8, 36, 3], width_factor=2, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_152x4_bitm_in21k(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
return _create_resnetv2_bit(
|
||||
'resnetv2_152x4_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
|
||||
layers=[3, 8, 36, 3], width_factor=4, stem_type='fixed', **kwargs)
|
||||
layers=[3, 8, 36, 3], width_factor=4, **kwargs)
|
||||
|
||||
|
||||
# NOTE the 'S' versions of the model weights arent as interesting as original 21k or transfer to 1K M.
|
||||
@register_model
|
||||
def resnetv2_50(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_50', pretrained=pretrained,
|
||||
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=nn.BatchNorm2d, **kwargs)
|
||||
|
||||
# @register_model
|
||||
# def resnetv2_50x1_bits(pretrained=False, **kwargs):
|
||||
# return _create_resnetv2(
|
||||
# 'resnetv2_50x1_bits', pretrained=pretrained,
|
||||
# layers=[3, 4, 6, 3], width_factor=1, stem_type='fixed', **kwargs)
|
||||
#
|
||||
#
|
||||
# @register_model
|
||||
# def resnetv2_50x3_bits(pretrained=False, **kwargs):
|
||||
# return _create_resnetv2(
|
||||
# 'resnetv2_50x3_bits', pretrained=pretrained,
|
||||
# layers=[3, 4, 6, 3], width_factor=3, stem_type='fixed', **kwargs)
|
||||
#
|
||||
#
|
||||
# @register_model
|
||||
# def resnetv2_101x1_bits(pretrained=False, **kwargs):
|
||||
# return _create_resnetv2(
|
||||
# 'resnetv2_101x1_bits', pretrained=pretrained,
|
||||
# layers=[3, 4, 23, 3], width_factor=1, stem_type='fixed', **kwargs)
|
||||
#
|
||||
#
|
||||
# @register_model
|
||||
# def resnetv2_101x3_bits(pretrained=False, **kwargs):
|
||||
# return _create_resnetv2(
|
||||
# 'resnetv2_101x3_bits', pretrained=pretrained,
|
||||
# layers=[3, 4, 23, 3], width_factor=3, stem_type='fixed', **kwargs)
|
||||
#
|
||||
#
|
||||
# @register_model
|
||||
# def resnetv2_152x2_bits(pretrained=False, **kwargs):
|
||||
# return _create_resnetv2(
|
||||
# 'resnetv2_152x2_bits', pretrained=pretrained,
|
||||
# layers=[3, 8, 36, 3], width_factor=2, stem_type='fixed', **kwargs)
|
||||
#
|
||||
#
|
||||
# @register_model
|
||||
# def resnetv2_152x4_bits(pretrained=False, **kwargs):
|
||||
# return _create_resnetv2(
|
||||
# 'resnetv2_152x4_bits', pretrained=pretrained,
|
||||
# layers=[3, 8, 36, 3], width_factor=4, stem_type='fixed', **kwargs)
|
||||
#
|
||||
|
||||
@register_model
|
||||
def resnetv2_50d(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_50d', pretrained=pretrained,
|
||||
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=nn.BatchNorm2d,
|
||||
stem_type='deep', avg_down=True, **kwargs)
|
||||
|
@ -126,19 +126,18 @@ class WindowAttention(nn.Module):
|
||||
window_size (tuple[int]): The height and width of the window.
|
||||
num_heads (int): Number of attention heads.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
||||
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
||||
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
||||
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
||||
"""
|
||||
|
||||
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
|
||||
def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.):
|
||||
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.window_size = window_size # Wh, Ww
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = qk_scale or head_dim ** -0.5
|
||||
self.scale = head_dim ** -0.5
|
||||
|
||||
# define a parameter table of relative position bias
|
||||
self.relative_position_bias_table = nn.Parameter(
|
||||
@ -210,7 +209,6 @@ class SwinTransformerBlock(nn.Module):
|
||||
shift_size (int): Shift size for SW-MSA.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
||||
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
||||
drop (float, optional): Dropout rate. Default: 0.0
|
||||
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
||||
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
||||
@ -219,7 +217,7 @@ class SwinTransformerBlock(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
|
||||
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
|
||||
mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
|
||||
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
@ -236,8 +234,8 @@ class SwinTransformerBlock(nn.Module):
|
||||
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = WindowAttention(
|
||||
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
|
||||
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
||||
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, qkv_bias=qkv_bias,
|
||||
attn_drop=attn_drop, proj_drop=drop)
|
||||
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
@ -369,7 +367,6 @@ class BasicLayer(nn.Module):
|
||||
window_size (int): Local window size.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
||||
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
||||
drop (float, optional): Dropout rate. Default: 0.0
|
||||
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
||||
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
||||
@ -379,7 +376,7 @@ class BasicLayer(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
|
||||
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
|
||||
mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
|
||||
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
|
||||
|
||||
super().__init__()
|
||||
@ -390,14 +387,11 @@ class BasicLayer(nn.Module):
|
||||
|
||||
# build blocks
|
||||
self.blocks = nn.ModuleList([
|
||||
SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
|
||||
num_heads=num_heads, window_size=window_size,
|
||||
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
||||
drop=drop, attn_drop=attn_drop,
|
||||
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
||||
norm_layer=norm_layer)
|
||||
SwinTransformerBlock(
|
||||
dim=dim, input_resolution=input_resolution, num_heads=num_heads, window_size=window_size,
|
||||
shift_size=0 if (i % 2 == 0) else window_size // 2, mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias, drop=drop, attn_drop=attn_drop,
|
||||
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer)
|
||||
for i in range(depth)])
|
||||
|
||||
# patch merging layer
|
||||
@ -436,7 +430,6 @@ class SwinTransformer(nn.Module):
|
||||
window_size (int): Window size. Default: 7
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
|
||||
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
||||
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
|
||||
drop_rate (float): Dropout rate. Default: 0
|
||||
attn_drop_rate (float): Attention dropout rate. Default: 0
|
||||
drop_path_rate (float): Stochastic depth rate. Default: 0.1
|
||||
@ -448,7 +441,7 @@ class SwinTransformer(nn.Module):
|
||||
|
||||
def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
|
||||
embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24),
|
||||
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
|
||||
window_size=7, mlp_ratio=4., qkv_bias=True,
|
||||
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
||||
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
|
||||
use_checkpoint=False, weight_init='', **kwargs):
|
||||
@ -491,8 +484,9 @@ class SwinTransformer(nn.Module):
|
||||
num_heads=num_heads[i_layer],
|
||||
window_size=window_size,
|
||||
mlp_ratio=self.mlp_ratio,
|
||||
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
||||
drop=drop_rate, attn_drop=attn_drop_rate,
|
||||
qkv_bias=qkv_bias,
|
||||
drop=drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
|
||||
norm_layer=norm_layer,
|
||||
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
|
||||
@ -520,6 +514,13 @@ class SwinTransformer(nn.Module):
|
||||
def no_weight_decay_keywords(self):
|
||||
return {'relative_position_bias_table'}
|
||||
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=''):
|
||||
self.num_classes = num_classes
|
||||
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.patch_embed(x)
|
||||
if self.absolute_pos_embed is not None:
|
||||
|
@ -278,6 +278,8 @@ class Twins(nn.Module):
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
self.depths = depths
|
||||
self.embed_dims = embed_dims
|
||||
self.num_features = embed_dims[-1]
|
||||
|
||||
img_size = to_2tuple(img_size)
|
||||
prev_chs = in_chans
|
||||
@ -303,10 +305,10 @@ class Twins(nn.Module):
|
||||
|
||||
self.pos_block = nn.ModuleList([PosConv(embed_dim, embed_dim) for embed_dim in embed_dims])
|
||||
|
||||
self.norm = norm_layer(embed_dims[-1])
|
||||
self.norm = norm_layer(self.num_features)
|
||||
|
||||
# classification head
|
||||
self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity()
|
||||
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
# init weights
|
||||
self.apply(self._init_weights)
|
||||
@ -320,7 +322,7 @@ class Twins(nn.Module):
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=''):
|
||||
self.num_classes = num_classes
|
||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
|
@ -13,7 +13,7 @@ import torch.nn.functional as F
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .helpers import build_model_with_cfg, overlay_external_default_cfg
|
||||
from .layers import to_2tuple, trunc_normal_, DropPath, PatchEmbed, LayerNorm2d
|
||||
from .layers import to_2tuple, trunc_normal_, DropPath, PatchEmbed, LayerNorm2d, create_classifier
|
||||
from .registry import register_model
|
||||
|
||||
|
||||
@ -140,14 +140,14 @@ class Visformer(nn.Module):
|
||||
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, init_channels=32, embed_dim=384,
|
||||
depth=12, num_heads=6, mlp_ratio=4., drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
|
||||
norm_layer=LayerNorm2d, attn_stage='111', pos_embed=True, spatial_conv='111',
|
||||
vit_stem=False, group=8, pool=True, conv_init=False, embed_norm=None):
|
||||
vit_stem=False, group=8, global_pool='avg', conv_init=False, embed_norm=None):
|
||||
super().__init__()
|
||||
img_size = to_2tuple(img_size)
|
||||
self.num_classes = num_classes
|
||||
self.num_features = self.embed_dim = embed_dim
|
||||
self.embed_dim = embed_dim
|
||||
self.init_channels = init_channels
|
||||
self.img_size = img_size
|
||||
self.vit_stem = vit_stem
|
||||
self.pool = pool
|
||||
self.conv_init = conv_init
|
||||
if isinstance(depth, (list, tuple)):
|
||||
self.stage_num1, self.stage_num2, self.stage_num3 = depth
|
||||
@ -164,31 +164,31 @@ class Visformer(nn.Module):
|
||||
self.patch_embed1 = PatchEmbed(
|
||||
img_size=img_size, patch_size=patch_size, in_chans=in_chans,
|
||||
embed_dim=embed_dim, norm_layer=embed_norm, flatten=False)
|
||||
img_size //= 16
|
||||
img_size = [x // 16 for x in img_size]
|
||||
else:
|
||||
if self.init_channels is None:
|
||||
self.stem = None
|
||||
self.patch_embed1 = PatchEmbed(
|
||||
img_size=img_size, patch_size=patch_size // 2, in_chans=in_chans,
|
||||
embed_dim=embed_dim // 2, norm_layer=embed_norm, flatten=False)
|
||||
img_size //= 8
|
||||
img_size = [x // 8 for x in img_size]
|
||||
else:
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(in_chans, self.init_channels, 7, stride=2, padding=3, bias=False),
|
||||
nn.BatchNorm2d(self.init_channels),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
img_size //= 2
|
||||
img_size = [x // 2 for x in img_size]
|
||||
self.patch_embed1 = PatchEmbed(
|
||||
img_size=img_size, patch_size=patch_size // 4, in_chans=self.init_channels,
|
||||
embed_dim=embed_dim // 2, norm_layer=embed_norm, flatten=False)
|
||||
img_size //= 4
|
||||
img_size = [x // 4 for x in img_size]
|
||||
|
||||
if self.pos_embed:
|
||||
if self.vit_stem:
|
||||
self.pos_embed1 = nn.Parameter(torch.zeros(1, embed_dim, img_size, img_size))
|
||||
self.pos_embed1 = nn.Parameter(torch.zeros(1, embed_dim, *img_size))
|
||||
else:
|
||||
self.pos_embed1 = nn.Parameter(torch.zeros(1, embed_dim//2, img_size, img_size))
|
||||
self.pos_embed1 = nn.Parameter(torch.zeros(1, embed_dim//2, *img_size))
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
self.stage1 = nn.ModuleList([
|
||||
Block(
|
||||
@ -199,14 +199,14 @@ class Visformer(nn.Module):
|
||||
for i in range(self.stage_num1)
|
||||
])
|
||||
|
||||
#stage2
|
||||
# stage2
|
||||
if not self.vit_stem:
|
||||
self.patch_embed2 = PatchEmbed(
|
||||
img_size=img_size, patch_size=patch_size // 8, in_chans=embed_dim // 2,
|
||||
embed_dim=embed_dim, norm_layer=embed_norm, flatten=False)
|
||||
img_size //= 2
|
||||
img_size = [x // 2 for x in img_size]
|
||||
if self.pos_embed:
|
||||
self.pos_embed2 = nn.Parameter(torch.zeros(1, embed_dim, img_size, img_size))
|
||||
self.pos_embed2 = nn.Parameter(torch.zeros(1, embed_dim, *img_size))
|
||||
self.stage2 = nn.ModuleList([
|
||||
Block(
|
||||
dim=embed_dim, num_heads=num_heads, head_dim_ratio=1.0, mlp_ratio=mlp_ratio,
|
||||
@ -221,9 +221,9 @@ class Visformer(nn.Module):
|
||||
self.patch_embed3 = PatchEmbed(
|
||||
img_size=img_size, patch_size=patch_size // 8, in_chans=embed_dim,
|
||||
embed_dim=embed_dim * 2, norm_layer=embed_norm, flatten=False)
|
||||
img_size //= 2
|
||||
img_size = [x // 2 for x in img_size]
|
||||
if self.pos_embed:
|
||||
self.pos_embed3 = nn.Parameter(torch.zeros(1, embed_dim*2, img_size, img_size))
|
||||
self.pos_embed3 = nn.Parameter(torch.zeros(1, embed_dim*2, *img_size))
|
||||
self.stage3 = nn.ModuleList([
|
||||
Block(
|
||||
dim=embed_dim*2, num_heads=num_heads, head_dim_ratio=1.0, mlp_ratio=mlp_ratio,
|
||||
@ -234,11 +234,10 @@ class Visformer(nn.Module):
|
||||
])
|
||||
|
||||
# head
|
||||
if self.pool:
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
head_dim = embed_dim if self.vit_stem else embed_dim * 2
|
||||
self.norm = norm_layer(head_dim)
|
||||
self.head = nn.Linear(head_dim, num_classes)
|
||||
self.num_features = embed_dim if self.vit_stem else embed_dim * 2
|
||||
self.norm = norm_layer(self.num_features)
|
||||
self.global_pool, self.head = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
|
||||
self.head = nn.Linear(self.num_features, num_classes)
|
||||
|
||||
# weights init
|
||||
if self.pos_embed:
|
||||
@ -267,7 +266,14 @@ class Visformer(nn.Module):
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0.)
|
||||
|
||||
def forward(self, x):
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||
self.num_classes = num_classes
|
||||
self.global_pool, self.head = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
|
||||
|
||||
def forward_features(self, x):
|
||||
if self.stem is not None:
|
||||
x = self.stem(x)
|
||||
|
||||
@ -297,14 +303,13 @@ class Visformer(nn.Module):
|
||||
for b in self.stage3:
|
||||
x = b(x)
|
||||
|
||||
# head
|
||||
x = self.norm(x)
|
||||
if self.pool:
|
||||
x = self.global_pooling(x)
|
||||
else:
|
||||
x = x[:, :, 0, 0]
|
||||
return x
|
||||
|
||||
x = self.head(x.view(x.size(0), -1))
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
x = self.global_pool(x)
|
||||
x = self.head(x)
|
||||
return x
|
||||
|
||||
|
||||
@ -321,7 +326,7 @@ def _create_visformer(variant, pretrained=False, default_cfg=None, **kwargs):
|
||||
@register_model
|
||||
def visformer_tiny(pretrained=False, **kwargs):
|
||||
model_cfg = dict(
|
||||
img_size=224, init_channels=16, embed_dim=192, depth=(7, 4, 4), num_heads=3, mlp_ratio=4., group=8,
|
||||
init_channels=16, embed_dim=192, depth=(7, 4, 4), num_heads=3, mlp_ratio=4., group=8,
|
||||
attn_stage='011', spatial_conv='100', norm_layer=nn.BatchNorm2d, conv_init=True,
|
||||
embed_norm=nn.BatchNorm2d, **kwargs)
|
||||
model = _create_visformer('visformer_tiny', pretrained=pretrained, **model_cfg)
|
||||
@ -331,7 +336,7 @@ def visformer_tiny(pretrained=False, **kwargs):
|
||||
@register_model
|
||||
def visformer_small(pretrained=False, **kwargs):
|
||||
model_cfg = dict(
|
||||
img_size=224, init_channels=32, embed_dim=384, depth=(7, 4, 4), num_heads=6, mlp_ratio=4., group=8,
|
||||
init_channels=32, embed_dim=384, depth=(7, 4, 4), num_heads=6, mlp_ratio=4., group=8,
|
||||
attn_stage='011', spatial_conv='100', norm_layer=nn.BatchNorm2d, conv_init=True,
|
||||
embed_norm=nn.BatchNorm2d, **kwargs)
|
||||
model = _create_visformer('visformer_small', pretrained=pretrained, **model_cfg)
|
||||
|
@ -28,7 +28,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .helpers import build_model_with_cfg, overlay_external_default_cfg
|
||||
from .helpers import build_model_with_cfg, named_apply, adapt_input_conv
|
||||
from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_
|
||||
from .registry import register_model
|
||||
|
||||
@ -47,9 +47,18 @@ def _cfg(url='', **kwargs):
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
# patch models (my experiments)
|
||||
# FIXME weights coming
|
||||
'vit_tiny_patch16_224': _cfg(
|
||||
url='',
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
||||
),
|
||||
'vit_small_patch16_224': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth',
|
||||
url='',
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
||||
),
|
||||
'vit_small_patch32_224': _cfg(
|
||||
url='',
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
||||
),
|
||||
|
||||
# patch models (weights ported from official Google JAX impl)
|
||||
@ -97,29 +106,29 @@ default_cfgs = {
|
||||
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
||||
|
||||
# deit models (FB weights)
|
||||
'vit_deit_tiny_patch16_224': _cfg(
|
||||
'deit_tiny_patch16_224': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth'),
|
||||
'vit_deit_small_patch16_224': _cfg(
|
||||
'deit_small_patch16_224': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth'),
|
||||
'vit_deit_base_patch16_224': _cfg(
|
||||
'deit_base_patch16_224': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth',),
|
||||
'vit_deit_base_patch16_384': _cfg(
|
||||
'deit_base_patch16_384': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth',
|
||||
input_size=(3, 384, 384), crop_pct=1.0),
|
||||
'vit_deit_tiny_distilled_patch16_224': _cfg(
|
||||
'deit_tiny_distilled_patch16_224': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth',
|
||||
classifier=('head', 'head_dist')),
|
||||
'vit_deit_small_distilled_patch16_224': _cfg(
|
||||
'deit_small_distilled_patch16_224': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth',
|
||||
classifier=('head', 'head_dist')),
|
||||
'vit_deit_base_distilled_patch16_224': _cfg(
|
||||
'deit_base_distilled_patch16_224': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth',
|
||||
classifier=('head', 'head_dist')),
|
||||
'vit_deit_base_distilled_patch16_384': _cfg(
|
||||
'deit_base_distilled_patch16_384': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth',
|
||||
input_size=(3, 384, 384), crop_pct=1.0, classifier=('head', 'head_dist')),
|
||||
|
||||
# ViT ImageNet-21K-P pretraining
|
||||
# ViT ImageNet-21K-P pretraining by MILL
|
||||
'vit_base_patch16_224_miil_in21k': _cfg(
|
||||
url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/vit_base_patch16_224_in21k_miil.pth',
|
||||
mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', num_classes=11221,
|
||||
@ -133,11 +142,11 @@ default_cfgs = {
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
||||
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = qk_scale or head_dim ** -0.5
|
||||
self.scale = head_dim ** -0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
@ -161,12 +170,11 @@ class Attention(nn.Module):
|
||||
|
||||
class Block(nn.Module):
|
||||
|
||||
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
||||
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
|
||||
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = Attention(
|
||||
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
||||
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
|
||||
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
@ -190,7 +198,7 @@ class VisionTransformer(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
||||
num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, distilled=False,
|
||||
num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False,
|
||||
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None,
|
||||
act_layer=None, weight_init=''):
|
||||
"""
|
||||
@ -204,7 +212,6 @@ class VisionTransformer(nn.Module):
|
||||
num_heads (int): number of attention heads
|
||||
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
||||
qkv_bias (bool): enable bias for qkv if True
|
||||
qk_scale (float): override default qk scale of head_dim ** -0.5 if set
|
||||
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
|
||||
distilled (bool): model includes a distillation token and head as in DeiT models
|
||||
drop_rate (float): dropout rate
|
||||
@ -233,8 +240,8 @@ class VisionTransformer(nn.Module):
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
||||
self.blocks = nn.Sequential(*[
|
||||
Block(
|
||||
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
||||
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer)
|
||||
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate,
|
||||
attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer)
|
||||
for i in range(depth)])
|
||||
self.norm = norm_layer(embed_dim)
|
||||
|
||||
@ -254,16 +261,17 @@ class VisionTransformer(nn.Module):
|
||||
if distilled:
|
||||
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
# Weight init
|
||||
assert weight_init in ('jax', 'jax_nlhb', 'nlhb', '')
|
||||
head_bias = -math.log(self.num_classes) if 'nlhb' in weight_init else 0.
|
||||
self.init_weights(weight_init)
|
||||
|
||||
def init_weights(self, mode=''):
|
||||
assert mode in ('jax', 'jax_nlhb', 'nlhb', '')
|
||||
head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
|
||||
trunc_normal_(self.pos_embed, std=.02)
|
||||
if self.dist_token is not None:
|
||||
trunc_normal_(self.dist_token, std=.02)
|
||||
if weight_init.startswith('jax'):
|
||||
if mode.startswith('jax'):
|
||||
# leave cls token as zeros to match jax impl
|
||||
for n, m in self.named_modules():
|
||||
_init_vit_weights(m, n, head_bias=head_bias, jax_impl=True)
|
||||
named_apply(partial(_init_vit_weights, head_bias=head_bias, jax_impl=True), self)
|
||||
else:
|
||||
trunc_normal_(self.cls_token, std=.02)
|
||||
self.apply(_init_vit_weights)
|
||||
@ -272,6 +280,10 @@ class VisionTransformer(nn.Module):
|
||||
# this fn left here for compat with downstream users
|
||||
_init_vit_weights(m)
|
||||
|
||||
@torch.jit.ignore()
|
||||
def load_pretrained(self, checkpoint_path, prefix=''):
|
||||
_load_weights(self, checkpoint_path, prefix)
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay(self):
|
||||
return {'pos_embed', 'cls_token', 'dist_token'}
|
||||
@ -317,39 +329,92 @@ class VisionTransformer(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def _init_vit_weights(m, n: str = '', head_bias: float = 0., jax_impl: bool = False):
|
||||
def _init_vit_weights(module: nn.Module, name: str = '', head_bias: float = 0., jax_impl: bool = False):
|
||||
""" ViT weight initialization
|
||||
* When called without n, head_bias, jax_impl args it will behave exactly the same
|
||||
as my original init for compatibility with prev hparam / downstream use cases (ie DeiT).
|
||||
* When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl
|
||||
"""
|
||||
if isinstance(m, nn.Linear):
|
||||
if n.startswith('head'):
|
||||
nn.init.zeros_(m.weight)
|
||||
nn.init.constant_(m.bias, head_bias)
|
||||
elif n.startswith('pre_logits'):
|
||||
lecun_normal_(m.weight)
|
||||
nn.init.zeros_(m.bias)
|
||||
if isinstance(module, nn.Linear):
|
||||
if name.startswith('head'):
|
||||
nn.init.zeros_(module.weight)
|
||||
nn.init.constant_(module.bias, head_bias)
|
||||
elif name.startswith('pre_logits'):
|
||||
lecun_normal_(module.weight)
|
||||
nn.init.zeros_(module.bias)
|
||||
else:
|
||||
if jax_impl:
|
||||
nn.init.xavier_uniform_(m.weight)
|
||||
if m.bias is not None:
|
||||
if 'mlp' in n:
|
||||
nn.init.normal_(m.bias, std=1e-6)
|
||||
nn.init.xavier_uniform_(module.weight)
|
||||
if module.bias is not None:
|
||||
if 'mlp' in name:
|
||||
nn.init.normal_(module.bias, std=1e-6)
|
||||
else:
|
||||
nn.init.zeros_(m.bias)
|
||||
nn.init.zeros_(module.bias)
|
||||
else:
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
elif jax_impl and isinstance(m, nn.Conv2d):
|
||||
trunc_normal_(module.weight, std=.02)
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
elif jax_impl and isinstance(module, nn.Conv2d):
|
||||
# NOTE conv was left to pytorch default in my original init
|
||||
lecun_normal_(m.weight)
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.zeros_(m.bias)
|
||||
nn.init.ones_(m.weight)
|
||||
lecun_normal_(module.weight)
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
|
||||
nn.init.zeros_(module.bias)
|
||||
nn.init.ones_(module.weight)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
|
||||
""" Load weights from .npz checkpoints for official Google Brain Flax implementation
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
def _n2p(w, t=True):
|
||||
if t and w.ndim == 4:
|
||||
w = w.transpose([3, 2, 0, 1])
|
||||
elif t and w.ndim == 3:
|
||||
w = w.transpose([2, 0, 1])
|
||||
elif t and w.ndim == 2:
|
||||
w = w.transpose([1, 0])
|
||||
return torch.from_numpy(w)
|
||||
|
||||
w = np.load(checkpoint_path)
|
||||
if not prefix:
|
||||
prefix = 'opt/target/' if 'opt/target/embedding/kernel' in w else prefix
|
||||
|
||||
input_conv_w = adapt_input_conv(
|
||||
model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
|
||||
model.patch_embed.proj.weight.copy_(input_conv_w)
|
||||
model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
|
||||
model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
|
||||
model.pos_embed.copy_(_n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False))
|
||||
model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
|
||||
model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
|
||||
if model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
|
||||
model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
|
||||
model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
|
||||
for i, block in enumerate(model.blocks.children()):
|
||||
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
|
||||
block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
|
||||
block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
|
||||
mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
|
||||
block.attn.qkv.weight.copy_(torch.cat([
|
||||
_n2p(w[f'{mha_prefix}query/kernel'], t=False).flatten(1).T,
|
||||
_n2p(w[f'{mha_prefix}key/kernel'], t=False).flatten(1).T,
|
||||
_n2p(w[f'{mha_prefix}value/kernel'], t=False).flatten(1).T]))
|
||||
block.attn.qkv.bias.copy_(torch.cat([
|
||||
_n2p(w[f'{mha_prefix}query/bias'], t=False).reshape(-1),
|
||||
_n2p(w[f'{mha_prefix}key/bias'], t=False).reshape(-1),
|
||||
_n2p(w[f'{mha_prefix}value/bias'], t=False).reshape(-1)]))
|
||||
block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
|
||||
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
|
||||
block.mlp.fc1.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_0/kernel']))
|
||||
block.mlp.fc1.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_0/bias']))
|
||||
block.mlp.fc2.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_1/kernel']))
|
||||
block.mlp.fc2.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_1/bias']))
|
||||
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
|
||||
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
|
||||
|
||||
|
||||
def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()):
|
||||
@ -418,22 +483,33 @@ def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kw
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_small_patch16_224(pretrained=False, **kwargs):
|
||||
""" My custom 'small' ViT model. embed_dim=768, depth=8, num_heads=8, mlp_ratio=3.
|
||||
NOTE:
|
||||
* this differs from the DeiT based 'small' definitions with embed_dim=384, depth=12, num_heads=6
|
||||
* this model does not have a bias for QKV (unlike the official ViT and DeiT models)
|
||||
def vit_tiny_patch16_224(pretrained=False, **kwargs):
|
||||
""" ViT-Tiny (Vit-Ti/16)
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=768, depth=8, num_heads=8, mlp_ratio=3.,
|
||||
qkv_bias=False, norm_layer=nn.LayerNorm, **kwargs)
|
||||
if pretrained:
|
||||
# NOTE my scale was wrong for original weights, leaving this here until I have better ones for this model
|
||||
model_kwargs.setdefault('qk_scale', 768 ** -0.5)
|
||||
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
|
||||
model = _create_vision_transformer('vit_tiny_patch16_224', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_small_patch16_224(pretrained=False, **kwargs):
|
||||
""" ViT-Small (ViT-S/16)
|
||||
NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper
|
||||
"""
|
||||
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
|
||||
model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_small_patch32_224(pretrained=False, **kwargs):
|
||||
""" ViT-Small (ViT-S/32)
|
||||
"""
|
||||
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
|
||||
model = _create_vision_transformer('vit_small_patch32_224', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_patch16_224(pretrained=False, **kwargs):
|
||||
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
|
||||
@ -569,86 +645,86 @@ def vit_huge_patch14_224_in21k(pretrained=False, **kwargs):
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_deit_tiny_patch16_224(pretrained=False, **kwargs):
|
||||
def deit_tiny_patch16_224(pretrained=False, **kwargs):
|
||||
""" DeiT-tiny model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
|
||||
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
|
||||
model = _create_vision_transformer('vit_deit_tiny_patch16_224', pretrained=pretrained, **model_kwargs)
|
||||
model = _create_vision_transformer('deit_tiny_patch16_224', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_deit_small_patch16_224(pretrained=False, **kwargs):
|
||||
def deit_small_patch16_224(pretrained=False, **kwargs):
|
||||
""" DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
|
||||
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
|
||||
model = _create_vision_transformer('vit_deit_small_patch16_224', pretrained=pretrained, **model_kwargs)
|
||||
model = _create_vision_transformer('deit_small_patch16_224', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_deit_base_patch16_224(pretrained=False, **kwargs):
|
||||
def deit_base_patch16_224(pretrained=False, **kwargs):
|
||||
""" DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
|
||||
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
||||
model = _create_vision_transformer('vit_deit_base_patch16_224', pretrained=pretrained, **model_kwargs)
|
||||
model = _create_vision_transformer('deit_base_patch16_224', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_deit_base_patch16_384(pretrained=False, **kwargs):
|
||||
def deit_base_patch16_384(pretrained=False, **kwargs):
|
||||
""" DeiT base model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
|
||||
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
||||
model = _create_vision_transformer('vit_deit_base_patch16_384', pretrained=pretrained, **model_kwargs)
|
||||
model = _create_vision_transformer('deit_base_patch16_384', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_deit_tiny_distilled_patch16_224(pretrained=False, **kwargs):
|
||||
def deit_tiny_distilled_patch16_224(pretrained=False, **kwargs):
|
||||
""" DeiT-tiny distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
|
||||
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
|
||||
model = _create_vision_transformer(
|
||||
'vit_deit_tiny_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
|
||||
'deit_tiny_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_deit_small_distilled_patch16_224(pretrained=False, **kwargs):
|
||||
def deit_small_distilled_patch16_224(pretrained=False, **kwargs):
|
||||
""" DeiT-small distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
|
||||
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
|
||||
model = _create_vision_transformer(
|
||||
'vit_deit_small_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
|
||||
'deit_small_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_deit_base_distilled_patch16_224(pretrained=False, **kwargs):
|
||||
def deit_base_distilled_patch16_224(pretrained=False, **kwargs):
|
||||
""" DeiT-base distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
|
||||
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
||||
model = _create_vision_transformer(
|
||||
'vit_deit_base_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
|
||||
'deit_base_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_deit_base_distilled_patch16_384(pretrained=False, **kwargs):
|
||||
def deit_base_distilled_patch16_384(pretrained=False, **kwargs):
|
||||
""" DeiT-base distilled model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
|
||||
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
||||
model = _create_vision_transformer(
|
||||
'vit_deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **model_kwargs)
|
||||
'deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
|
@ -46,8 +46,8 @@ default_cfgs = {
|
||||
input_size=(3, 384, 384), crop_pct=1.0),
|
||||
|
||||
# hybrid in-1k models (mostly untrained, experimental configs w/ resnetv2 stdconv backbones)
|
||||
'vit_tiny_r_s16_p8_224': _cfg(),
|
||||
'vit_small_r_s16_p8_224': _cfg(),
|
||||
'vit_tiny_r_s16_p8_224': _cfg(first_conv='patch_embed.backbone.conv'),
|
||||
'vit_small_r_s16_p8_224': _cfg(first_conv='patch_embed.backbone.conv'),
|
||||
'vit_small_r20_s16_p2_224': _cfg(),
|
||||
'vit_small_r20_s16_224': _cfg(),
|
||||
'vit_small_r26_s32_224': _cfg(),
|
||||
@ -57,10 +57,14 @@ default_cfgs = {
|
||||
'vit_large_r50_s32_224': _cfg(),
|
||||
|
||||
# hybrid models (using timm resnet backbones)
|
||||
'vit_small_resnet26d_224': _cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
'vit_small_resnet50d_s16_224': _cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
'vit_base_resnet26d_224': _cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
'vit_base_resnet50d_224': _cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
'vit_small_resnet26d_224': _cfg(
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
|
||||
'vit_small_resnet50d_s16_224': _cfg(
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
|
||||
'vit_base_resnet26d_224': _cfg(
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
|
||||
'vit_base_resnet50d_224': _cfg(
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
|
||||
}
|
||||
|
||||
|
||||
@ -140,12 +144,6 @@ def vit_base_r50_s16_224_in21k(pretrained=False, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_resnet50_224_in21k(pretrained=False, **kwargs):
|
||||
# NOTE this is forwarding to model def above for backwards compatibility
|
||||
return vit_base_r50_s16_224_in21k(pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_r50_s16_384(pretrained=False, **kwargs):
|
||||
""" R50+ViT-B/16 hybrid from original paper (https://arxiv.org/abs/2010.11929).
|
||||
@ -158,12 +156,6 @@ def vit_base_r50_s16_384(pretrained=False, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_resnet50_384(pretrained=False, **kwargs):
|
||||
# NOTE this is forwarding to model def above for backwards compatibility
|
||||
return vit_base_r50_s16_384(pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_tiny_r_s16_p8_224(pretrained=False, **kwargs):
|
||||
""" R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 224 x 224.
|
||||
|
Loading…
x
Reference in New Issue
Block a user