Merge pull request #702 from rwightman/cleanup_xla_model_fixes
AugReg Vision Transformers, XLA model compat for ResNetV2-BiT / NFNet, ECA-NFNet-L2, GMixer-24 weights, ResMLP official weights, and cleanuppull/714/head
commit
79927baaec
19
README.md
19
README.md
|
@ -23,6 +23,25 @@ I'm fortunate to be able to dedicate significant time and money of my own suppor
|
|||
|
||||
## What's New
|
||||
|
||||
### June 20, 2021
|
||||
* Release Vision Transformer 'AugReg' weights from [How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers](https://arxiv.org/abs/2106.10270)
|
||||
* .npz weight loading support added, can load any of the 50K+ weights from the [AugReg series](https://console.cloud.google.com/storage/browser/vit_models/augreg)
|
||||
* See [example notebook](https://colab.research.google.com/github/google-research/vision_transformer/blob/master/vit_jax_augreg.ipynb) from official impl for navigating the augreg weights
|
||||
* Replaced all default weights w/ best AugReg variant (if possible). All AugReg 21k classifiers work.
|
||||
* Highlights: `vit_large_patch16_384` (87.1 top-1), `vit_large_r50_s32_384` (86.2 top-1), `vit_base_patch16_384` (86.0 top-1)
|
||||
* `vit_deit_*` renamed to just `deit_*`
|
||||
* Remove my old small model, replace with DeiT compatible small w/ AugReg weights
|
||||
* Add 1st training of my `gmixer_24_224` MLP /w GLU, 78.1 top-1 w/ 25M params.
|
||||
* Add weights from official ResMLP release (https://github.com/facebookresearch/deit)
|
||||
* Add `eca_nfnet_l2` weights from my 'lightweight' series. 84.7 top-1 at 384x384.
|
||||
* Add distilled BiT 50x1 student and 152x2 Teacher weights from [Knowledge distillation: A good teacher is patient and consistent](https://arxiv.org/abs/2106.05237)
|
||||
* NFNets and ResNetV2-BiT models work w/ Pytorch XLA now
|
||||
* weight standardization uses F.batch_norm instead of std_mean (std_mean wasn't lowered)
|
||||
* eps values adjusted, will be slight differences but should be quite close
|
||||
* Improve test coverage and classifier interface of non-conv (vision transformer and mlp) models
|
||||
* Cleanup a few classifier / flatten details for models w/ conv classifiers or early global pool
|
||||
* Please report any regressions, this PR touched quite a few models.
|
||||
|
||||
### June 8, 2021
|
||||
* Add first ResMLP weights, trained in PyTorch XLA on TPU-VM w/ my XLA branch. 24 block variant, 79.2 top-1.
|
||||
* Add ResNet51-Q model w/ pretrained weights at 82.36 top-1.
|
||||
|
|
|
@ -17,7 +17,7 @@ if hasattr(torch._C, '_jit_set_profiling_executor'):
|
|||
# transformer models don't support many of the spatial / feature based model functionalities
|
||||
NON_STD_FILTERS = [
|
||||
'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
|
||||
'convit_*', 'levit*', 'visformer*']
|
||||
'convit_*', 'levit*', 'visformer*', 'deit*']
|
||||
NUM_NON_STD = len(NON_STD_FILTERS)
|
||||
|
||||
# exclude models that cause specific test failures
|
||||
|
@ -120,7 +120,6 @@ def test_model_default_cfgs(model_name, batch_size):
|
|||
state_dict = model.state_dict()
|
||||
cfg = model.default_cfg
|
||||
|
||||
classifier = cfg['classifier']
|
||||
pool_size = cfg['pool_size']
|
||||
input_size = model.default_cfg['input_size']
|
||||
|
||||
|
@ -149,7 +148,57 @@ def test_model_default_cfgs(model_name, batch_size):
|
|||
assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2]
|
||||
|
||||
# check classifier name matches default_cfg
|
||||
assert classifier + ".weight" in state_dict.keys(), f'{classifier} not in model params'
|
||||
classifier = cfg['classifier']
|
||||
if not isinstance(classifier, (tuple, list)):
|
||||
classifier = classifier,
|
||||
for c in classifier:
|
||||
assert c + ".weight" in state_dict.keys(), f'{c} not in model params'
|
||||
|
||||
# check first conv(s) names match default_cfg
|
||||
first_conv = cfg['first_conv']
|
||||
if isinstance(first_conv, str):
|
||||
first_conv = (first_conv,)
|
||||
assert isinstance(first_conv, (tuple, list))
|
||||
for fc in first_conv:
|
||||
assert fc + ".weight" in state_dict.keys(), f'{fc} not in model params'
|
||||
|
||||
|
||||
@pytest.mark.timeout(300)
|
||||
@pytest.mark.parametrize('model_name', list_models(filter=NON_STD_FILTERS))
|
||||
@pytest.mark.parametrize('batch_size', [1])
|
||||
def test_model_default_cfgs_non_std(model_name, batch_size):
|
||||
"""Run a single forward pass with each model"""
|
||||
model = create_model(model_name, pretrained=False)
|
||||
model.eval()
|
||||
state_dict = model.state_dict()
|
||||
cfg = model.default_cfg
|
||||
|
||||
input_size = _get_input_size(model=model)
|
||||
if max(input_size) > 320: # FIXME const
|
||||
pytest.skip("Fixed input size model > limit.")
|
||||
|
||||
input_tensor = torch.randn((batch_size, *input_size))
|
||||
|
||||
# test forward_features (always unpooled)
|
||||
outputs = model.forward_features(input_tensor)
|
||||
if isinstance(outputs, tuple):
|
||||
outputs = outputs[0]
|
||||
assert outputs.shape[1] == model.num_features
|
||||
|
||||
# test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features
|
||||
model.reset_classifier(0)
|
||||
outputs = model.forward(input_tensor)
|
||||
if isinstance(outputs, tuple):
|
||||
outputs = outputs[0]
|
||||
assert len(outputs.shape) == 2
|
||||
assert outputs.shape[1] == model.num_features
|
||||
|
||||
# check classifier name matches default_cfg
|
||||
classifier = cfg['classifier']
|
||||
if not isinstance(classifier, (tuple, list)):
|
||||
classifier = classifier,
|
||||
for c in classifier:
|
||||
assert c + ".weight" in state_dict.keys(), f'{c} not in model params'
|
||||
|
||||
# check first conv(s) names match default_cfg
|
||||
first_conv = cfg['first_conv']
|
||||
|
|
|
@ -74,11 +74,11 @@ default_cfgs = dict(
|
|||
class ClassAttn(nn.Module):
|
||||
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
||||
# with slight modifications to do CA
|
||||
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
||||
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = qk_scale or head_dim ** -0.5
|
||||
self.scale = head_dim ** -0.5
|
||||
|
||||
self.q = nn.Linear(dim, dim, bias=qkv_bias)
|
||||
self.k = nn.Linear(dim, dim, bias=qkv_bias)
|
||||
|
@ -110,13 +110,13 @@ class LayerScaleBlockClassAttn(nn.Module):
|
|||
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
||||
# with slight modifications to add CA and LayerScale
|
||||
def __init__(
|
||||
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
||||
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
|
||||
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, attn_block=ClassAttn,
|
||||
mlp_block=Mlp, init_values=1e-4):
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = attn_block(
|
||||
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
||||
dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
|
@ -134,14 +134,14 @@ class LayerScaleBlockClassAttn(nn.Module):
|
|||
class TalkingHeadAttn(nn.Module):
|
||||
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
||||
# with slight modifications to add Talking Heads Attention (https://arxiv.org/pdf/2003.02436v1.pdf)
|
||||
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
||||
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
|
||||
super().__init__()
|
||||
|
||||
self.num_heads = num_heads
|
||||
|
||||
head_dim = dim // num_heads
|
||||
|
||||
self.scale = qk_scale or head_dim ** -0.5
|
||||
self.scale = head_dim ** -0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
|
@ -177,13 +177,13 @@ class LayerScaleBlock(nn.Module):
|
|||
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
||||
# with slight modifications to add layerScale
|
||||
def __init__(
|
||||
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
||||
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
|
||||
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, attn_block=TalkingHeadAttn,
|
||||
mlp_block=Mlp, init_values=1e-4):
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = attn_block(
|
||||
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
||||
dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
|
@ -202,7 +202,7 @@ class Cait(nn.Module):
|
|||
# with slight modifications to adapt to our cait models
|
||||
def __init__(
|
||||
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
||||
num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
|
||||
num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
global_pool=None,
|
||||
|
@ -235,14 +235,14 @@ class Cait(nn.Module):
|
|||
dpr = [drop_path_rate for i in range(depth)]
|
||||
self.blocks = nn.ModuleList([
|
||||
block_layers(
|
||||
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
||||
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
|
||||
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
|
||||
act_layer=act_layer, attn_block=attn_block, mlp_block=mlp_block, init_values=init_scale)
|
||||
for i in range(depth)])
|
||||
|
||||
self.blocks_token_only = nn.ModuleList([
|
||||
block_layers_token(
|
||||
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio_clstk, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
||||
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio_clstk, qkv_bias=qkv_bias,
|
||||
drop=0.0, attn_drop=0.0, drop_path=0.0, norm_layer=norm_layer,
|
||||
act_layer=act_layer, attn_block=attn_block_token_only,
|
||||
mlp_block=mlp_block_token_only, init_values=init_scale)
|
||||
|
@ -270,6 +270,13 @@ class Cait(nn.Module):
|
|||
def no_weight_decay(self):
|
||||
return {'pos_embed', 'cls_token'}
|
||||
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=''):
|
||||
self.num_classes = num_classes
|
||||
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def forward_features(self, x):
|
||||
B = x.shape[0]
|
||||
x = self.patch_embed(x)
|
||||
|
@ -293,7 +300,6 @@ class Cait(nn.Module):
|
|||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
x = self.head(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
|
|
|
@ -335,6 +335,8 @@ class CoaT(nn.Module):
|
|||
crpe_window = crpe_window or {3: 2, 5: 3, 7: 3}
|
||||
self.return_interm_layers = return_interm_layers
|
||||
self.out_features = out_features
|
||||
self.embed_dims = embed_dims
|
||||
self.num_features = embed_dims[-1]
|
||||
self.num_classes = num_classes
|
||||
|
||||
# Patch embeddings.
|
||||
|
@ -441,10 +443,10 @@ class CoaT(nn.Module):
|
|||
# CoaT series: Aggregate features of last three scales for classification.
|
||||
assert embed_dims[1] == embed_dims[2] == embed_dims[3]
|
||||
self.aggregate = torch.nn.Conv1d(in_channels=3, out_channels=1, kernel_size=1)
|
||||
self.head = nn.Linear(embed_dims[3], num_classes)
|
||||
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
else:
|
||||
# CoaT-Lite series: Use feature of last scale for classification.
|
||||
self.head = nn.Linear(embed_dims[3], num_classes)
|
||||
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
# Initialize weights.
|
||||
trunc_normal_(self.cls_token1, std=.02)
|
||||
|
@ -471,7 +473,7 @@ class CoaT(nn.Module):
|
|||
|
||||
def reset_classifier(self, num_classes, global_pool=''):
|
||||
self.num_classes = num_classes
|
||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def insert_cls(self, x, cls_token):
|
||||
""" Insert CLS token. """
|
||||
|
|
|
@ -57,13 +57,13 @@ default_cfgs = {
|
|||
|
||||
|
||||
class GPSA(nn.Module):
|
||||
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.,
|
||||
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.,
|
||||
locality_strength=1.):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.dim = dim
|
||||
head_dim = dim // num_heads
|
||||
self.scale = qk_scale or head_dim ** -0.5
|
||||
self.scale = head_dim ** -0.5
|
||||
self.locality_strength = locality_strength
|
||||
|
||||
self.qk = nn.Linear(dim, dim * 2, bias=qkv_bias)
|
||||
|
@ -141,11 +141,11 @@ class GPSA(nn.Module):
|
|||
|
||||
|
||||
class MHSA(nn.Module):
|
||||
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
||||
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = qk_scale or head_dim ** -0.5
|
||||
self.scale = head_dim ** -0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
|
@ -190,19 +190,16 @@ class MHSA(nn.Module):
|
|||
|
||||
class Block(nn.Module):
|
||||
|
||||
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
||||
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
|
||||
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_gpsa=True, **kwargs):
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.use_gpsa = use_gpsa
|
||||
if self.use_gpsa:
|
||||
self.attn = GPSA(
|
||||
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop,
|
||||
proj_drop=drop, **kwargs)
|
||||
dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, **kwargs)
|
||||
else:
|
||||
self.attn = MHSA(
|
||||
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop,
|
||||
proj_drop=drop, **kwargs)
|
||||
self.attn = MHSA(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
|
@ -219,7 +216,7 @@ class ConViT(nn.Module):
|
|||
"""
|
||||
|
||||
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
||||
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
|
||||
num_heads=12, mlp_ratio=4., qkv_bias=False, drop_rate=0., attn_drop_rate=0.,
|
||||
drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm, global_pool=None,
|
||||
local_up_to_layer=3, locality_strength=1., use_pos_embed=True):
|
||||
super().__init__()
|
||||
|
@ -249,13 +246,13 @@ class ConViT(nn.Module):
|
|||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
||||
self.blocks = nn.ModuleList([
|
||||
Block(
|
||||
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
||||
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
|
||||
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
|
||||
use_gpsa=True,
|
||||
locality_strength=locality_strength)
|
||||
if i < local_up_to_layer else
|
||||
Block(
|
||||
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
||||
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
|
||||
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
|
||||
use_gpsa=False)
|
||||
for i in range(depth)])
|
||||
|
|
|
@ -288,6 +288,8 @@ class DLA(nn.Module):
|
|||
self.num_features = channels[-1]
|
||||
self.global_pool, self.fc = create_classifier(
|
||||
self.num_features, self.num_classes, pool_type=global_pool, use_conv=True)
|
||||
self.flatten = nn.Flatten(1) if global_pool else nn.Identity()
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
|
@ -314,6 +316,7 @@ class DLA(nn.Module):
|
|||
self.num_classes = num_classes
|
||||
self.global_pool, self.fc = create_classifier(
|
||||
self.num_features, self.num_classes, pool_type=global_pool, use_conv=True)
|
||||
self.flatten = nn.Flatten(1) if global_pool else nn.Identity()
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.base_layer(x)
|
||||
|
@ -331,8 +334,7 @@ class DLA(nn.Module):
|
|||
if self.drop_rate > 0.:
|
||||
x = F.dropout(x, p=self.drop_rate, training=self.training)
|
||||
x = self.fc(x)
|
||||
if not self.global_pool.is_identity():
|
||||
x = x.flatten(1) # conv classifier, flatten if pooling isn't pass-through (disabled)
|
||||
x = self.flatten(x)
|
||||
return x
|
||||
|
||||
|
||||
|
|
|
@ -237,6 +237,7 @@ class DPN(nn.Module):
|
|||
# Using 1x1 conv for the FC layer to allow the extra pooling scheme
|
||||
self.global_pool, self.classifier = create_classifier(
|
||||
self.num_features, self.num_classes, pool_type=global_pool, use_conv=True)
|
||||
self.flatten = nn.Flatten(1) if global_pool else nn.Identity()
|
||||
|
||||
def get_classifier(self):
|
||||
return self.classifier
|
||||
|
@ -245,6 +246,7 @@ class DPN(nn.Module):
|
|||
self.num_classes = num_classes
|
||||
self.global_pool, self.classifier = create_classifier(
|
||||
self.num_features, self.num_classes, pool_type=global_pool, use_conv=True)
|
||||
self.flatten = nn.Flatten(1) if global_pool else nn.Identity()
|
||||
|
||||
def forward_features(self, x):
|
||||
return self.features(x)
|
||||
|
@ -255,8 +257,7 @@ class DPN(nn.Module):
|
|||
if self.drop_rate > 0.:
|
||||
x = F.dropout(x, p=self.drop_rate, training=self.training)
|
||||
x = self.classifier(x)
|
||||
if not self.global_pool.is_identity():
|
||||
x = x.flatten(1) # conv classifier, flatten if pooling isn't pass-through (disabled)
|
||||
x = self.flatten(x)
|
||||
return x
|
||||
|
||||
|
||||
|
|
|
@ -133,7 +133,7 @@ class GhostBottleneck(nn.Module):
|
|||
|
||||
|
||||
class GhostNet(nn.Module):
|
||||
def __init__(self, cfgs, num_classes=1000, width=1.0, dropout=0.2, in_chans=3, output_stride=32):
|
||||
def __init__(self, cfgs, num_classes=1000, width=1.0, dropout=0.2, in_chans=3, output_stride=32, global_pool='avg'):
|
||||
super(GhostNet, self).__init__()
|
||||
# setting of inverted residual blocks
|
||||
assert output_stride == 32, 'only output_stride==32 is valid, dilation not supported'
|
||||
|
@ -178,9 +178,10 @@ class GhostNet(nn.Module):
|
|||
|
||||
# building last several layers
|
||||
self.num_features = out_chs = 1280
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type='avg')
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||
self.conv_head = nn.Conv2d(prev_chs, out_chs, 1, 1, 0, bias=True)
|
||||
self.act2 = nn.ReLU(inplace=True)
|
||||
self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
|
||||
self.classifier = Linear(out_chs, num_classes)
|
||||
|
||||
def get_classifier(self):
|
||||
|
@ -190,6 +191,7 @@ class GhostNet(nn.Module):
|
|||
self.num_classes = num_classes
|
||||
# cannot meaningfully change pooling of efficient head after creation
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||
self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
|
||||
self.classifier = Linear(self.pool_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def forward_features(self, x):
|
||||
|
@ -204,8 +206,7 @@ class GhostNet(nn.Module):
|
|||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
if not self.global_pool.is_identity():
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.flatten(x)
|
||||
if self.dropout > 0.:
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
x = self.classifier(x)
|
||||
|
|
|
@ -45,6 +45,13 @@ def load_state_dict(checkpoint_path, use_ema=False):
|
|||
|
||||
|
||||
def load_checkpoint(model, checkpoint_path, use_ema=False, strict=True):
|
||||
if os.path.splitext(checkpoint_path)[-1].lower() in ('.npz', '.npy'):
|
||||
# numpy checkpoint, try to load via model specific load_pretrained fn
|
||||
if hasattr(model, 'load_pretrained'):
|
||||
model.load_pretrained(checkpoint_path)
|
||||
else:
|
||||
raise NotImplementedError('Model cannot load numpy checkpoint')
|
||||
return
|
||||
state_dict = load_state_dict(checkpoint_path, use_ema)
|
||||
model.load_state_dict(state_dict, strict=strict)
|
||||
|
||||
|
@ -477,3 +484,25 @@ def model_parameters(model, exclude_head=False):
|
|||
return [p for p in model.parameters()][:-2]
|
||||
else:
|
||||
return model.parameters()
|
||||
|
||||
|
||||
def named_apply(fn: Callable, module: nn.Module, name='', depth_first=True, include_root=False) -> nn.Module:
|
||||
if not depth_first and include_root:
|
||||
fn(module=module, name=name)
|
||||
for child_name, child_module in module.named_children():
|
||||
child_name = '.'.join((name, child_name)) if name else child_name
|
||||
named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
|
||||
if depth_first and include_root:
|
||||
fn(module=module, name=name)
|
||||
return module
|
||||
|
||||
|
||||
def named_modules(module: nn.Module, name='', depth_first=True, include_root=False):
|
||||
if not depth_first and include_root:
|
||||
yield name, module
|
||||
for child_name, child_module in module.named_children():
|
||||
child_name = '.'.join((name, child_name)) if name else child_name
|
||||
yield from named_modules(
|
||||
module=child_module, name=child_name, depth_first=depth_first, include_root=True)
|
||||
if depth_first and include_root:
|
||||
yield name, module
|
||||
|
|
|
@ -55,7 +55,7 @@ class FastAdaptiveAvgPool2d(nn.Module):
|
|||
self.flatten = flatten
|
||||
|
||||
def forward(self, x):
|
||||
return x.mean((2, 3)) if self.flatten else x.mean((2, 3), keepdim=True)
|
||||
return x.mean((2, 3), keepdim=not self.flatten)
|
||||
|
||||
|
||||
class AdaptiveAvgMaxPool2d(nn.Module):
|
||||
|
@ -82,13 +82,13 @@ class SelectAdaptivePool2d(nn.Module):
|
|||
def __init__(self, output_size=1, pool_type='fast', flatten=False):
|
||||
super(SelectAdaptivePool2d, self).__init__()
|
||||
self.pool_type = pool_type or '' # convert other falsy values to empty string for consistent TS typing
|
||||
self.flatten = flatten
|
||||
self.flatten = nn.Flatten(1) if flatten else nn.Identity()
|
||||
if pool_type == '':
|
||||
self.pool = nn.Identity() # pass through
|
||||
elif pool_type == 'fast':
|
||||
assert output_size == 1
|
||||
self.pool = FastAdaptiveAvgPool2d(self.flatten)
|
||||
self.flatten = False
|
||||
self.pool = FastAdaptiveAvgPool2d(flatten)
|
||||
self.flatten = nn.Identity()
|
||||
elif pool_type == 'avg':
|
||||
self.pool = nn.AdaptiveAvgPool2d(output_size)
|
||||
elif pool_type == 'avgmax':
|
||||
|
@ -101,12 +101,11 @@ class SelectAdaptivePool2d(nn.Module):
|
|||
assert False, 'Invalid pool type: %s' % pool_type
|
||||
|
||||
def is_identity(self):
|
||||
return self.pool_type == ''
|
||||
return not self.pool_type
|
||||
|
||||
def forward(self, x):
|
||||
x = self.pool(x)
|
||||
if self.flatten:
|
||||
x = x.flatten(1)
|
||||
x = self.flatten(x)
|
||||
return x
|
||||
|
||||
def feat_mult(self):
|
||||
|
|
|
@ -20,7 +20,7 @@ def _create_pool(num_features, num_classes, pool_type='avg', use_conv=False):
|
|||
return global_pool, num_pooled_features
|
||||
|
||||
|
||||
def _create_fc(num_features, num_classes, pool_type='avg', use_conv=False):
|
||||
def _create_fc(num_features, num_classes, use_conv=False):
|
||||
if num_classes <= 0:
|
||||
fc = nn.Identity() # pass-through (no classifier)
|
||||
elif use_conv:
|
||||
|
@ -45,11 +45,12 @@ class ClassifierHead(nn.Module):
|
|||
self.drop_rate = drop_rate
|
||||
self.global_pool, num_pooled_features = _create_pool(in_chs, num_classes, pool_type, use_conv=use_conv)
|
||||
self.fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv)
|
||||
self.flatten_after_fc = use_conv and pool_type
|
||||
self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.global_pool(x)
|
||||
if self.drop_rate:
|
||||
x = F.dropout(x, p=float(self.drop_rate), training=self.training)
|
||||
x = self.fc(x)
|
||||
x = self.flatten(x)
|
||||
return x
|
||||
|
|
|
@ -40,6 +40,12 @@ class GluMlp(nn.Module):
|
|||
self.fc2 = nn.Linear(hidden_features // 2, out_features)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def init_weights(self):
|
||||
# override init of fc1 w/ gate portion set to weight near zero, bias=1
|
||||
fc1_mid = self.fc1.bias.shape[0] // 2
|
||||
nn.init.ones_(self.fc1.bias[fc1_mid:])
|
||||
nn.init.normal_(self.fc1.weight[fc1_mid:], std=1e-6)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x, gates = x.chunk(2, dim=-1)
|
||||
|
|
|
@ -27,7 +27,8 @@ class AvgPool2dSame(nn.AvgPool2d):
|
|||
super(AvgPool2dSame, self).__init__(kernel_size, stride, (0, 0), ceil_mode, count_include_pad)
|
||||
|
||||
def forward(self, x):
|
||||
return avg_pool2d_same(
|
||||
x = pad_same(x, self.kernel_size, self.stride)
|
||||
return F.avg_pool2d(
|
||||
x, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad)
|
||||
|
||||
|
||||
|
@ -41,14 +42,15 @@ def max_pool2d_same(
|
|||
class MaxPool2dSame(nn.MaxPool2d):
|
||||
""" Tensorflow like 'SAME' wrapper for 2D max pooling
|
||||
"""
|
||||
def __init__(self, kernel_size: int, stride=None, padding=0, dilation=1, ceil_mode=False, count_include_pad=True):
|
||||
def __init__(self, kernel_size: int, stride=None, padding=0, dilation=1, ceil_mode=False):
|
||||
kernel_size = to_2tuple(kernel_size)
|
||||
stride = to_2tuple(stride)
|
||||
dilation = to_2tuple(dilation)
|
||||
super(MaxPool2dSame, self).__init__(kernel_size, stride, (0, 0), dilation, ceil_mode, count_include_pad)
|
||||
super(MaxPool2dSame, self).__init__(kernel_size, stride, (0, 0), dilation, ceil_mode)
|
||||
|
||||
def forward(self, x):
|
||||
return max_pool2d_same(x, self.kernel_size, self.stride, self.padding, self.dilation, self.ceil_mode)
|
||||
x = pad_same(x, self.kernel_size, self.stride, value=-float('inf'))
|
||||
return F.max_pool2d(x, self.kernel_size, self.stride, (0, 0), self.dilation, self.ceil_mode)
|
||||
|
||||
|
||||
def create_pool2d(pool_type, kernel_size, stride=None, **kwargs):
|
||||
|
|
|
@ -1,3 +1,21 @@
|
|||
""" Convolution with Weight Standardization (StdConv and ScaledStdConv)
|
||||
|
||||
StdConv:
|
||||
@article{weightstandardization,
|
||||
author = {Siyuan Qiao and Huiyu Wang and Chenxi Liu and Wei Shen and Alan Yuille},
|
||||
title = {Weight Standardization},
|
||||
journal = {arXiv preprint arXiv:1903.10520},
|
||||
year = {2019},
|
||||
}
|
||||
Code: https://github.com/joe-siyuan-qiao/WeightStandardization
|
||||
|
||||
ScaledStdConv:
|
||||
Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets`
|
||||
- https://arxiv.org/abs/2101.08692
|
||||
Official Deepmind JAX code: https://github.com/deepmind/deepmind-research/tree/master/nfnets
|
||||
|
||||
Hacked together by / copyright Ross Wightman, 2021.
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
@ -5,12 +23,6 @@ import torch.nn.functional as F
|
|||
from .padding import get_padding, get_padding_value, pad_same
|
||||
|
||||
|
||||
def get_weight(module):
|
||||
std, mean = torch.std_mean(module.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
|
||||
weight = (module.weight - mean) / (std + module.eps)
|
||||
return weight
|
||||
|
||||
|
||||
class StdConv2d(nn.Conv2d):
|
||||
"""Conv2d with Weight Standardization. Used for BiT ResNet-V2 models.
|
||||
|
||||
|
@ -18,8 +30,8 @@ class StdConv2d(nn.Conv2d):
|
|||
https://arxiv.org/abs/1903.10520v2
|
||||
"""
|
||||
def __init__(
|
||||
self, in_channel, out_channels, kernel_size, stride=1, padding=None, dilation=1,
|
||||
groups=1, bias=False, eps=1e-5):
|
||||
self, in_channel, out_channels, kernel_size, stride=1, padding=None,
|
||||
dilation=1, groups=1, bias=False, eps=1e-6):
|
||||
if padding is None:
|
||||
padding = get_padding(kernel_size, stride, dilation)
|
||||
super().__init__(
|
||||
|
@ -27,13 +39,11 @@ class StdConv2d(nn.Conv2d):
|
|||
padding=padding, dilation=dilation, groups=groups, bias=bias)
|
||||
self.eps = eps
|
||||
|
||||
def get_weight(self):
|
||||
std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
|
||||
weight = (self.weight - mean) / (std + self.eps)
|
||||
return weight
|
||||
|
||||
def forward(self, x):
|
||||
x = F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
weight = F.batch_norm(
|
||||
self.weight.view(1, self.out_channels, -1), None, None,
|
||||
training=True, momentum=0., eps=self.eps).reshape_as(self.weight)
|
||||
x = F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
return x
|
||||
|
||||
|
||||
|
@ -44,8 +54,8 @@ class StdConv2dSame(nn.Conv2d):
|
|||
https://arxiv.org/abs/1903.10520v2
|
||||
"""
|
||||
def __init__(
|
||||
self, in_channel, out_channels, kernel_size, stride=1, padding='SAME', dilation=1,
|
||||
groups=1, bias=False, eps=1e-5):
|
||||
self, in_channel, out_channels, kernel_size, stride=1, padding='SAME',
|
||||
dilation=1, groups=1, bias=False, eps=1e-6):
|
||||
padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, dilation=dilation)
|
||||
super().__init__(
|
||||
in_channel, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation,
|
||||
|
@ -53,15 +63,13 @@ class StdConv2dSame(nn.Conv2d):
|
|||
self.same_pad = is_dynamic
|
||||
self.eps = eps
|
||||
|
||||
def get_weight(self):
|
||||
std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
|
||||
weight = (self.weight - mean) / (std + self.eps)
|
||||
return weight
|
||||
|
||||
def forward(self, x):
|
||||
if self.same_pad:
|
||||
x = pad_same(x, self.kernel_size, self.stride, self.dilation)
|
||||
x = F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
weight = F.batch_norm(
|
||||
self.weight.view(1, self.out_channels, -1), None, None,
|
||||
training=True, momentum=0., eps=self.eps).reshape_as(self.weight)
|
||||
x = F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
return x
|
||||
|
||||
|
||||
|
@ -75,8 +83,8 @@ class ScaledStdConv2d(nn.Conv2d):
|
|||
"""
|
||||
|
||||
def __init__(
|
||||
self, in_channels, out_channels, kernel_size, stride=1, padding=None, dilation=1, groups=1,
|
||||
bias=True, gamma=1.0, eps=1e-5, gain_init=1.0, use_layernorm=False):
|
||||
self, in_channels, out_channels, kernel_size, stride=1, padding=None,
|
||||
dilation=1, groups=1, bias=True, gamma=1.0, eps=1e-6, gain_init=1.0):
|
||||
if padding is None:
|
||||
padding = get_padding(kernel_size, stride, dilation)
|
||||
super().__init__(
|
||||
|
@ -84,19 +92,14 @@ class ScaledStdConv2d(nn.Conv2d):
|
|||
groups=groups, bias=bias)
|
||||
self.gain = nn.Parameter(torch.full((self.out_channels, 1, 1, 1), gain_init))
|
||||
self.scale = gamma * self.weight[0].numel() ** -0.5 # gamma * 1 / sqrt(fan-in)
|
||||
self.eps = eps ** 2 if use_layernorm else eps
|
||||
self.use_layernorm = use_layernorm # experimental, slightly faster/less GPU memory to hijack LN kernel
|
||||
|
||||
def get_weight(self):
|
||||
if self.use_layernorm:
|
||||
weight = self.scale * F.layer_norm(self.weight, self.weight.shape[1:], eps=self.eps)
|
||||
else:
|
||||
std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
|
||||
weight = self.scale * (self.weight - mean) / (std + self.eps)
|
||||
return self.gain * weight
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x):
|
||||
return F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
weight = F.batch_norm(
|
||||
self.weight.view(1, self.out_channels, -1), None, None,
|
||||
weight=(self.gain * self.scale).view(-1),
|
||||
training=True, momentum=0., eps=self.eps).reshape_as(self.weight)
|
||||
return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
|
||||
|
||||
class ScaledStdConv2dSame(nn.Conv2d):
|
||||
|
@ -109,8 +112,8 @@ class ScaledStdConv2dSame(nn.Conv2d):
|
|||
"""
|
||||
|
||||
def __init__(
|
||||
self, in_channels, out_channels, kernel_size, stride=1, padding='SAME', dilation=1, groups=1,
|
||||
bias=True, gamma=1.0, eps=1e-5, gain_init=1.0, use_layernorm=False):
|
||||
self, in_channels, out_channels, kernel_size, stride=1, padding='SAME',
|
||||
dilation=1, groups=1, bias=True, gamma=1.0, eps=1e-6, gain_init=1.0):
|
||||
padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, dilation=dilation)
|
||||
super().__init__(
|
||||
in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation,
|
||||
|
@ -118,26 +121,13 @@ class ScaledStdConv2dSame(nn.Conv2d):
|
|||
self.gain = nn.Parameter(torch.full((self.out_channels, 1, 1, 1), gain_init))
|
||||
self.scale = gamma * self.weight[0].numel() ** -0.5
|
||||
self.same_pad = is_dynamic
|
||||
self.eps = eps ** 2 if use_layernorm else eps
|
||||
self.use_layernorm = use_layernorm # experimental, slightly faster/less GPU memory to hijack LN kernel
|
||||
|
||||
# NOTE an alternate formulation to consider, closer to DeepMind Haiku impl but doesn't seem
|
||||
# to make much numerical difference (+/- .002 to .004) in top-1 during eval.
|
||||
# def get_weight(self):
|
||||
# var, mean = torch.var_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
|
||||
# scale = torch.rsqrt((self.weight[0].numel() * var).clamp_(self.eps)) * self.gain
|
||||
# weight = (self.weight - mean) * scale
|
||||
# return self.gain * weight
|
||||
|
||||
def get_weight(self):
|
||||
if self.use_layernorm:
|
||||
weight = self.scale * F.layer_norm(self.weight, self.weight.shape[1:], eps=self.eps)
|
||||
else:
|
||||
std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
|
||||
weight = self.scale * (self.weight - mean) / (std + self.eps)
|
||||
return self.gain * weight
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x):
|
||||
if self.same_pad:
|
||||
x = pad_same(x, self.kernel_size, self.stride, self.dilation)
|
||||
return F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
weight = F.batch_norm(
|
||||
self.weight.view(1, self.out_channels, -1), None, None,
|
||||
weight=(self.gain * self.scale).view(-1),
|
||||
training=True, momentum=0., eps=self.eps).reshape_as(self.weight)
|
||||
return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
|
|
|
@ -84,63 +84,33 @@ __all__ = ['Levit']
|
|||
|
||||
|
||||
@register_model
|
||||
def levit_128s(pretrained=False, fuse=False,distillation=True, use_conv=False, **kwargs):
|
||||
def levit_128s(pretrained=False, use_conv=False, **kwargs):
|
||||
return create_levit(
|
||||
'levit_128s', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
|
||||
'levit_128s', pretrained=pretrained, use_conv=use_conv, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def levit_128(pretrained=False, fuse=False, distillation=True, use_conv=False, **kwargs):
|
||||
def levit_128(pretrained=False, use_conv=False, **kwargs):
|
||||
return create_levit(
|
||||
'levit_128', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
|
||||
'levit_128', pretrained=pretrained, use_conv=use_conv, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def levit_192(pretrained=False, fuse=False, distillation=True, use_conv=False, **kwargs):
|
||||
def levit_192(pretrained=False, use_conv=False, **kwargs):
|
||||
return create_levit(
|
||||
'levit_192', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
|
||||
'levit_192', pretrained=pretrained, use_conv=use_conv, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def levit_256(pretrained=False, fuse=False, distillation=True, use_conv=False, **kwargs):
|
||||
def levit_256(pretrained=False, use_conv=False, **kwargs):
|
||||
return create_levit(
|
||||
'levit_256', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
|
||||
'levit_256', pretrained=pretrained, use_conv=use_conv, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def levit_384(pretrained=False, fuse=False, distillation=True, use_conv=False, **kwargs):
|
||||
def levit_384(pretrained=False, use_conv=False, **kwargs):
|
||||
return create_levit(
|
||||
'levit_384', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def levit_c_128s(pretrained=False, fuse=False, distillation=True, use_conv=True,**kwargs):
|
||||
return create_levit(
|
||||
'levit_128s', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def levit_c_128(pretrained=False, fuse=False,distillation=True, use_conv=True, **kwargs):
|
||||
return create_levit(
|
||||
'levit_128', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def levit_c_192(pretrained=False, fuse=False, distillation=True, use_conv=True, **kwargs):
|
||||
return create_levit(
|
||||
'levit_192', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def levit_c_256(pretrained=False, fuse=False, distillation=True, use_conv=True, **kwargs):
|
||||
return create_levit(
|
||||
'levit_256', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def levit_c_384(pretrained=False, fuse=False, distillation=True, use_conv=True, **kwargs):
|
||||
return create_levit(
|
||||
'levit_384', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
|
||||
'levit_384', pretrained=pretrained, use_conv=use_conv, **kwargs)
|
||||
|
||||
|
||||
class ConvNorm(nn.Sequential):
|
||||
|
@ -427,6 +397,9 @@ class AttentionSubsample(nn.Module):
|
|||
|
||||
class Levit(nn.Module):
|
||||
""" Vision Transformer with support for patch or hybrid CNN input stage
|
||||
|
||||
NOTE: distillation is defaulted to True since pretrained weights use it, will cause problems
|
||||
w/ train scripts that don't take tuple outputs,
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -447,7 +420,8 @@ class Levit(nn.Module):
|
|||
attn_act_layer='hard_swish',
|
||||
distillation=True,
|
||||
use_conv=False,
|
||||
drop_path=0):
|
||||
drop_rate=0.,
|
||||
drop_path_rate=0.):
|
||||
super().__init__()
|
||||
act_layer = get_act_layer(act_layer)
|
||||
attn_act_layer = get_act_layer(attn_act_layer)
|
||||
|
@ -486,7 +460,7 @@ class Levit(nn.Module):
|
|||
Attention(
|
||||
ed, kd, nh, attn_ratio=ar, act_layer=attn_act_layer,
|
||||
resolution=resolution, use_conv=use_conv),
|
||||
drop_path))
|
||||
drop_path_rate))
|
||||
if mr > 0:
|
||||
h = int(ed * mr)
|
||||
self.blocks.append(
|
||||
|
@ -494,7 +468,7 @@ class Levit(nn.Module):
|
|||
ln_layer(ed, h, resolution=resolution),
|
||||
act_layer(),
|
||||
ln_layer(h, ed, bn_weight_init=0, resolution=resolution),
|
||||
), drop_path))
|
||||
), drop_path_rate))
|
||||
if do[0] == 'Subsample':
|
||||
# ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride)
|
||||
resolution_ = (resolution - 1) // do[5] + 1
|
||||
|
@ -511,26 +485,45 @@ class Levit(nn.Module):
|
|||
ln_layer(embed_dim[i + 1], h, resolution=resolution),
|
||||
act_layer(),
|
||||
ln_layer(h, embed_dim[i + 1], bn_weight_init=0, resolution=resolution),
|
||||
), drop_path))
|
||||
), drop_path_rate))
|
||||
self.blocks = nn.Sequential(*self.blocks)
|
||||
|
||||
# Classifier head
|
||||
self.head = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity()
|
||||
self.head_dist = None
|
||||
if distillation:
|
||||
self.head_dist = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity()
|
||||
else:
|
||||
self.head_dist = None
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay(self):
|
||||
return {x for x in self.state_dict().keys() if 'attention_biases' in x}
|
||||
|
||||
def forward(self, x):
|
||||
def get_classifier(self):
|
||||
if self.head_dist is None:
|
||||
return self.head
|
||||
else:
|
||||
return self.head, self.head_dist
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool='', distillation=None):
|
||||
self.num_classes = num_classes
|
||||
self.head = NormLinear(self.embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity()
|
||||
if distillation is not None:
|
||||
self.distillation = distillation
|
||||
if self.distillation:
|
||||
self.head_dist = NormLinear(self.embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity()
|
||||
else:
|
||||
self.head_dist = None
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.patch_embed(x)
|
||||
if not self.use_conv:
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
x = self.blocks(x)
|
||||
x = x.mean((-2, -1)) if self.use_conv else x.mean(1)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
if self.head_dist is not None:
|
||||
x, x_dist = self.head(x), self.head_dist(x)
|
||||
if self.training and not torch.jit.is_scripting():
|
||||
|
|
|
@ -14,8 +14,9 @@ Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2
|
|||
year={2021}
|
||||
}
|
||||
|
||||
Also supporting preliminary (not verified) implementations of ResMlp, gMLP, and possibly more...
|
||||
Also supporting ResMlp, and a preliminary (not verified) implementations of gMLP
|
||||
|
||||
Code: https://github.com/facebookresearch/deit
|
||||
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
|
||||
@misc{touvron2021resmlp,
|
||||
title={ResMLP: Feedforward networks for image classification with data-efficient training},
|
||||
|
@ -45,7 +46,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .helpers import build_model_with_cfg, overlay_external_default_cfg
|
||||
from .helpers import build_model_with_cfg, overlay_external_default_cfg, named_apply
|
||||
from .layers import PatchEmbed, Mlp, GluMlp, GatedMlp, DropPath, lecun_normal_, to_2tuple
|
||||
from .registry import register_model
|
||||
|
||||
|
@ -92,13 +93,40 @@ default_cfgs = dict(
|
|||
),
|
||||
|
||||
gmixer_12_224=_cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
gmixer_24_224=_cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
gmixer_24_224=_cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gmixer_24_224_raa-7daf7ae6.pth',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
|
||||
resmlp_12_224=_cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
resmlp_12_224=_cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/resmlp_12_no_dist.pth',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
resmlp_24_224=_cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resmlp_24_224_raa-a8256759.pth',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=0.89),
|
||||
resmlp_36_224=_cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
url='https://dl.fbaipublicfiles.com/deit/resmlp_24_no_dist.pth',
|
||||
#url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resmlp_24_224_raa-a8256759.pth',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
resmlp_36_224=_cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/resmlp_36_no_dist.pth',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
resmlp_big_24_224=_cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/resmlpB_24_no_dist.pth',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
|
||||
resmlp_12_distilled_224=_cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/resmlp_12_dist.pth',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
resmlp_24_distilled_224=_cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/resmlp_24_dist.pth',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
resmlp_36_distilled_224=_cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/resmlp_36_dist.pth',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
resmlp_big_24_distilled_224=_cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/resmlpB_24_dist.pth',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
|
||||
resmlp_big_24_224_in22ft1k=_cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/resmlpB_24_22k.pth',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
|
||||
gmlp_ti16_224=_cfg(),
|
||||
gmlp_s16_224=_cfg(),
|
||||
|
@ -172,6 +200,11 @@ class SpatialGatingUnit(nn.Module):
|
|||
self.norm = norm_layer(gate_dim)
|
||||
self.proj = nn.Linear(seq_len, seq_len)
|
||||
|
||||
def init_weights(self):
|
||||
# special init for the projection gate, called as override by base model init
|
||||
nn.init.normal_(self.proj.weight, std=1e-6)
|
||||
nn.init.ones_(self.proj.bias)
|
||||
|
||||
def forward(self, x):
|
||||
u, v = x.chunk(2, dim=-1)
|
||||
v = self.norm(v)
|
||||
|
@ -208,7 +241,7 @@ class MlpMixer(nn.Module):
|
|||
in_chans=3,
|
||||
patch_size=16,
|
||||
num_blocks=8,
|
||||
hidden_dim=512,
|
||||
embed_dim=512,
|
||||
mlp_ratio=(0.5, 4.0),
|
||||
block_layer=MixerBlock,
|
||||
mlp_layer=Mlp,
|
||||
|
@ -221,59 +254,95 @@ class MlpMixer(nn.Module):
|
|||
):
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
||||
|
||||
self.stem = PatchEmbed(
|
||||
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=hidden_dim,
|
||||
norm_layer=norm_layer if stem_norm else None)
|
||||
img_size=img_size, patch_size=patch_size, in_chans=in_chans,
|
||||
embed_dim=embed_dim, norm_layer=norm_layer if stem_norm else None)
|
||||
# FIXME drop_path (stochastic depth scaling rule or all the same?)
|
||||
self.blocks = nn.Sequential(*[
|
||||
block_layer(
|
||||
hidden_dim, self.stem.num_patches, mlp_ratio, mlp_layer=mlp_layer, norm_layer=norm_layer,
|
||||
embed_dim, self.stem.num_patches, mlp_ratio, mlp_layer=mlp_layer, norm_layer=norm_layer,
|
||||
act_layer=act_layer, drop=drop_rate, drop_path=drop_path_rate)
|
||||
for _ in range(num_blocks)])
|
||||
self.norm = norm_layer(hidden_dim)
|
||||
self.head = nn.Linear(hidden_dim, self.num_classes) # zero init
|
||||
self.norm = norm_layer(embed_dim)
|
||||
self.head = nn.Linear(embed_dim, self.num_classes) # zero init
|
||||
|
||||
self.init_weights(nlhb=nlhb)
|
||||
|
||||
def init_weights(self, nlhb=False):
|
||||
head_bias = -math.log(self.num_classes) if nlhb else 0.
|
||||
for n, m in self.named_modules():
|
||||
_init_weights(m, n, head_bias=head_bias)
|
||||
named_apply(partial(_init_weights, head_bias=head_bias), module=self) # depth-first
|
||||
|
||||
def forward(self, x):
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=''):
|
||||
self.num_classes = num_classes
|
||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.stem(x)
|
||||
x = self.blocks(x)
|
||||
x = self.norm(x)
|
||||
x = x.mean(dim=1)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
x = self.head(x)
|
||||
return x
|
||||
|
||||
|
||||
def _init_weights(m, n: str, head_bias: float = 0.):
|
||||
def _init_weights(module: nn.Module, name: str, head_bias: float = 0., flax=False):
|
||||
""" Mixer weight initialization (trying to match Flax defaults)
|
||||
"""
|
||||
if isinstance(m, nn.Linear):
|
||||
if n.startswith('head'):
|
||||
nn.init.zeros_(m.weight)
|
||||
nn.init.constant_(m.bias, head_bias)
|
||||
elif n.endswith('gate.proj'):
|
||||
nn.init.normal_(m.weight, std=1e-4)
|
||||
nn.init.ones_(m.bias)
|
||||
if isinstance(module, nn.Linear):
|
||||
if name.startswith('head'):
|
||||
nn.init.zeros_(module.weight)
|
||||
nn.init.constant_(module.bias, head_bias)
|
||||
else:
|
||||
nn.init.xavier_uniform_(m.weight)
|
||||
if m.bias is not None:
|
||||
if 'mlp' in n:
|
||||
nn.init.normal_(m.bias, std=1e-6)
|
||||
else:
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
lecun_normal_(m.weight)
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.zeros_(m.bias)
|
||||
nn.init.ones_(m.weight)
|
||||
if flax:
|
||||
# Flax defaults
|
||||
lecun_normal_(module.weight)
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
else:
|
||||
# like MLP init in vit (my original init)
|
||||
nn.init.xavier_uniform_(module.weight)
|
||||
if module.bias is not None:
|
||||
if 'mlp' in name:
|
||||
nn.init.normal_(module.bias, std=1e-6)
|
||||
else:
|
||||
nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.Conv2d):
|
||||
lecun_normal_(module.weight)
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.ones_(module.weight)
|
||||
nn.init.zeros_(module.bias)
|
||||
elif hasattr(module, 'init_weights'):
|
||||
# NOTE if a parent module contains init_weights method, it can override the init of the
|
||||
# child modules as this will be called in depth-first order.
|
||||
module.init_weights()
|
||||
|
||||
|
||||
def checkpoint_filter_fn(state_dict, model):
|
||||
""" Remap checkpoints if needed """
|
||||
if 'patch_embed.proj.weight' in state_dict:
|
||||
# Remap FB ResMlp models -> timm
|
||||
out_dict = {}
|
||||
for k, v in state_dict.items():
|
||||
k = k.replace('patch_embed.', 'stem.')
|
||||
k = k.replace('attn.', 'linear_tokens.')
|
||||
k = k.replace('mlp.', 'mlp_channels.')
|
||||
k = k.replace('gamma_', 'ls')
|
||||
if k.endswith('.alpha') or k.endswith('.beta'):
|
||||
v = v.reshape(1, 1, -1)
|
||||
out_dict[k] = v
|
||||
return out_dict
|
||||
return state_dict
|
||||
|
||||
|
||||
def _create_mixer(variant, pretrained=False, **kwargs):
|
||||
|
@ -283,6 +352,7 @@ def _create_mixer(variant, pretrained=False, **kwargs):
|
|||
model = build_model_with_cfg(
|
||||
MlpMixer, variant, pretrained,
|
||||
default_cfg=default_cfgs[variant],
|
||||
pretrained_filter_fn=checkpoint_filter_fn,
|
||||
**kwargs)
|
||||
return model
|
||||
|
||||
|
@ -292,7 +362,7 @@ def mixer_s32_224(pretrained=False, **kwargs):
|
|||
""" Mixer-S/32 224x224
|
||||
Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
|
||||
"""
|
||||
model_args = dict(patch_size=32, num_blocks=8, hidden_dim=512, **kwargs)
|
||||
model_args = dict(patch_size=32, num_blocks=8, embed_dim=512, **kwargs)
|
||||
model = _create_mixer('mixer_s32_224', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
@ -302,7 +372,7 @@ def mixer_s16_224(pretrained=False, **kwargs):
|
|||
""" Mixer-S/16 224x224
|
||||
Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
|
||||
"""
|
||||
model_args = dict(patch_size=16, num_blocks=8, hidden_dim=512, **kwargs)
|
||||
model_args = dict(patch_size=16, num_blocks=8, embed_dim=512, **kwargs)
|
||||
model = _create_mixer('mixer_s16_224', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
@ -312,7 +382,7 @@ def mixer_b32_224(pretrained=False, **kwargs):
|
|||
""" Mixer-B/32 224x224
|
||||
Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
|
||||
"""
|
||||
model_args = dict(patch_size=32, num_blocks=12, hidden_dim=768, **kwargs)
|
||||
model_args = dict(patch_size=32, num_blocks=12, embed_dim=768, **kwargs)
|
||||
model = _create_mixer('mixer_b32_224', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
@ -322,7 +392,7 @@ def mixer_b16_224(pretrained=False, **kwargs):
|
|||
""" Mixer-B/16 224x224. ImageNet-1k pretrained weights.
|
||||
Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
|
||||
"""
|
||||
model_args = dict(patch_size=16, num_blocks=12, hidden_dim=768, **kwargs)
|
||||
model_args = dict(patch_size=16, num_blocks=12, embed_dim=768, **kwargs)
|
||||
model = _create_mixer('mixer_b16_224', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
@ -332,7 +402,7 @@ def mixer_b16_224_in21k(pretrained=False, **kwargs):
|
|||
""" Mixer-B/16 224x224. ImageNet-21k pretrained weights.
|
||||
Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
|
||||
"""
|
||||
model_args = dict(patch_size=16, num_blocks=12, hidden_dim=768, **kwargs)
|
||||
model_args = dict(patch_size=16, num_blocks=12, embed_dim=768, **kwargs)
|
||||
model = _create_mixer('mixer_b16_224_in21k', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
@ -342,7 +412,7 @@ def mixer_l32_224(pretrained=False, **kwargs):
|
|||
""" Mixer-L/32 224x224.
|
||||
Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
|
||||
"""
|
||||
model_args = dict(patch_size=32, num_blocks=24, hidden_dim=1024, **kwargs)
|
||||
model_args = dict(patch_size=32, num_blocks=24, embed_dim=1024, **kwargs)
|
||||
model = _create_mixer('mixer_l32_224', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
@ -352,7 +422,7 @@ def mixer_l16_224(pretrained=False, **kwargs):
|
|||
""" Mixer-L/16 224x224. ImageNet-1k pretrained weights.
|
||||
Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
|
||||
"""
|
||||
model_args = dict(patch_size=16, num_blocks=24, hidden_dim=1024, **kwargs)
|
||||
model_args = dict(patch_size=16, num_blocks=24, embed_dim=1024, **kwargs)
|
||||
model = _create_mixer('mixer_l16_224', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
@ -362,35 +432,38 @@ def mixer_l16_224_in21k(pretrained=False, **kwargs):
|
|||
""" Mixer-L/16 224x224. ImageNet-21k pretrained weights.
|
||||
Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
|
||||
"""
|
||||
model_args = dict(patch_size=16, num_blocks=24, hidden_dim=1024, **kwargs)
|
||||
model_args = dict(patch_size=16, num_blocks=24, embed_dim=1024, **kwargs)
|
||||
model = _create_mixer('mixer_l16_224_in21k', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mixer_b16_224_miil(pretrained=False, **kwargs):
|
||||
""" Mixer-B/16 224x224. ImageNet-21k pretrained weights.
|
||||
Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
|
||||
"""
|
||||
model_args = dict(patch_size=16, num_blocks=12, hidden_dim=768, **kwargs)
|
||||
model_args = dict(patch_size=16, num_blocks=12, embed_dim=768, **kwargs)
|
||||
model = _create_mixer('mixer_b16_224_miil', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mixer_b16_224_miil_in21k(pretrained=False, **kwargs):
|
||||
""" Mixer-B/16 224x224. ImageNet-1k pretrained weights.
|
||||
Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
|
||||
"""
|
||||
model_args = dict(patch_size=16, num_blocks=12, hidden_dim=768, **kwargs)
|
||||
model_args = dict(patch_size=16, num_blocks=12, embed_dim=768, **kwargs)
|
||||
model = _create_mixer('mixer_b16_224_miil_in21k', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def gmixer_12_224(pretrained=False, **kwargs):
|
||||
""" Glu-Mixer-12 224x224 (short & fat)
|
||||
""" Glu-Mixer-12 224x224
|
||||
Experiment by Ross Wightman, adding (Si)GLU to MLP-Mixer
|
||||
"""
|
||||
model_args = dict(
|
||||
patch_size=20, num_blocks=12, hidden_dim=512, mlp_ratio=(1.0, 6.0),
|
||||
patch_size=16, num_blocks=12, embed_dim=384, mlp_ratio=(1.0, 4.0),
|
||||
mlp_layer=GluMlp, act_layer=nn.SiLU, **kwargs)
|
||||
model = _create_mixer('gmixer_12_224', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
@ -398,11 +471,11 @@ def gmixer_12_224(pretrained=False, **kwargs):
|
|||
|
||||
@register_model
|
||||
def gmixer_24_224(pretrained=False, **kwargs):
|
||||
""" Glu-Mixer-24 224x224 (tall & slim)
|
||||
""" Glu-Mixer-24 224x224
|
||||
Experiment by Ross Wightman, adding (Si)GLU to MLP-Mixer
|
||||
"""
|
||||
model_args = dict(
|
||||
patch_size=20, num_blocks=24, hidden_dim=384, mlp_ratio=(1.0, 6.0),
|
||||
patch_size=16, num_blocks=24, embed_dim=384, mlp_ratio=(1.0, 4.0),
|
||||
mlp_layer=GluMlp, act_layer=nn.SiLU, **kwargs)
|
||||
model = _create_mixer('gmixer_24_224', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
@ -414,7 +487,7 @@ def resmlp_12_224(pretrained=False, **kwargs):
|
|||
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
|
||||
"""
|
||||
model_args = dict(
|
||||
patch_size=16, num_blocks=12, hidden_dim=384, mlp_ratio=4, block_layer=ResBlock, norm_layer=Affine, **kwargs)
|
||||
patch_size=16, num_blocks=12, embed_dim=384, mlp_ratio=4, block_layer=ResBlock, norm_layer=Affine, **kwargs)
|
||||
model = _create_mixer('resmlp_12_224', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
@ -425,7 +498,8 @@ def resmlp_24_224(pretrained=False, **kwargs):
|
|||
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
|
||||
"""
|
||||
model_args = dict(
|
||||
patch_size=16, num_blocks=24, hidden_dim=384, mlp_ratio=4, block_layer=ResBlock, norm_layer=Affine, **kwargs)
|
||||
patch_size=16, num_blocks=24, embed_dim=384, mlp_ratio=4,
|
||||
block_layer=partial(ResBlock, init_values=1e-5), norm_layer=Affine, **kwargs)
|
||||
model = _create_mixer('resmlp_24_224', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
@ -436,18 +510,90 @@ def resmlp_36_224(pretrained=False, **kwargs):
|
|||
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
|
||||
"""
|
||||
model_args = dict(
|
||||
patch_size=16, num_blocks=36, hidden_dim=384, mlp_ratio=4, block_layer=ResBlock, norm_layer=Affine, **kwargs)
|
||||
patch_size=16, num_blocks=36, embed_dim=384, mlp_ratio=4,
|
||||
block_layer=partial(ResBlock, init_values=1e-6), norm_layer=Affine, **kwargs)
|
||||
model = _create_mixer('resmlp_36_224', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def resmlp_big_24_224(pretrained=False, **kwargs):
|
||||
""" ResMLP-B-24
|
||||
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
|
||||
"""
|
||||
model_args = dict(
|
||||
patch_size=8, num_blocks=24, embed_dim=768, mlp_ratio=4,
|
||||
block_layer=partial(ResBlock, init_values=1e-6), norm_layer=Affine, **kwargs)
|
||||
model = _create_mixer('resmlp_big_24_224', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def resmlp_12_distilled_224(pretrained=False, **kwargs):
|
||||
""" ResMLP-12
|
||||
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
|
||||
"""
|
||||
model_args = dict(
|
||||
patch_size=16, num_blocks=12, embed_dim=384, mlp_ratio=4, block_layer=ResBlock, norm_layer=Affine, **kwargs)
|
||||
model = _create_mixer('resmlp_12_distilled_224', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def resmlp_24_distilled_224(pretrained=False, **kwargs):
|
||||
""" ResMLP-24
|
||||
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
|
||||
"""
|
||||
model_args = dict(
|
||||
patch_size=16, num_blocks=24, embed_dim=384, mlp_ratio=4,
|
||||
block_layer=partial(ResBlock, init_values=1e-5), norm_layer=Affine, **kwargs)
|
||||
model = _create_mixer('resmlp_24_distilled_224', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def resmlp_36_distilled_224(pretrained=False, **kwargs):
|
||||
""" ResMLP-36
|
||||
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
|
||||
"""
|
||||
model_args = dict(
|
||||
patch_size=16, num_blocks=36, embed_dim=384, mlp_ratio=4,
|
||||
block_layer=partial(ResBlock, init_values=1e-6), norm_layer=Affine, **kwargs)
|
||||
model = _create_mixer('resmlp_36_distilled_224', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def resmlp_big_24_distilled_224(pretrained=False, **kwargs):
|
||||
""" ResMLP-B-24
|
||||
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
|
||||
"""
|
||||
model_args = dict(
|
||||
patch_size=8, num_blocks=24, embed_dim=768, mlp_ratio=4,
|
||||
block_layer=partial(ResBlock, init_values=1e-6), norm_layer=Affine, **kwargs)
|
||||
model = _create_mixer('resmlp_big_24_distilled_224', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def resmlp_big_24_224_in22ft1k(pretrained=False, **kwargs):
|
||||
""" ResMLP-B-24
|
||||
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
|
||||
"""
|
||||
model_args = dict(
|
||||
patch_size=8, num_blocks=24, embed_dim=768, mlp_ratio=4,
|
||||
block_layer=partial(ResBlock, init_values=1e-6), norm_layer=Affine, **kwargs)
|
||||
model = _create_mixer('resmlp_big_24_224_in22ft1k', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def gmlp_ti16_224(pretrained=False, **kwargs):
|
||||
""" gMLP-Tiny
|
||||
Paper: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050
|
||||
"""
|
||||
model_args = dict(
|
||||
patch_size=16, num_blocks=30, hidden_dim=128, mlp_ratio=6, block_layer=SpatialGatingBlock,
|
||||
patch_size=16, num_blocks=30, embed_dim=128, mlp_ratio=6, block_layer=SpatialGatingBlock,
|
||||
mlp_layer=GatedMlp, **kwargs)
|
||||
model = _create_mixer('gmlp_ti16_224', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
@ -459,7 +605,7 @@ def gmlp_s16_224(pretrained=False, **kwargs):
|
|||
Paper: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050
|
||||
"""
|
||||
model_args = dict(
|
||||
patch_size=16, num_blocks=30, hidden_dim=256, mlp_ratio=6, block_layer=SpatialGatingBlock,
|
||||
patch_size=16, num_blocks=30, embed_dim=256, mlp_ratio=6, block_layer=SpatialGatingBlock,
|
||||
mlp_layer=GatedMlp, **kwargs)
|
||||
model = _create_mixer('gmlp_s16_224', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
@ -471,7 +617,7 @@ def gmlp_b16_224(pretrained=False, **kwargs):
|
|||
Paper: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050
|
||||
"""
|
||||
model_args = dict(
|
||||
patch_size=16, num_blocks=30, hidden_dim=512, mlp_ratio=6, block_layer=SpatialGatingBlock,
|
||||
patch_size=16, num_blocks=30, embed_dim=512, mlp_ratio=6, block_layer=SpatialGatingBlock,
|
||||
mlp_layer=GatedMlp, **kwargs)
|
||||
model = _create_mixer('gmlp_b16_224', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
|
|
@ -119,6 +119,7 @@ class MobileNetV3(nn.Module):
|
|||
num_pooled_chs = head_chs * self.global_pool.feat_mult()
|
||||
self.conv_head = create_conv2d(num_pooled_chs, self.num_features, 1, padding=pad_type, bias=head_bias)
|
||||
self.act2 = act_layer(inplace=True)
|
||||
self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
|
||||
self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
efficientnet_init_weights(self)
|
||||
|
@ -137,6 +138,7 @@ class MobileNetV3(nn.Module):
|
|||
self.num_classes = num_classes
|
||||
# cannot meaningfully change pooling of efficient head after creation
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||
self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
|
||||
self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def forward_features(self, x):
|
||||
|
@ -151,8 +153,7 @@ class MobileNetV3(nn.Module):
|
|||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
if not self.global_pool.is_identity():
|
||||
x = x.flatten(1)
|
||||
x = self.flatten(x)
|
||||
if self.drop_rate > 0.:
|
||||
x = F.dropout(x, p=self.drop_rate, training=self.training)
|
||||
return self.classifier(x)
|
||||
|
|
|
@ -111,11 +111,11 @@ default_cfgs = dict(
|
|||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecanfnet_l1_ra2-7dce93cd.pth',
|
||||
pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 320, 320), crop_pct=1.0),
|
||||
eca_nfnet_l2=_dcfg(
|
||||
url='',
|
||||
pool_size=(9, 9), input_size=(3, 288, 288), test_input_size=(3, 352, 352), crop_pct=1.0),
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecanfnet_l2_ra3-da781a61.pth',
|
||||
pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 384, 384), crop_pct=1.0),
|
||||
eca_nfnet_l3=_dcfg(
|
||||
url='',
|
||||
pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 384, 384), crop_pct=1.0),
|
||||
pool_size=(11, 11), input_size=(3, 352, 352), test_input_size=(3, 448, 448), crop_pct=1.0),
|
||||
|
||||
nf_regnet_b0=_dcfg(
|
||||
url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv'),
|
||||
|
@ -166,6 +166,7 @@ class NfCfg:
|
|||
extra_conv: bool = False # extra 3x3 bottleneck convolution for NFNet models
|
||||
gamma_in_act: bool = False
|
||||
same_padding: bool = False
|
||||
std_conv_eps: float = 1e-5
|
||||
skipinit: bool = False # disabled by default, non-trivial performance impact
|
||||
zero_init_fc: bool = False
|
||||
act_layer: str = 'silu'
|
||||
|
@ -209,6 +210,7 @@ def _dm_nfnet_cfg(depths, channels=(256, 512, 1536, 1536), act_layer='gelu', ski
|
|||
return cfg
|
||||
|
||||
|
||||
|
||||
model_cfgs = dict(
|
||||
# NFNet-F models w/ GELU compatible with DeepMind weights
|
||||
dm_nfnet_f0=_dm_nfnet_cfg(depths=(1, 2, 6, 3)),
|
||||
|
@ -482,10 +484,10 @@ class NormFreeNet(nn.Module):
|
|||
conv_layer = ScaledStdConv2dSame if cfg.same_padding else ScaledStdConv2d
|
||||
if cfg.gamma_in_act:
|
||||
act_layer = act_with_gamma(cfg.act_layer, gamma=_nonlin_gamma[cfg.act_layer])
|
||||
conv_layer = partial(conv_layer, eps=1e-4) # DM weights better with higher eps
|
||||
conv_layer = partial(conv_layer, eps=cfg.std_conv_eps)
|
||||
else:
|
||||
act_layer = get_act_layer(cfg.act_layer)
|
||||
conv_layer = partial(conv_layer, gamma=_nonlin_gamma[cfg.act_layer])
|
||||
conv_layer = partial(conv_layer, gamma=_nonlin_gamma[cfg.act_layer], eps=cfg.std_conv_eps)
|
||||
attn_layer = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None
|
||||
|
||||
stem_chs = make_divisible((cfg.stem_chs or cfg.channels[0]) * cfg.width_factor, cfg.ch_div)
|
||||
|
|
|
@ -186,12 +186,13 @@ class PoolingVisionTransformer(nn.Module):
|
|||
]
|
||||
self.transformers = SequentialTuple(*transformers)
|
||||
self.norm = nn.LayerNorm(base_dims[-1] * heads[-1], eps=1e-6)
|
||||
self.embed_dim = base_dims[-1] * heads[-1]
|
||||
self.num_features = self.embed_dim = base_dims[-1] * heads[-1]
|
||||
|
||||
# Classifier head
|
||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) \
|
||||
if num_classes > 0 and distilled else nn.Identity()
|
||||
self.head_dist = None
|
||||
if distilled:
|
||||
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
trunc_normal_(self.pos_embed, std=.02)
|
||||
trunc_normal_(self.cls_token, std=.02)
|
||||
|
@ -207,13 +208,16 @@ class PoolingVisionTransformer(nn.Module):
|
|||
return {'pos_embed', 'cls_token'}
|
||||
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
if self.head_dist is not None:
|
||||
return self.head, self.head_dist
|
||||
else:
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=''):
|
||||
self.num_classes = num_classes
|
||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) \
|
||||
if num_classes > 0 and self.num_tokens == 2 else nn.Identity()
|
||||
if self.head_dist is not None:
|
||||
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.patch_embed(x)
|
||||
|
@ -221,19 +225,21 @@ class PoolingVisionTransformer(nn.Module):
|
|||
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
|
||||
x, cls_tokens = self.transformers((x, cls_tokens))
|
||||
cls_tokens = self.norm(cls_tokens)
|
||||
return cls_tokens
|
||||
if self.head_dist is not None:
|
||||
return cls_tokens[:, 0], cls_tokens[:, 1]
|
||||
else:
|
||||
return cls_tokens[:, 0]
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
x_cls = self.head(x[:, 0])
|
||||
if self.num_tokens > 1:
|
||||
x_dist = self.head_dist(x[:, 1])
|
||||
if self.head_dist is not None:
|
||||
x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple
|
||||
if self.training and not torch.jit.is_scripting():
|
||||
return x_cls, x_dist
|
||||
return x, x_dist
|
||||
else:
|
||||
return (x_cls + x_dist) / 2
|
||||
return (x + x_dist) / 2
|
||||
else:
|
||||
return x_cls
|
||||
return self.head(x)
|
||||
|
||||
|
||||
def checkpoint_filter_fn(state_dict, model):
|
||||
|
|
|
@ -65,11 +65,18 @@ def list_models(filter='', module='', pretrained=False, exclude_filters='', name
|
|||
model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module
|
||||
"""
|
||||
if module:
|
||||
models = list(_module_to_models[module])
|
||||
all_models = list(_module_to_models[module])
|
||||
else:
|
||||
models = _model_entrypoints.keys()
|
||||
all_models = _model_entrypoints.keys()
|
||||
if filter:
|
||||
models = fnmatch.filter(models, filter) # include these models
|
||||
models = []
|
||||
include_filters = filter if isinstance(filter, (tuple, list)) else [filter]
|
||||
for f in include_filters:
|
||||
include_models = fnmatch.filter(all_models, f) # include these models
|
||||
if len(include_models):
|
||||
models = set(models).union(include_models)
|
||||
else:
|
||||
models = all_models
|
||||
if exclude_filters:
|
||||
if not isinstance(exclude_filters, (tuple, list)):
|
||||
exclude_filters = [exclude_filters]
|
||||
|
|
|
@ -638,12 +638,15 @@ class ResNet(nn.Module):
|
|||
self.num_features = 512 * block.expansion
|
||||
self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
|
||||
|
||||
self.init_weights(zero_init_last_bn=zero_init_last_bn)
|
||||
|
||||
def init_weights(self, zero_init_last_bn=True):
|
||||
for n, m in self.named_modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1.)
|
||||
nn.init.constant_(m.bias, 0.)
|
||||
nn.init.ones_(m.weight)
|
||||
nn.init.zeros_(m.bias)
|
||||
if zero_init_last_bn:
|
||||
for m in self.modules():
|
||||
if hasattr(m, 'zero_init_last_bn'):
|
||||
|
|
|
@ -11,6 +11,7 @@ https://github.com/google-research/vision_transformer
|
|||
Thanks to the Google team for the above two repositories and associated papers:
|
||||
* Big Transfer (BiT): General Visual Representation Learning - https://arxiv.org/abs/1912.11370
|
||||
* An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale - https://arxiv.org/abs/2010.11929
|
||||
* Knowledge distillation: A good teacher is patient and consistent - https://arxiv.org/abs/2106.05237
|
||||
|
||||
Original copyright of Google code below, modifications by Ross Wightman, Copyright 2020.
|
||||
"""
|
||||
|
@ -35,16 +36,16 @@ import torch.nn as nn
|
|||
from functools import partial
|
||||
|
||||
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
from .helpers import build_model_with_cfg
|
||||
from .helpers import build_model_with_cfg, named_apply, adapt_input_conv
|
||||
from .registry import register_model
|
||||
from .layers import GroupNormAct, ClassifierHead, DropPath, AvgPool2dSame, create_pool2d, StdConv2d
|
||||
from .layers import GroupNormAct, ClassifierHead, DropPath, AvgPool2dSame, create_pool2d, StdConv2d, create_conv2d
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 480, 480), 'pool_size': (7, 7),
|
||||
'crop_pct': 1.0, 'interpolation': 'bilinear',
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||
'crop_pct': 0.875, 'interpolation': 'bilinear',
|
||||
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
|
||||
'first_conv': 'stem.conv', 'classifier': 'head.fc',
|
||||
**kwargs
|
||||
|
@ -54,17 +55,23 @@ def _cfg(url='', **kwargs):
|
|||
default_cfgs = {
|
||||
# pretrained on imagenet21k, finetuned on imagenet1k
|
||||
'resnetv2_50x1_bitm': _cfg(
|
||||
url='https://storage.googleapis.com/bit_models/BiT-M-R50x1-ILSVRC2012.npz'),
|
||||
url='https://storage.googleapis.com/bit_models/BiT-M-R50x1-ILSVRC2012.npz',
|
||||
input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0),
|
||||
'resnetv2_50x3_bitm': _cfg(
|
||||
url='https://storage.googleapis.com/bit_models/BiT-M-R50x3-ILSVRC2012.npz'),
|
||||
url='https://storage.googleapis.com/bit_models/BiT-M-R50x3-ILSVRC2012.npz',
|
||||
input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0),
|
||||
'resnetv2_101x1_bitm': _cfg(
|
||||
url='https://storage.googleapis.com/bit_models/BiT-M-R101x1-ILSVRC2012.npz'),
|
||||
url='https://storage.googleapis.com/bit_models/BiT-M-R101x1-ILSVRC2012.npz',
|
||||
input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0),
|
||||
'resnetv2_101x3_bitm': _cfg(
|
||||
url='https://storage.googleapis.com/bit_models/BiT-M-R101x3-ILSVRC2012.npz'),
|
||||
url='https://storage.googleapis.com/bit_models/BiT-M-R101x3-ILSVRC2012.npz',
|
||||
input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0),
|
||||
'resnetv2_152x2_bitm': _cfg(
|
||||
url='https://storage.googleapis.com/bit_models/BiT-M-R152x2-ILSVRC2012.npz'),
|
||||
url='https://storage.googleapis.com/bit_models/BiT-M-R152x2-ILSVRC2012.npz',
|
||||
input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0),
|
||||
'resnetv2_152x4_bitm': _cfg(
|
||||
url='https://storage.googleapis.com/bit_models/BiT-M-R152x4-ILSVRC2012.npz'),
|
||||
url='https://storage.googleapis.com/bit_models/BiT-M-R152x4-ILSVRC2012.npz',
|
||||
input_size=(3, 480, 480), pool_size=(15, 15), crop_pct=1.0), # only one at 480x480?
|
||||
|
||||
# trained on imagenet-21k
|
||||
'resnetv2_50x1_bitm_in21k': _cfg(
|
||||
|
@ -86,20 +93,20 @@ default_cfgs = {
|
|||
url='https://storage.googleapis.com/bit_models/BiT-M-R152x4.npz',
|
||||
num_classes=21843),
|
||||
|
||||
'resnetv2_50x1_bit_distilled': _cfg(
|
||||
url='https://storage.googleapis.com/bit_models/distill/R50x1_224.npz',
|
||||
interpolation='bicubic'),
|
||||
'resnetv2_152x2_bit_teacher': _cfg(
|
||||
url='https://storage.googleapis.com/bit_models/distill/R152x2_T_224.npz',
|
||||
interpolation='bicubic'),
|
||||
'resnetv2_152x2_bit_teacher_384': _cfg(
|
||||
url='https://storage.googleapis.com/bit_models/distill/R152x2_T_384.npz',
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, interpolation='bicubic'),
|
||||
|
||||
# trained on imagenet-1k, NOTE not overly interesting set of weights, leaving disabled for now
|
||||
# 'resnetv2_50x1_bits': _cfg(
|
||||
# url='https://storage.googleapis.com/bit_models/BiT-S-R50x1.npz'),
|
||||
# 'resnetv2_50x3_bits': _cfg(
|
||||
# url='https://storage.googleapis.com/bit_models/BiT-S-R50x3.npz'),
|
||||
# 'resnetv2_101x1_bits': _cfg(
|
||||
# url='https://storage.googleapis.com/bit_models/BiT-S-R101x3.npz'),
|
||||
# 'resnetv2_101x3_bits': _cfg(
|
||||
# url='https://storage.googleapis.com/bit_models/BiT-S-R101x3.npz'),
|
||||
# 'resnetv2_152x2_bits': _cfg(
|
||||
# url='https://storage.googleapis.com/bit_models/BiT-S-R152x2.npz'),
|
||||
# 'resnetv2_152x4_bits': _cfg(
|
||||
# url='https://storage.googleapis.com/bit_models/BiT-S-R152x4.npz'),
|
||||
'resnetv2_50': _cfg(
|
||||
interpolation='bicubic'),
|
||||
'resnetv2_50d': _cfg(
|
||||
interpolation='bicubic', first_conv='stem.conv1'),
|
||||
}
|
||||
|
||||
|
||||
|
@ -111,13 +118,6 @@ def make_div(v, divisor=8):
|
|||
return new_v
|
||||
|
||||
|
||||
def tf2th(conv_weights):
|
||||
"""Possibly convert HWIO to OIHW."""
|
||||
if conv_weights.ndim == 4:
|
||||
conv_weights = conv_weights.transpose([3, 2, 0, 1])
|
||||
return torch.from_numpy(conv_weights)
|
||||
|
||||
|
||||
class PreActBottleneck(nn.Module):
|
||||
"""Pre-activation (v2) bottleneck block.
|
||||
|
||||
|
@ -152,6 +152,9 @@ class PreActBottleneck(nn.Module):
|
|||
self.conv3 = conv_layer(mid_chs, out_chs, 1)
|
||||
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
|
||||
|
||||
def zero_init_last_bn(self):
|
||||
nn.init.zeros_(self.norm3.weight)
|
||||
|
||||
def forward(self, x):
|
||||
x_preact = self.norm1(x)
|
||||
|
||||
|
@ -198,6 +201,9 @@ class Bottleneck(nn.Module):
|
|||
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
|
||||
self.act3 = act_layer(inplace=True)
|
||||
|
||||
def zero_init_last_bn(self):
|
||||
nn.init.zeros_(self.norm3.weight)
|
||||
|
||||
def forward(self, x):
|
||||
# shortcut branch
|
||||
shortcut = x
|
||||
|
@ -285,14 +291,17 @@ def create_resnetv2_stem(
|
|||
# A 3 deep 3x3 conv stack as in ResNet V1D models
|
||||
mid_chs = out_chs // 2
|
||||
stem['conv1'] = conv_layer(in_chs, mid_chs, kernel_size=3, stride=2)
|
||||
stem['norm1'] = norm_layer(mid_chs)
|
||||
stem['conv2'] = conv_layer(mid_chs, mid_chs, kernel_size=3, stride=1)
|
||||
stem['norm2'] = norm_layer(mid_chs)
|
||||
stem['conv3'] = conv_layer(mid_chs, out_chs, kernel_size=3, stride=1)
|
||||
if not preact:
|
||||
stem['norm3'] = norm_layer(out_chs)
|
||||
else:
|
||||
# The usual 7x7 stem conv
|
||||
stem['conv'] = conv_layer(in_chs, out_chs, kernel_size=7, stride=2)
|
||||
|
||||
if not preact:
|
||||
stem['norm'] = norm_layer(out_chs)
|
||||
if not preact:
|
||||
stem['norm'] = norm_layer(out_chs)
|
||||
|
||||
if 'fixed' in stem_type:
|
||||
# 'fixed' SAME padding approximation that is used in BiT models
|
||||
|
@ -312,11 +321,12 @@ class ResNetV2(nn.Module):
|
|||
"""Implementation of Pre-activation (v2) ResNet mode.
|
||||
"""
|
||||
|
||||
def __init__(self, layers, channels=(256, 512, 1024, 2048),
|
||||
num_classes=1000, in_chans=3, global_pool='avg', output_stride=32,
|
||||
width_factor=1, stem_chs=64, stem_type='', avg_down=False, preact=True,
|
||||
act_layer=nn.ReLU, conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32),
|
||||
drop_rate=0., drop_path_rate=0.):
|
||||
def __init__(
|
||||
self, layers, channels=(256, 512, 1024, 2048),
|
||||
num_classes=1000, in_chans=3, global_pool='avg', output_stride=32,
|
||||
width_factor=1, stem_chs=64, stem_type='', avg_down=False, preact=True,
|
||||
act_layer=nn.ReLU, conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32),
|
||||
drop_rate=0., drop_path_rate=0., zero_init_last_bn=True):
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
self.drop_rate = drop_rate
|
||||
|
@ -354,12 +364,14 @@ class ResNetV2(nn.Module):
|
|||
self.head = ClassifierHead(
|
||||
self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate, use_conv=True)
|
||||
|
||||
for n, m in self.named_modules():
|
||||
if isinstance(m, nn.Linear) or ('.fc' in n and isinstance(m, nn.Conv2d)):
|
||||
nn.init.normal_(m.weight, mean=0.0, std=0.01)
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
self.init_weights(zero_init_last_bn=zero_init_last_bn)
|
||||
|
||||
def init_weights(self, zero_init_last_bn=True):
|
||||
named_apply(partial(_init_weights, zero_init_last_bn=zero_init_last_bn), self)
|
||||
|
||||
@torch.jit.ignore()
|
||||
def load_pretrained(self, checkpoint_path, prefix='resnet/'):
|
||||
_load_weights(self, checkpoint_path, prefix)
|
||||
|
||||
def get_classifier(self):
|
||||
return self.head.fc
|
||||
|
@ -378,41 +390,59 @@ class ResNetV2(nn.Module):
|
|||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
x = self.head(x)
|
||||
if not self.head.global_pool.is_identity():
|
||||
x = x.flatten(1) # conv classifier, flatten if pooling isn't pass-through (disabled)
|
||||
return x
|
||||
|
||||
def load_pretrained(self, checkpoint_path, prefix='resnet/'):
|
||||
import numpy as np
|
||||
weights = np.load(checkpoint_path)
|
||||
with torch.no_grad():
|
||||
stem_conv_w = tf2th(weights[f'{prefix}root_block/standardized_conv2d/kernel'])
|
||||
if self.stem.conv.weight.shape[1] == 1:
|
||||
self.stem.conv.weight.copy_(stem_conv_w.sum(dim=1, keepdim=True))
|
||||
# FIXME handle > 3 in_chans?
|
||||
else:
|
||||
self.stem.conv.weight.copy_(stem_conv_w)
|
||||
self.norm.weight.copy_(tf2th(weights[f'{prefix}group_norm/gamma']))
|
||||
self.norm.bias.copy_(tf2th(weights[f'{prefix}group_norm/beta']))
|
||||
if self.head.fc.weight.shape[0] == weights[f'{prefix}head/conv2d/kernel'].shape[-1]:
|
||||
self.head.fc.weight.copy_(tf2th(weights[f'{prefix}head/conv2d/kernel']))
|
||||
self.head.fc.bias.copy_(tf2th(weights[f'{prefix}head/conv2d/bias']))
|
||||
for i, (sname, stage) in enumerate(self.stages.named_children()):
|
||||
for j, (bname, block) in enumerate(stage.blocks.named_children()):
|
||||
convname = 'standardized_conv2d'
|
||||
block_prefix = f'{prefix}block{i + 1}/unit{j + 1:02d}/'
|
||||
block.conv1.weight.copy_(tf2th(weights[f'{block_prefix}a/{convname}/kernel']))
|
||||
block.conv2.weight.copy_(tf2th(weights[f'{block_prefix}b/{convname}/kernel']))
|
||||
block.conv3.weight.copy_(tf2th(weights[f'{block_prefix}c/{convname}/kernel']))
|
||||
block.norm1.weight.copy_(tf2th(weights[f'{block_prefix}a/group_norm/gamma']))
|
||||
block.norm2.weight.copy_(tf2th(weights[f'{block_prefix}b/group_norm/gamma']))
|
||||
block.norm3.weight.copy_(tf2th(weights[f'{block_prefix}c/group_norm/gamma']))
|
||||
block.norm1.bias.copy_(tf2th(weights[f'{block_prefix}a/group_norm/beta']))
|
||||
block.norm2.bias.copy_(tf2th(weights[f'{block_prefix}b/group_norm/beta']))
|
||||
block.norm3.bias.copy_(tf2th(weights[f'{block_prefix}c/group_norm/beta']))
|
||||
if block.downsample is not None:
|
||||
w = weights[f'{block_prefix}a/proj/{convname}/kernel']
|
||||
block.downsample.conv.weight.copy_(tf2th(w))
|
||||
|
||||
def _init_weights(module: nn.Module, name: str = '', zero_init_last_bn=True):
|
||||
if isinstance(module, nn.Linear) or ('head.fc' in name and isinstance(module, nn.Conv2d)):
|
||||
nn.init.normal_(module.weight, mean=0.0, std=0.01)
|
||||
nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, (nn.BatchNorm2d, nn.LayerNorm, nn.GroupNorm)):
|
||||
nn.init.ones_(module.weight)
|
||||
nn.init.zeros_(module.bias)
|
||||
elif zero_init_last_bn and hasattr(module, 'zero_init_last_bn'):
|
||||
module.zero_init_last_bn()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def _load_weights(model: nn.Module, checkpoint_path: str, prefix: str = 'resnet/'):
|
||||
import numpy as np
|
||||
|
||||
def t2p(conv_weights):
|
||||
"""Possibly convert HWIO to OIHW."""
|
||||
if conv_weights.ndim == 4:
|
||||
conv_weights = conv_weights.transpose([3, 2, 0, 1])
|
||||
return torch.from_numpy(conv_weights)
|
||||
|
||||
weights = np.load(checkpoint_path)
|
||||
stem_conv_w = adapt_input_conv(
|
||||
model.stem.conv.weight.shape[1], t2p(weights[f'{prefix}root_block/standardized_conv2d/kernel']))
|
||||
model.stem.conv.weight.copy_(stem_conv_w)
|
||||
model.norm.weight.copy_(t2p(weights[f'{prefix}group_norm/gamma']))
|
||||
model.norm.bias.copy_(t2p(weights[f'{prefix}group_norm/beta']))
|
||||
if model.head.fc.weight.shape[0] == weights[f'{prefix}head/conv2d/kernel'].shape[-1]:
|
||||
model.head.fc.weight.copy_(t2p(weights[f'{prefix}head/conv2d/kernel']))
|
||||
model.head.fc.bias.copy_(t2p(weights[f'{prefix}head/conv2d/bias']))
|
||||
for i, (sname, stage) in enumerate(model.stages.named_children()):
|
||||
for j, (bname, block) in enumerate(stage.blocks.named_children()):
|
||||
cname = 'standardized_conv2d'
|
||||
block_prefix = f'{prefix}block{i + 1}/unit{j + 1:02d}/'
|
||||
block.conv1.weight.copy_(t2p(weights[f'{block_prefix}a/{cname}/kernel']))
|
||||
block.conv2.weight.copy_(t2p(weights[f'{block_prefix}b/{cname}/kernel']))
|
||||
block.conv3.weight.copy_(t2p(weights[f'{block_prefix}c/{cname}/kernel']))
|
||||
block.norm1.weight.copy_(t2p(weights[f'{block_prefix}a/group_norm/gamma']))
|
||||
block.norm2.weight.copy_(t2p(weights[f'{block_prefix}b/group_norm/gamma']))
|
||||
block.norm3.weight.copy_(t2p(weights[f'{block_prefix}c/group_norm/gamma']))
|
||||
block.norm1.bias.copy_(t2p(weights[f'{block_prefix}a/group_norm/beta']))
|
||||
block.norm2.bias.copy_(t2p(weights[f'{block_prefix}b/group_norm/beta']))
|
||||
block.norm3.bias.copy_(t2p(weights[f'{block_prefix}c/group_norm/beta']))
|
||||
if block.downsample is not None:
|
||||
w = weights[f'{block_prefix}a/proj/{cname}/kernel']
|
||||
block.downsample.conv.weight.copy_(t2p(w))
|
||||
|
||||
|
||||
def _create_resnetv2(variant, pretrained=False, **kwargs):
|
||||
|
@ -425,130 +455,126 @@ def _create_resnetv2(variant, pretrained=False, **kwargs):
|
|||
**kwargs)
|
||||
|
||||
|
||||
def _create_resnetv2_bit(variant, pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
variant, pretrained=pretrained, stem_type='fixed', conv_layer=partial(StdConv2d, eps=1e-8), **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_50x1_bitm(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_50x1_bitm', pretrained=pretrained,
|
||||
layers=[3, 4, 6, 3], width_factor=1, stem_type='fixed', **kwargs)
|
||||
return _create_resnetv2_bit(
|
||||
'resnetv2_50x1_bitm', pretrained=pretrained, layers=[3, 4, 6, 3], width_factor=1, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_50x3_bitm(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_50x3_bitm', pretrained=pretrained,
|
||||
layers=[3, 4, 6, 3], width_factor=3, stem_type='fixed', **kwargs)
|
||||
return _create_resnetv2_bit(
|
||||
'resnetv2_50x3_bitm', pretrained=pretrained, layers=[3, 4, 6, 3], width_factor=3, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_101x1_bitm(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_101x1_bitm', pretrained=pretrained,
|
||||
layers=[3, 4, 23, 3], width_factor=1, stem_type='fixed', **kwargs)
|
||||
return _create_resnetv2_bit(
|
||||
'resnetv2_101x1_bitm', pretrained=pretrained, layers=[3, 4, 23, 3], width_factor=1, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_101x3_bitm(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_101x3_bitm', pretrained=pretrained,
|
||||
layers=[3, 4, 23, 3], width_factor=3, stem_type='fixed', **kwargs)
|
||||
return _create_resnetv2_bit(
|
||||
'resnetv2_101x3_bitm', pretrained=pretrained, layers=[3, 4, 23, 3], width_factor=3, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_152x2_bitm(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_152x2_bitm', pretrained=pretrained,
|
||||
layers=[3, 8, 36, 3], width_factor=2, stem_type='fixed', **kwargs)
|
||||
return _create_resnetv2_bit(
|
||||
'resnetv2_152x2_bitm', pretrained=pretrained, layers=[3, 8, 36, 3], width_factor=2, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_152x4_bitm(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_152x4_bitm', pretrained=pretrained,
|
||||
layers=[3, 8, 36, 3], width_factor=4, stem_type='fixed', **kwargs)
|
||||
return _create_resnetv2_bit(
|
||||
'resnetv2_152x4_bitm', pretrained=pretrained, layers=[3, 8, 36, 3], width_factor=4, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_50x1_bitm_in21k(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
return _create_resnetv2_bit(
|
||||
'resnetv2_50x1_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
|
||||
layers=[3, 4, 6, 3], width_factor=1, stem_type='fixed', **kwargs)
|
||||
layers=[3, 4, 6, 3], width_factor=1, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_50x3_bitm_in21k(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
return _create_resnetv2_bit(
|
||||
'resnetv2_50x3_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
|
||||
layers=[3, 4, 6, 3], width_factor=3, stem_type='fixed', **kwargs)
|
||||
layers=[3, 4, 6, 3], width_factor=3, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_101x1_bitm_in21k(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_101x1_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
|
||||
layers=[3, 4, 23, 3], width_factor=1, stem_type='fixed', **kwargs)
|
||||
layers=[3, 4, 23, 3], width_factor=1, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_101x3_bitm_in21k(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
return _create_resnetv2_bit(
|
||||
'resnetv2_101x3_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
|
||||
layers=[3, 4, 23, 3], width_factor=3, stem_type='fixed', **kwargs)
|
||||
layers=[3, 4, 23, 3], width_factor=3, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_152x2_bitm_in21k(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
return _create_resnetv2_bit(
|
||||
'resnetv2_152x2_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
|
||||
layers=[3, 8, 36, 3], width_factor=2, stem_type='fixed', **kwargs)
|
||||
layers=[3, 8, 36, 3], width_factor=2, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_152x4_bitm_in21k(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
return _create_resnetv2_bit(
|
||||
'resnetv2_152x4_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
|
||||
layers=[3, 8, 36, 3], width_factor=4, stem_type='fixed', **kwargs)
|
||||
layers=[3, 8, 36, 3], width_factor=4, **kwargs)
|
||||
|
||||
|
||||
# NOTE the 'S' versions of the model weights arent as interesting as original 21k or transfer to 1K M.
|
||||
@register_model
|
||||
def resnetv2_50x1_bit_distilled(pretrained=False, **kwargs):
|
||||
""" ResNetV2-50x1-BiT Distilled
|
||||
Paper: Knowledge distillation: A good teacher is patient and consistent - https://arxiv.org/abs/2106.05237
|
||||
"""
|
||||
return _create_resnetv2_bit(
|
||||
'resnetv2_50x1_bit_distilled', pretrained=pretrained, layers=[3, 4, 6, 3], width_factor=1, **kwargs)
|
||||
|
||||
# @register_model
|
||||
# def resnetv2_50x1_bits(pretrained=False, **kwargs):
|
||||
# return _create_resnetv2(
|
||||
# 'resnetv2_50x1_bits', pretrained=pretrained,
|
||||
# layers=[3, 4, 6, 3], width_factor=1, stem_type='fixed', **kwargs)
|
||||
#
|
||||
#
|
||||
# @register_model
|
||||
# def resnetv2_50x3_bits(pretrained=False, **kwargs):
|
||||
# return _create_resnetv2(
|
||||
# 'resnetv2_50x3_bits', pretrained=pretrained,
|
||||
# layers=[3, 4, 6, 3], width_factor=3, stem_type='fixed', **kwargs)
|
||||
#
|
||||
#
|
||||
# @register_model
|
||||
# def resnetv2_101x1_bits(pretrained=False, **kwargs):
|
||||
# return _create_resnetv2(
|
||||
# 'resnetv2_101x1_bits', pretrained=pretrained,
|
||||
# layers=[3, 4, 23, 3], width_factor=1, stem_type='fixed', **kwargs)
|
||||
#
|
||||
#
|
||||
# @register_model
|
||||
# def resnetv2_101x3_bits(pretrained=False, **kwargs):
|
||||
# return _create_resnetv2(
|
||||
# 'resnetv2_101x3_bits', pretrained=pretrained,
|
||||
# layers=[3, 4, 23, 3], width_factor=3, stem_type='fixed', **kwargs)
|
||||
#
|
||||
#
|
||||
# @register_model
|
||||
# def resnetv2_152x2_bits(pretrained=False, **kwargs):
|
||||
# return _create_resnetv2(
|
||||
# 'resnetv2_152x2_bits', pretrained=pretrained,
|
||||
# layers=[3, 8, 36, 3], width_factor=2, stem_type='fixed', **kwargs)
|
||||
#
|
||||
#
|
||||
# @register_model
|
||||
# def resnetv2_152x4_bits(pretrained=False, **kwargs):
|
||||
# return _create_resnetv2(
|
||||
# 'resnetv2_152x4_bits', pretrained=pretrained,
|
||||
# layers=[3, 8, 36, 3], width_factor=4, stem_type='fixed', **kwargs)
|
||||
#
|
||||
|
||||
@register_model
|
||||
def resnetv2_152x2_bit_teacher(pretrained=False, **kwargs):
|
||||
""" ResNetV2-152x2-BiT Teacher
|
||||
Paper: Knowledge distillation: A good teacher is patient and consistent - https://arxiv.org/abs/2106.05237
|
||||
"""
|
||||
return _create_resnetv2_bit(
|
||||
'resnetv2_152x2_bit_teacher', pretrained=pretrained, layers=[3, 8, 36, 3], width_factor=2, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_152x2_bit_teacher_384(pretrained=False, **kwargs):
|
||||
""" ResNetV2-152xx-BiT Teacher @ 384x384
|
||||
Paper: Knowledge distillation: A good teacher is patient and consistent - https://arxiv.org/abs/2106.05237
|
||||
"""
|
||||
return _create_resnetv2_bit(
|
||||
'resnetv2_152x2_bit_teacher_384', pretrained=pretrained, layers=[3, 8, 36, 3], width_factor=2, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_50(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_50', pretrained=pretrained,
|
||||
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=nn.BatchNorm2d, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_50d(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_50d', pretrained=pretrained,
|
||||
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=nn.BatchNorm2d,
|
||||
stem_type='deep', avg_down=True, **kwargs)
|
||||
|
|
|
@ -126,19 +126,18 @@ class WindowAttention(nn.Module):
|
|||
window_size (tuple[int]): The height and width of the window.
|
||||
num_heads (int): Number of attention heads.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
||||
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
||||
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
||||
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
||||
"""
|
||||
|
||||
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
|
||||
def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.):
|
||||
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.window_size = window_size # Wh, Ww
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = qk_scale or head_dim ** -0.5
|
||||
self.scale = head_dim ** -0.5
|
||||
|
||||
# define a parameter table of relative position bias
|
||||
self.relative_position_bias_table = nn.Parameter(
|
||||
|
@ -210,7 +209,6 @@ class SwinTransformerBlock(nn.Module):
|
|||
shift_size (int): Shift size for SW-MSA.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
||||
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
||||
drop (float, optional): Dropout rate. Default: 0.0
|
||||
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
||||
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
||||
|
@ -219,7 +217,7 @@ class SwinTransformerBlock(nn.Module):
|
|||
"""
|
||||
|
||||
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
|
||||
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
|
||||
mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
|
||||
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
@ -236,8 +234,8 @@ class SwinTransformerBlock(nn.Module):
|
|||
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = WindowAttention(
|
||||
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
|
||||
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
||||
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, qkv_bias=qkv_bias,
|
||||
attn_drop=attn_drop, proj_drop=drop)
|
||||
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
|
@ -369,7 +367,6 @@ class BasicLayer(nn.Module):
|
|||
window_size (int): Local window size.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
||||
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
||||
drop (float, optional): Dropout rate. Default: 0.0
|
||||
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
||||
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
||||
|
@ -379,7 +376,7 @@ class BasicLayer(nn.Module):
|
|||
"""
|
||||
|
||||
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
|
||||
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
|
||||
mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
|
||||
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
|
||||
|
||||
super().__init__()
|
||||
|
@ -390,14 +387,11 @@ class BasicLayer(nn.Module):
|
|||
|
||||
# build blocks
|
||||
self.blocks = nn.ModuleList([
|
||||
SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
|
||||
num_heads=num_heads, window_size=window_size,
|
||||
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
||||
drop=drop, attn_drop=attn_drop,
|
||||
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
||||
norm_layer=norm_layer)
|
||||
SwinTransformerBlock(
|
||||
dim=dim, input_resolution=input_resolution, num_heads=num_heads, window_size=window_size,
|
||||
shift_size=0 if (i % 2 == 0) else window_size // 2, mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias, drop=drop, attn_drop=attn_drop,
|
||||
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer)
|
||||
for i in range(depth)])
|
||||
|
||||
# patch merging layer
|
||||
|
@ -436,7 +430,6 @@ class SwinTransformer(nn.Module):
|
|||
window_size (int): Window size. Default: 7
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
|
||||
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
||||
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
|
||||
drop_rate (float): Dropout rate. Default: 0
|
||||
attn_drop_rate (float): Attention dropout rate. Default: 0
|
||||
drop_path_rate (float): Stochastic depth rate. Default: 0.1
|
||||
|
@ -448,7 +441,7 @@ class SwinTransformer(nn.Module):
|
|||
|
||||
def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
|
||||
embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24),
|
||||
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
|
||||
window_size=7, mlp_ratio=4., qkv_bias=True,
|
||||
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
||||
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
|
||||
use_checkpoint=False, weight_init='', **kwargs):
|
||||
|
@ -491,8 +484,9 @@ class SwinTransformer(nn.Module):
|
|||
num_heads=num_heads[i_layer],
|
||||
window_size=window_size,
|
||||
mlp_ratio=self.mlp_ratio,
|
||||
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
||||
drop=drop_rate, attn_drop=attn_drop_rate,
|
||||
qkv_bias=qkv_bias,
|
||||
drop=drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
|
||||
norm_layer=norm_layer,
|
||||
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
|
||||
|
@ -520,6 +514,13 @@ class SwinTransformer(nn.Module):
|
|||
def no_weight_decay_keywords(self):
|
||||
return {'relative_position_bias_table'}
|
||||
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=''):
|
||||
self.num_classes = num_classes
|
||||
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.patch_embed(x)
|
||||
if self.absolute_pos_embed is not None:
|
||||
|
|
|
@ -278,6 +278,8 @@ class Twins(nn.Module):
|
|||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
self.depths = depths
|
||||
self.embed_dims = embed_dims
|
||||
self.num_features = embed_dims[-1]
|
||||
|
||||
img_size = to_2tuple(img_size)
|
||||
prev_chs = in_chans
|
||||
|
@ -303,10 +305,10 @@ class Twins(nn.Module):
|
|||
|
||||
self.pos_block = nn.ModuleList([PosConv(embed_dim, embed_dim) for embed_dim in embed_dims])
|
||||
|
||||
self.norm = norm_layer(embed_dims[-1])
|
||||
self.norm = norm_layer(self.num_features)
|
||||
|
||||
# classification head
|
||||
self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity()
|
||||
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
# init weights
|
||||
self.apply(self._init_weights)
|
||||
|
@ -320,7 +322,7 @@ class Twins(nn.Module):
|
|||
|
||||
def reset_classifier(self, num_classes, global_pool=''):
|
||||
self.num_classes = num_classes
|
||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
|
|
|
@ -13,7 +13,7 @@ import torch.nn.functional as F
|
|||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .helpers import build_model_with_cfg, overlay_external_default_cfg
|
||||
from .layers import to_2tuple, trunc_normal_, DropPath, PatchEmbed, LayerNorm2d
|
||||
from .layers import to_2tuple, trunc_normal_, DropPath, PatchEmbed, LayerNorm2d, create_classifier
|
||||
from .registry import register_model
|
||||
|
||||
|
||||
|
@ -140,14 +140,14 @@ class Visformer(nn.Module):
|
|||
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, init_channels=32, embed_dim=384,
|
||||
depth=12, num_heads=6, mlp_ratio=4., drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
|
||||
norm_layer=LayerNorm2d, attn_stage='111', pos_embed=True, spatial_conv='111',
|
||||
vit_stem=False, group=8, pool=True, conv_init=False, embed_norm=None):
|
||||
vit_stem=False, group=8, global_pool='avg', conv_init=False, embed_norm=None):
|
||||
super().__init__()
|
||||
img_size = to_2tuple(img_size)
|
||||
self.num_classes = num_classes
|
||||
self.num_features = self.embed_dim = embed_dim
|
||||
self.embed_dim = embed_dim
|
||||
self.init_channels = init_channels
|
||||
self.img_size = img_size
|
||||
self.vit_stem = vit_stem
|
||||
self.pool = pool
|
||||
self.conv_init = conv_init
|
||||
if isinstance(depth, (list, tuple)):
|
||||
self.stage_num1, self.stage_num2, self.stage_num3 = depth
|
||||
|
@ -164,31 +164,31 @@ class Visformer(nn.Module):
|
|||
self.patch_embed1 = PatchEmbed(
|
||||
img_size=img_size, patch_size=patch_size, in_chans=in_chans,
|
||||
embed_dim=embed_dim, norm_layer=embed_norm, flatten=False)
|
||||
img_size //= 16
|
||||
img_size = [x // 16 for x in img_size]
|
||||
else:
|
||||
if self.init_channels is None:
|
||||
self.stem = None
|
||||
self.patch_embed1 = PatchEmbed(
|
||||
img_size=img_size, patch_size=patch_size // 2, in_chans=in_chans,
|
||||
embed_dim=embed_dim // 2, norm_layer=embed_norm, flatten=False)
|
||||
img_size //= 8
|
||||
img_size = [x // 8 for x in img_size]
|
||||
else:
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(in_chans, self.init_channels, 7, stride=2, padding=3, bias=False),
|
||||
nn.BatchNorm2d(self.init_channels),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
img_size //= 2
|
||||
img_size = [x // 2 for x in img_size]
|
||||
self.patch_embed1 = PatchEmbed(
|
||||
img_size=img_size, patch_size=patch_size // 4, in_chans=self.init_channels,
|
||||
embed_dim=embed_dim // 2, norm_layer=embed_norm, flatten=False)
|
||||
img_size //= 4
|
||||
img_size = [x // 4 for x in img_size]
|
||||
|
||||
if self.pos_embed:
|
||||
if self.vit_stem:
|
||||
self.pos_embed1 = nn.Parameter(torch.zeros(1, embed_dim, img_size, img_size))
|
||||
self.pos_embed1 = nn.Parameter(torch.zeros(1, embed_dim, *img_size))
|
||||
else:
|
||||
self.pos_embed1 = nn.Parameter(torch.zeros(1, embed_dim//2, img_size, img_size))
|
||||
self.pos_embed1 = nn.Parameter(torch.zeros(1, embed_dim//2, *img_size))
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
self.stage1 = nn.ModuleList([
|
||||
Block(
|
||||
|
@ -199,14 +199,14 @@ class Visformer(nn.Module):
|
|||
for i in range(self.stage_num1)
|
||||
])
|
||||
|
||||
#stage2
|
||||
# stage2
|
||||
if not self.vit_stem:
|
||||
self.patch_embed2 = PatchEmbed(
|
||||
img_size=img_size, patch_size=patch_size // 8, in_chans=embed_dim // 2,
|
||||
embed_dim=embed_dim, norm_layer=embed_norm, flatten=False)
|
||||
img_size //= 2
|
||||
img_size = [x // 2 for x in img_size]
|
||||
if self.pos_embed:
|
||||
self.pos_embed2 = nn.Parameter(torch.zeros(1, embed_dim, img_size, img_size))
|
||||
self.pos_embed2 = nn.Parameter(torch.zeros(1, embed_dim, *img_size))
|
||||
self.stage2 = nn.ModuleList([
|
||||
Block(
|
||||
dim=embed_dim, num_heads=num_heads, head_dim_ratio=1.0, mlp_ratio=mlp_ratio,
|
||||
|
@ -221,9 +221,9 @@ class Visformer(nn.Module):
|
|||
self.patch_embed3 = PatchEmbed(
|
||||
img_size=img_size, patch_size=patch_size // 8, in_chans=embed_dim,
|
||||
embed_dim=embed_dim * 2, norm_layer=embed_norm, flatten=False)
|
||||
img_size //= 2
|
||||
img_size = [x // 2 for x in img_size]
|
||||
if self.pos_embed:
|
||||
self.pos_embed3 = nn.Parameter(torch.zeros(1, embed_dim*2, img_size, img_size))
|
||||
self.pos_embed3 = nn.Parameter(torch.zeros(1, embed_dim*2, *img_size))
|
||||
self.stage3 = nn.ModuleList([
|
||||
Block(
|
||||
dim=embed_dim*2, num_heads=num_heads, head_dim_ratio=1.0, mlp_ratio=mlp_ratio,
|
||||
|
@ -234,11 +234,10 @@ class Visformer(nn.Module):
|
|||
])
|
||||
|
||||
# head
|
||||
if self.pool:
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
head_dim = embed_dim if self.vit_stem else embed_dim * 2
|
||||
self.norm = norm_layer(head_dim)
|
||||
self.head = nn.Linear(head_dim, num_classes)
|
||||
self.num_features = embed_dim if self.vit_stem else embed_dim * 2
|
||||
self.norm = norm_layer(self.num_features)
|
||||
self.global_pool, self.head = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
|
||||
self.head = nn.Linear(self.num_features, num_classes)
|
||||
|
||||
# weights init
|
||||
if self.pos_embed:
|
||||
|
@ -267,7 +266,14 @@ class Visformer(nn.Module):
|
|||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0.)
|
||||
|
||||
def forward(self, x):
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||
self.num_classes = num_classes
|
||||
self.global_pool, self.head = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
|
||||
|
||||
def forward_features(self, x):
|
||||
if self.stem is not None:
|
||||
x = self.stem(x)
|
||||
|
||||
|
@ -297,14 +303,13 @@ class Visformer(nn.Module):
|
|||
for b in self.stage3:
|
||||
x = b(x)
|
||||
|
||||
# head
|
||||
x = self.norm(x)
|
||||
if self.pool:
|
||||
x = self.global_pooling(x)
|
||||
else:
|
||||
x = x[:, :, 0, 0]
|
||||
return x
|
||||
|
||||
x = self.head(x.view(x.size(0), -1))
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
x = self.global_pool(x)
|
||||
x = self.head(x)
|
||||
return x
|
||||
|
||||
|
||||
|
@ -321,7 +326,7 @@ def _create_visformer(variant, pretrained=False, default_cfg=None, **kwargs):
|
|||
@register_model
|
||||
def visformer_tiny(pretrained=False, **kwargs):
|
||||
model_cfg = dict(
|
||||
img_size=224, init_channels=16, embed_dim=192, depth=(7, 4, 4), num_heads=3, mlp_ratio=4., group=8,
|
||||
init_channels=16, embed_dim=192, depth=(7, 4, 4), num_heads=3, mlp_ratio=4., group=8,
|
||||
attn_stage='011', spatial_conv='100', norm_layer=nn.BatchNorm2d, conv_init=True,
|
||||
embed_norm=nn.BatchNorm2d, **kwargs)
|
||||
model = _create_visformer('visformer_tiny', pretrained=pretrained, **model_cfg)
|
||||
|
@ -331,7 +336,7 @@ def visformer_tiny(pretrained=False, **kwargs):
|
|||
@register_model
|
||||
def visformer_small(pretrained=False, **kwargs):
|
||||
model_cfg = dict(
|
||||
img_size=224, init_channels=32, embed_dim=384, depth=(7, 4, 4), num_heads=6, mlp_ratio=4., group=8,
|
||||
init_channels=32, embed_dim=384, depth=(7, 4, 4), num_heads=6, mlp_ratio=4., group=8,
|
||||
attn_stage='011', spatial_conv='100', norm_layer=nn.BatchNorm2d, conv_init=True,
|
||||
embed_norm=nn.BatchNorm2d, **kwargs)
|
||||
model = _create_visformer('visformer_small', pretrained=pretrained, **model_cfg)
|
||||
|
|
|
@ -1,7 +1,12 @@
|
|||
""" Vision Transformer (ViT) in PyTorch
|
||||
|
||||
A PyTorch implement of Vision Transformers as described in
|
||||
'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' - https://arxiv.org/abs/2010.11929
|
||||
A PyTorch implement of Vision Transformers as described in:
|
||||
|
||||
'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale'
|
||||
- https://arxiv.org/abs/2010.11929
|
||||
|
||||
`How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers`
|
||||
- https://arxiv.org/abs/2106.TODO
|
||||
|
||||
The official jax code is released and available at https://github.com/google-research/vision_transformer
|
||||
|
||||
|
@ -15,7 +20,7 @@ for some einops/einsum fun
|
|||
* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
|
||||
* Bert reference code checks against Huggingface Transformers and Tensorflow Bert
|
||||
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
Hacked together by / Copyright 2021 Ross Wightman
|
||||
"""
|
||||
import math
|
||||
import logging
|
||||
|
@ -27,8 +32,8 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .helpers import build_model_with_cfg, overlay_external_default_cfg
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
from .helpers import build_model_with_cfg, named_apply, adapt_input_conv
|
||||
from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_
|
||||
from .registry import register_model
|
||||
|
||||
|
@ -40,86 +45,118 @@ def _cfg(url='', **kwargs):
|
|||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
||||
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
|
||||
'first_conv': 'patch_embed.proj', 'classifier': 'head',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
# patch models (my experiments)
|
||||
# patch models (weights from official Google JAX impl)
|
||||
'vit_tiny_patch16_224': _cfg(
|
||||
url='https://storage.googleapis.com/vit_models/augreg/'
|
||||
'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
|
||||
'vit_tiny_patch16_384': _cfg(
|
||||
url='https://storage.googleapis.com/vit_models/augreg/'
|
||||
'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
|
||||
input_size=(3, 384, 384), crop_pct=1.0),
|
||||
'vit_small_patch32_224': _cfg(
|
||||
url='https://storage.googleapis.com/vit_models/augreg/'
|
||||
'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
|
||||
'vit_small_patch32_384': _cfg(
|
||||
url='https://storage.googleapis.com/vit_models/augreg/'
|
||||
'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
|
||||
input_size=(3, 384, 384), crop_pct=1.0),
|
||||
'vit_small_patch16_224': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth',
|
||||
),
|
||||
|
||||
# patch models (weights ported from official Google JAX impl)
|
||||
'vit_base_patch16_224': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth',
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
||||
),
|
||||
url='https://storage.googleapis.com/vit_models/augreg/'
|
||||
'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
|
||||
'vit_small_patch16_384': _cfg(
|
||||
url='https://storage.googleapis.com/vit_models/augreg/'
|
||||
'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
|
||||
input_size=(3, 384, 384), crop_pct=1.0),
|
||||
'vit_base_patch32_224': _cfg(
|
||||
url='', # no official model weights for this combo, only for in21k
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
||||
'vit_base_patch16_384': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth',
|
||||
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
|
||||
url='https://storage.googleapis.com/vit_models/augreg/'
|
||||
'B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
|
||||
'vit_base_patch32_384': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p32_384-830016f5.pth',
|
||||
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
|
||||
'vit_large_patch16_224': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth',
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
||||
url='https://storage.googleapis.com/vit_models/augreg/'
|
||||
'B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
|
||||
input_size=(3, 384, 384), crop_pct=1.0),
|
||||
'vit_base_patch16_224': _cfg(
|
||||
url='https://storage.googleapis.com/vit_models/augreg/'
|
||||
'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'),
|
||||
'vit_base_patch16_384': _cfg(
|
||||
url='https://storage.googleapis.com/vit_models/augreg/'
|
||||
'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz',
|
||||
input_size=(3, 384, 384), crop_pct=1.0),
|
||||
'vit_large_patch32_224': _cfg(
|
||||
url='', # no official model weights for this combo, only for in21k
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
||||
'vit_large_patch16_384': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth',
|
||||
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
|
||||
),
|
||||
'vit_large_patch32_384': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth',
|
||||
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
|
||||
input_size=(3, 384, 384), crop_pct=1.0),
|
||||
'vit_large_patch16_224': _cfg(
|
||||
url='https://storage.googleapis.com/vit_models/augreg/'
|
||||
'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz'),
|
||||
'vit_large_patch16_384': _cfg(
|
||||
url='https://storage.googleapis.com/vit_models/augreg/'
|
||||
'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz',
|
||||
input_size=(3, 384, 384), crop_pct=1.0),
|
||||
|
||||
# patch models, imagenet21k (weights ported from official Google JAX impl)
|
||||
'vit_base_patch16_224_in21k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth',
|
||||
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
||||
# patch models, imagenet21k (weights from official Google JAX impl)
|
||||
'vit_tiny_patch16_224_in21k': _cfg(
|
||||
url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz',
|
||||
num_classes=21843),
|
||||
'vit_small_patch32_224_in21k': _cfg(
|
||||
url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz',
|
||||
num_classes=21843),
|
||||
'vit_small_patch16_224_in21k': _cfg(
|
||||
url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz',
|
||||
num_classes=21843),
|
||||
'vit_base_patch32_224_in21k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch32_224_in21k-8db57226.pth',
|
||||
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
||||
'vit_large_patch16_224_in21k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch16_224_in21k-606da67d.pth',
|
||||
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
||||
url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0.npz',
|
||||
num_classes=21843),
|
||||
'vit_base_patch16_224_in21k': _cfg(
|
||||
url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz',
|
||||
num_classes=21843),
|
||||
'vit_large_patch32_224_in21k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth',
|
||||
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
||||
num_classes=21843),
|
||||
'vit_large_patch16_224_in21k': _cfg(
|
||||
url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1.npz',
|
||||
num_classes=21843),
|
||||
'vit_huge_patch14_224_in21k': _cfg(
|
||||
url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz',
|
||||
hf_hub='timm/vit_huge_patch14_224_in21k',
|
||||
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
||||
num_classes=21843),
|
||||
|
||||
# deit models (FB weights)
|
||||
'vit_deit_tiny_patch16_224': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth'),
|
||||
'vit_deit_small_patch16_224': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth'),
|
||||
'vit_deit_base_patch16_224': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth',),
|
||||
'vit_deit_base_patch16_384': _cfg(
|
||||
'deit_tiny_patch16_224': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
'deit_small_patch16_224': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
'deit_base_patch16_224': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
'deit_base_patch16_384': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth',
|
||||
input_size=(3, 384, 384), crop_pct=1.0),
|
||||
'vit_deit_tiny_distilled_patch16_224': _cfg(
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(3, 384, 384), crop_pct=1.0),
|
||||
'deit_tiny_distilled_patch16_224': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth',
|
||||
classifier=('head', 'head_dist')),
|
||||
'vit_deit_small_distilled_patch16_224': _cfg(
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')),
|
||||
'deit_small_distilled_patch16_224': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth',
|
||||
classifier=('head', 'head_dist')),
|
||||
'vit_deit_base_distilled_patch16_224': _cfg(
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')),
|
||||
'deit_base_distilled_patch16_224': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth',
|
||||
classifier=('head', 'head_dist')),
|
||||
'vit_deit_base_distilled_patch16_384': _cfg(
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')),
|
||||
'deit_base_distilled_patch16_384': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth',
|
||||
input_size=(3, 384, 384), crop_pct=1.0, classifier=('head', 'head_dist')),
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(3, 384, 384), crop_pct=1.0,
|
||||
classifier=('head', 'head_dist')),
|
||||
|
||||
# ViT ImageNet-21K-P pretraining
|
||||
# ViT ImageNet-21K-P pretraining by MILL
|
||||
'vit_base_patch16_224_miil_in21k': _cfg(
|
||||
url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/vit_base_patch16_224_in21k_miil.pth',
|
||||
mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', num_classes=11221,
|
||||
|
@ -133,11 +170,11 @@ default_cfgs = {
|
|||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
||||
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = qk_scale or head_dim ** -0.5
|
||||
self.scale = head_dim ** -0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
|
@ -161,12 +198,11 @@ class Attention(nn.Module):
|
|||
|
||||
class Block(nn.Module):
|
||||
|
||||
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
||||
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
|
||||
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = Attention(
|
||||
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
||||
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
|
||||
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
|
@ -190,7 +226,7 @@ class VisionTransformer(nn.Module):
|
|||
"""
|
||||
|
||||
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
||||
num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, distilled=False,
|
||||
num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False,
|
||||
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None,
|
||||
act_layer=None, weight_init=''):
|
||||
"""
|
||||
|
@ -204,7 +240,6 @@ class VisionTransformer(nn.Module):
|
|||
num_heads (int): number of attention heads
|
||||
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
||||
qkv_bias (bool): enable bias for qkv if True
|
||||
qk_scale (float): override default qk scale of head_dim ** -0.5 if set
|
||||
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
|
||||
distilled (bool): model includes a distillation token and head as in DeiT models
|
||||
drop_rate (float): dropout rate
|
||||
|
@ -233,8 +268,8 @@ class VisionTransformer(nn.Module):
|
|||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
||||
self.blocks = nn.Sequential(*[
|
||||
Block(
|
||||
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
||||
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer)
|
||||
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate,
|
||||
attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer)
|
||||
for i in range(depth)])
|
||||
self.norm = norm_layer(embed_dim)
|
||||
|
||||
|
@ -254,16 +289,17 @@ class VisionTransformer(nn.Module):
|
|||
if distilled:
|
||||
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
# Weight init
|
||||
assert weight_init in ('jax', 'jax_nlhb', 'nlhb', '')
|
||||
head_bias = -math.log(self.num_classes) if 'nlhb' in weight_init else 0.
|
||||
self.init_weights(weight_init)
|
||||
|
||||
def init_weights(self, mode=''):
|
||||
assert mode in ('jax', 'jax_nlhb', 'nlhb', '')
|
||||
head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
|
||||
trunc_normal_(self.pos_embed, std=.02)
|
||||
if self.dist_token is not None:
|
||||
trunc_normal_(self.dist_token, std=.02)
|
||||
if weight_init.startswith('jax'):
|
||||
if mode.startswith('jax'):
|
||||
# leave cls token as zeros to match jax impl
|
||||
for n, m in self.named_modules():
|
||||
_init_vit_weights(m, n, head_bias=head_bias, jax_impl=True)
|
||||
named_apply(partial(_init_vit_weights, head_bias=head_bias, jax_impl=True), self)
|
||||
else:
|
||||
trunc_normal_(self.cls_token, std=.02)
|
||||
self.apply(_init_vit_weights)
|
||||
|
@ -272,6 +308,10 @@ class VisionTransformer(nn.Module):
|
|||
# this fn left here for compat with downstream users
|
||||
_init_vit_weights(m)
|
||||
|
||||
@torch.jit.ignore()
|
||||
def load_pretrained(self, checkpoint_path, prefix=''):
|
||||
_load_weights(self, checkpoint_path, prefix)
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay(self):
|
||||
return {'pos_embed', 'cls_token', 'dist_token'}
|
||||
|
@ -317,39 +357,116 @@ class VisionTransformer(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
def _init_vit_weights(m, n: str = '', head_bias: float = 0., jax_impl: bool = False):
|
||||
def _init_vit_weights(module: nn.Module, name: str = '', head_bias: float = 0., jax_impl: bool = False):
|
||||
""" ViT weight initialization
|
||||
* When called without n, head_bias, jax_impl args it will behave exactly the same
|
||||
as my original init for compatibility with prev hparam / downstream use cases (ie DeiT).
|
||||
* When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl
|
||||
"""
|
||||
if isinstance(m, nn.Linear):
|
||||
if n.startswith('head'):
|
||||
nn.init.zeros_(m.weight)
|
||||
nn.init.constant_(m.bias, head_bias)
|
||||
elif n.startswith('pre_logits'):
|
||||
lecun_normal_(m.weight)
|
||||
nn.init.zeros_(m.bias)
|
||||
if isinstance(module, nn.Linear):
|
||||
if name.startswith('head'):
|
||||
nn.init.zeros_(module.weight)
|
||||
nn.init.constant_(module.bias, head_bias)
|
||||
elif name.startswith('pre_logits'):
|
||||
lecun_normal_(module.weight)
|
||||
nn.init.zeros_(module.bias)
|
||||
else:
|
||||
if jax_impl:
|
||||
nn.init.xavier_uniform_(m.weight)
|
||||
if m.bias is not None:
|
||||
if 'mlp' in n:
|
||||
nn.init.normal_(m.bias, std=1e-6)
|
||||
nn.init.xavier_uniform_(module.weight)
|
||||
if module.bias is not None:
|
||||
if 'mlp' in name:
|
||||
nn.init.normal_(module.bias, std=1e-6)
|
||||
else:
|
||||
nn.init.zeros_(m.bias)
|
||||
nn.init.zeros_(module.bias)
|
||||
else:
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
elif jax_impl and isinstance(m, nn.Conv2d):
|
||||
trunc_normal_(module.weight, std=.02)
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
elif jax_impl and isinstance(module, nn.Conv2d):
|
||||
# NOTE conv was left to pytorch default in my original init
|
||||
lecun_normal_(m.weight)
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.zeros_(m.bias)
|
||||
nn.init.ones_(m.weight)
|
||||
lecun_normal_(module.weight)
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
|
||||
nn.init.zeros_(module.bias)
|
||||
nn.init.ones_(module.weight)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
|
||||
""" Load weights from .npz checkpoints for official Google Brain Flax implementation
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
def _n2p(w, t=True):
|
||||
if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
|
||||
w = w.flatten()
|
||||
if t:
|
||||
if w.ndim == 4:
|
||||
w = w.transpose([3, 2, 0, 1])
|
||||
elif w.ndim == 3:
|
||||
w = w.transpose([2, 0, 1])
|
||||
elif w.ndim == 2:
|
||||
w = w.transpose([1, 0])
|
||||
return torch.from_numpy(w)
|
||||
|
||||
w = np.load(checkpoint_path)
|
||||
if not prefix and 'opt/target/embedding/kernel' in w:
|
||||
prefix = 'opt/target/'
|
||||
|
||||
if hasattr(model.patch_embed, 'backbone'):
|
||||
# hybrid
|
||||
backbone = model.patch_embed.backbone
|
||||
stem_only = not hasattr(backbone, 'stem')
|
||||
stem = backbone if stem_only else backbone.stem
|
||||
stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
|
||||
stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
|
||||
stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
|
||||
if not stem_only:
|
||||
for i, stage in enumerate(backbone.stages):
|
||||
for j, block in enumerate(stage.blocks):
|
||||
bp = f'{prefix}block{i + 1}/unit{j + 1}/'
|
||||
for r in range(3):
|
||||
getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
|
||||
getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
|
||||
getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
|
||||
if block.downsample is not None:
|
||||
block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
|
||||
block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
|
||||
block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
|
||||
embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
|
||||
else:
|
||||
embed_conv_w = adapt_input_conv(
|
||||
model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
|
||||
model.patch_embed.proj.weight.copy_(embed_conv_w)
|
||||
model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
|
||||
model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
|
||||
pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
|
||||
if pos_embed_w.shape != model.pos_embed.shape:
|
||||
pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
|
||||
pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
|
||||
model.pos_embed.copy_(pos_embed_w)
|
||||
model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
|
||||
model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
|
||||
if model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
|
||||
model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
|
||||
model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
|
||||
for i, block in enumerate(model.blocks.children()):
|
||||
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
|
||||
mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
|
||||
block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
|
||||
block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
|
||||
block.attn.qkv.weight.copy_(torch.cat([
|
||||
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
|
||||
block.attn.qkv.bias.copy_(torch.cat([
|
||||
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
|
||||
block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
|
||||
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
|
||||
for r in range(2):
|
||||
getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
|
||||
getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
|
||||
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
|
||||
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
|
||||
|
||||
|
||||
def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()):
|
||||
|
@ -413,34 +530,64 @@ def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kw
|
|||
default_cfg=default_cfg,
|
||||
representation_size=repr_size,
|
||||
pretrained_filter_fn=checkpoint_filter_fn,
|
||||
pretrained_custom_load='npz' in default_cfg['url'],
|
||||
**kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_small_patch16_224(pretrained=False, **kwargs):
|
||||
""" My custom 'small' ViT model. embed_dim=768, depth=8, num_heads=8, mlp_ratio=3.
|
||||
NOTE:
|
||||
* this differs from the DeiT based 'small' definitions with embed_dim=384, depth=12, num_heads=6
|
||||
* this model does not have a bias for QKV (unlike the official ViT and DeiT models)
|
||||
def vit_tiny_patch16_224(pretrained=False, **kwargs):
|
||||
""" ViT-Tiny (Vit-Ti/16)
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=768, depth=8, num_heads=8, mlp_ratio=3.,
|
||||
qkv_bias=False, norm_layer=nn.LayerNorm, **kwargs)
|
||||
if pretrained:
|
||||
# NOTE my scale was wrong for original weights, leaving this here until I have better ones for this model
|
||||
model_kwargs.setdefault('qk_scale', 768 ** -0.5)
|
||||
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
|
||||
model = _create_vision_transformer('vit_tiny_patch16_224', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_tiny_patch16_384(pretrained=False, **kwargs):
|
||||
""" ViT-Tiny (Vit-Ti/16) @ 384x384.
|
||||
"""
|
||||
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
|
||||
model = _create_vision_transformer('vit_tiny_patch16_384', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_small_patch32_224(pretrained=False, **kwargs):
|
||||
""" ViT-Small (ViT-S/32)
|
||||
"""
|
||||
model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs)
|
||||
model = _create_vision_transformer('vit_small_patch32_224', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_small_patch32_384(pretrained=False, **kwargs):
|
||||
""" ViT-Small (ViT-S/32) at 384x384.
|
||||
"""
|
||||
model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs)
|
||||
model = _create_vision_transformer('vit_small_patch32_384', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_small_patch16_224(pretrained=False, **kwargs):
|
||||
""" ViT-Small (ViT-S/16)
|
||||
NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper
|
||||
"""
|
||||
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
|
||||
model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_patch16_224(pretrained=False, **kwargs):
|
||||
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
|
||||
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
|
||||
def vit_small_patch16_384(pretrained=False, **kwargs):
|
||||
""" ViT-Small (ViT-S/16)
|
||||
NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper
|
||||
"""
|
||||
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
||||
model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **model_kwargs)
|
||||
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
|
||||
model = _create_vision_transformer('vit_small_patch16_384', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
|
@ -453,16 +600,6 @@ def vit_base_patch32_224(pretrained=False, **kwargs):
|
|||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_patch16_384(pretrained=False, **kwargs):
|
||||
""" ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
|
||||
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
|
||||
"""
|
||||
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
||||
model = _create_vision_transformer('vit_base_patch16_384', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_patch32_384(pretrained=False, **kwargs):
|
||||
""" ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
|
||||
|
@ -474,12 +611,22 @@ def vit_base_patch32_384(pretrained=False, **kwargs):
|
|||
|
||||
|
||||
@register_model
|
||||
def vit_large_patch16_224(pretrained=False, **kwargs):
|
||||
""" ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
|
||||
def vit_base_patch16_224(pretrained=False, **kwargs):
|
||||
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
|
||||
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
|
||||
"""
|
||||
model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
|
||||
model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **model_kwargs)
|
||||
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
||||
model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_patch16_384(pretrained=False, **kwargs):
|
||||
""" ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
|
||||
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
|
||||
"""
|
||||
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
||||
model = _create_vision_transformer('vit_base_patch16_384', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
|
@ -492,16 +639,6 @@ def vit_large_patch32_224(pretrained=False, **kwargs):
|
|||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_large_patch16_384(pretrained=False, **kwargs):
|
||||
""" ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
|
||||
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
|
||||
"""
|
||||
model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
|
||||
model = _create_vision_transformer('vit_large_patch16_384', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_large_patch32_384(pretrained=False, **kwargs):
|
||||
""" ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
|
||||
|
@ -513,13 +650,52 @@ def vit_large_patch32_384(pretrained=False, **kwargs):
|
|||
|
||||
|
||||
@register_model
|
||||
def vit_base_patch16_224_in21k(pretrained=False, **kwargs):
|
||||
""" ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
|
||||
def vit_large_patch16_224(pretrained=False, **kwargs):
|
||||
""" ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
|
||||
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
|
||||
"""
|
||||
model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
|
||||
model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_large_patch16_384(pretrained=False, **kwargs):
|
||||
""" ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
|
||||
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
|
||||
"""
|
||||
model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
|
||||
model = _create_vision_transformer('vit_large_patch16_384', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_tiny_patch16_224_in21k(pretrained=False, **kwargs):
|
||||
""" ViT-Tiny (Vit-Ti/16).
|
||||
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs)
|
||||
model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
|
||||
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
|
||||
model = _create_vision_transformer('vit_tiny_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_small_patch32_224_in21k(pretrained=False, **kwargs):
|
||||
""" ViT-Small (ViT-S/16)
|
||||
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
||||
"""
|
||||
model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs)
|
||||
model = _create_vision_transformer('vit_small_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_small_patch16_224_in21k(pretrained=False, **kwargs):
|
||||
""" ViT-Small (ViT-S/16)
|
||||
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
||||
"""
|
||||
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
|
||||
model = _create_vision_transformer('vit_small_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
|
@ -535,13 +711,13 @@ def vit_base_patch32_224_in21k(pretrained=False, **kwargs):
|
|||
|
||||
|
||||
@register_model
|
||||
def vit_large_patch16_224_in21k(pretrained=False, **kwargs):
|
||||
""" ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
|
||||
def vit_base_patch16_224_in21k(pretrained=False, **kwargs):
|
||||
""" ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
|
||||
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs)
|
||||
model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
|
||||
patch_size=16, embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs)
|
||||
model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
|
@ -556,6 +732,17 @@ def vit_large_patch32_224_in21k(pretrained=False, **kwargs):
|
|||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_large_patch16_224_in21k(pretrained=False, **kwargs):
|
||||
""" ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
|
||||
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs)
|
||||
model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_huge_patch14_224_in21k(pretrained=False, **kwargs):
|
||||
""" ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
|
||||
|
@ -569,86 +756,86 @@ def vit_huge_patch14_224_in21k(pretrained=False, **kwargs):
|
|||
|
||||
|
||||
@register_model
|
||||
def vit_deit_tiny_patch16_224(pretrained=False, **kwargs):
|
||||
def deit_tiny_patch16_224(pretrained=False, **kwargs):
|
||||
""" DeiT-tiny model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
|
||||
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
|
||||
model = _create_vision_transformer('vit_deit_tiny_patch16_224', pretrained=pretrained, **model_kwargs)
|
||||
model = _create_vision_transformer('deit_tiny_patch16_224', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_deit_small_patch16_224(pretrained=False, **kwargs):
|
||||
def deit_small_patch16_224(pretrained=False, **kwargs):
|
||||
""" DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
|
||||
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
|
||||
model = _create_vision_transformer('vit_deit_small_patch16_224', pretrained=pretrained, **model_kwargs)
|
||||
model = _create_vision_transformer('deit_small_patch16_224', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_deit_base_patch16_224(pretrained=False, **kwargs):
|
||||
def deit_base_patch16_224(pretrained=False, **kwargs):
|
||||
""" DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
|
||||
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
||||
model = _create_vision_transformer('vit_deit_base_patch16_224', pretrained=pretrained, **model_kwargs)
|
||||
model = _create_vision_transformer('deit_base_patch16_224', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_deit_base_patch16_384(pretrained=False, **kwargs):
|
||||
def deit_base_patch16_384(pretrained=False, **kwargs):
|
||||
""" DeiT base model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
|
||||
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
||||
model = _create_vision_transformer('vit_deit_base_patch16_384', pretrained=pretrained, **model_kwargs)
|
||||
model = _create_vision_transformer('deit_base_patch16_384', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_deit_tiny_distilled_patch16_224(pretrained=False, **kwargs):
|
||||
def deit_tiny_distilled_patch16_224(pretrained=False, **kwargs):
|
||||
""" DeiT-tiny distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
|
||||
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
|
||||
model = _create_vision_transformer(
|
||||
'vit_deit_tiny_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
|
||||
'deit_tiny_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_deit_small_distilled_patch16_224(pretrained=False, **kwargs):
|
||||
def deit_small_distilled_patch16_224(pretrained=False, **kwargs):
|
||||
""" DeiT-small distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
|
||||
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
|
||||
model = _create_vision_transformer(
|
||||
'vit_deit_small_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
|
||||
'deit_small_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_deit_base_distilled_patch16_224(pretrained=False, **kwargs):
|
||||
def deit_base_distilled_patch16_224(pretrained=False, **kwargs):
|
||||
""" DeiT-base distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
|
||||
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
||||
model = _create_vision_transformer(
|
||||
'vit_deit_base_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
|
||||
'deit_base_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_deit_base_distilled_patch16_384(pretrained=False, **kwargs):
|
||||
def deit_base_distilled_patch16_384(pretrained=False, **kwargs):
|
||||
""" DeiT-base distilled model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
|
||||
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
||||
model = _create_vision_transformer(
|
||||
'vit_deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **model_kwargs)
|
||||
'deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
|
|
|
@ -1,13 +1,17 @@
|
|||
""" Hybrid Vision Transformer (ViT) in PyTorch
|
||||
|
||||
A PyTorch implement of the Hybrid Vision Transformers as described in
|
||||
A PyTorch implement of the Hybrid Vision Transformers as described in:
|
||||
|
||||
'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale'
|
||||
- https://arxiv.org/abs/2010.11929
|
||||
|
||||
NOTE This relies on code in vision_transformer.py. The hybrid model definitions were moved here to
|
||||
keep file sizes sane.
|
||||
`How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers`
|
||||
- https://arxiv.org/abs/2106.TODO
|
||||
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
NOTE These hybrid model definitions depend on code in vision_transformer.py.
|
||||
They were moved here to keep file sizes sane.
|
||||
|
||||
Hacked together by / Copyright 2021 Ross Wightman
|
||||
"""
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
|
@ -35,32 +39,61 @@ def _cfg(url='', **kwargs):
|
|||
|
||||
|
||||
default_cfgs = {
|
||||
# hybrid in-21k models (weights ported from official Google JAX impl where they exist)
|
||||
'vit_base_r50_s16_224_in21k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth',
|
||||
num_classes=21843, crop_pct=0.9),
|
||||
|
||||
# hybrid in-1k models (weights ported from official JAX impl)
|
||||
# hybrid in-1k models (weights from official JAX impl where they exist)
|
||||
'vit_tiny_r_s16_p8_224': _cfg(
|
||||
url='https://storage.googleapis.com/vit_models/augreg/'
|
||||
'R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
|
||||
first_conv='patch_embed.backbone.conv'),
|
||||
'vit_tiny_r_s16_p8_384': _cfg(
|
||||
url='https://storage.googleapis.com/vit_models/augreg/'
|
||||
'R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
|
||||
first_conv='patch_embed.backbone.conv', input_size=(3, 384, 384), crop_pct=1.0),
|
||||
'vit_small_r26_s32_224': _cfg(
|
||||
url='https://storage.googleapis.com/vit_models/augreg/'
|
||||
'R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.03-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.03-res_224.npz',
|
||||
),
|
||||
'vit_small_r26_s32_384': _cfg(
|
||||
url='https://storage.googleapis.com/vit_models/augreg/'
|
||||
'R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
|
||||
input_size=(3, 384, 384), crop_pct=1.0),
|
||||
'vit_base_r26_s32_224': _cfg(),
|
||||
'vit_base_r50_s16_224': _cfg(),
|
||||
'vit_base_r50_s16_384': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth',
|
||||
input_size=(3, 384, 384), crop_pct=1.0),
|
||||
'vit_large_r50_s32_224': _cfg(
|
||||
url='https://storage.googleapis.com/vit_models/augreg/'
|
||||
'R50_L_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz'
|
||||
),
|
||||
'vit_large_r50_s32_384': _cfg(
|
||||
url='https://storage.googleapis.com/vit_models/augreg/'
|
||||
'R50_L_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz',
|
||||
input_size=(3, 384, 384), crop_pct=1.0
|
||||
),
|
||||
|
||||
# hybrid in-1k models (mostly untrained, experimental configs w/ resnetv2 stdconv backbones)
|
||||
'vit_tiny_r_s16_p8_224': _cfg(),
|
||||
'vit_small_r_s16_p8_224': _cfg(),
|
||||
'vit_small_r20_s16_p2_224': _cfg(),
|
||||
'vit_small_r20_s16_224': _cfg(),
|
||||
'vit_small_r26_s32_224': _cfg(),
|
||||
'vit_base_r20_s16_224': _cfg(),
|
||||
'vit_base_r26_s32_224': _cfg(),
|
||||
'vit_base_r50_s16_224': _cfg(),
|
||||
'vit_large_r50_s32_224': _cfg(),
|
||||
# hybrid in-21k models (weights from official Google JAX impl where they exist)
|
||||
'vit_tiny_r_s16_p8_224_in21k': _cfg(
|
||||
url='https://storage.googleapis.com/vit_models/augreg/R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz',
|
||||
num_classes=21843, crop_pct=0.9, first_conv='patch_embed.backbone.conv'),
|
||||
'vit_small_r26_s32_224_in21k': _cfg(
|
||||
url='https://storage.googleapis.com/vit_models/augreg/R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0.npz',
|
||||
num_classes=21843, crop_pct=0.9),
|
||||
'vit_base_r50_s16_224_in21k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth',
|
||||
num_classes=21843, crop_pct=0.9),
|
||||
'vit_large_r50_s32_224_in21k': _cfg(
|
||||
url='https://storage.googleapis.com/vit_models/augreg/R50_L_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0.npz',
|
||||
num_classes=21843, crop_pct=0.9),
|
||||
|
||||
# hybrid models (using timm resnet backbones)
|
||||
'vit_small_resnet26d_224': _cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
'vit_small_resnet50d_s16_224': _cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
'vit_base_resnet26d_224': _cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
'vit_base_resnet50d_224': _cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
'vit_small_resnet26d_224': _cfg(
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
|
||||
'vit_small_resnet50d_s16_224': _cfg(
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
|
||||
'vit_base_resnet26d_224': _cfg(
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
|
||||
'vit_base_resnet50d_224': _cfg(
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
|
||||
}
|
||||
|
||||
|
||||
|
@ -95,7 +128,8 @@ class HybridEmbed(nn.Module):
|
|||
else:
|
||||
feature_dim = self.backbone.num_features
|
||||
assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0
|
||||
self.num_patches = feature_size[0] // patch_size[0] * feature_size[1] // patch_size[1]
|
||||
self.grid_size = (feature_size[0] // patch_size[0], feature_size[1] // patch_size[1])
|
||||
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
||||
self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size)
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -116,12 +150,8 @@ def _create_vision_transformer_hybrid(variant, backbone, pretrained=False, **kwa
|
|||
def _resnetv2(layers=(3, 4, 9), **kwargs):
|
||||
""" ResNet-V2 backbone helper"""
|
||||
padding_same = kwargs.get('padding_same', True)
|
||||
if padding_same:
|
||||
stem_type = 'same'
|
||||
conv_layer = StdConv2dSame
|
||||
else:
|
||||
stem_type = ''
|
||||
conv_layer = StdConv2d
|
||||
stem_type = 'same' if padding_same else ''
|
||||
conv_layer = partial(StdConv2dSame, eps=1e-8) if padding_same else partial(StdConv2d, eps=1e-8)
|
||||
if len(layers):
|
||||
backbone = ResNetV2(
|
||||
layers=layers, num_classes=0, global_pool='', in_chans=kwargs.get('in_chans', 3),
|
||||
|
@ -132,42 +162,6 @@ def _resnetv2(layers=(3, 4, 9), **kwargs):
|
|||
return backbone
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_r50_s16_224_in21k(pretrained=False, **kwargs):
|
||||
""" R50+ViT-B/16 hybrid model from original paper (https://arxiv.org/abs/2010.11929).
|
||||
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
||||
"""
|
||||
backbone = _resnetv2(layers=(3, 4, 9), **kwargs)
|
||||
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs)
|
||||
model = _create_vision_transformer_hybrid(
|
||||
'vit_base_r50_s16_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_resnet50_224_in21k(pretrained=False, **kwargs):
|
||||
# NOTE this is forwarding to model def above for backwards compatibility
|
||||
return vit_base_r50_s16_224_in21k(pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_r50_s16_384(pretrained=False, **kwargs):
|
||||
""" R50+ViT-B/16 hybrid from original paper (https://arxiv.org/abs/2010.11929).
|
||||
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
|
||||
"""
|
||||
backbone = _resnetv2((3, 4, 9), **kwargs)
|
||||
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs)
|
||||
model = _create_vision_transformer_hybrid(
|
||||
'vit_base_r50_s16_384', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_resnet50_384(pretrained=False, **kwargs):
|
||||
# NOTE this is forwarding to model def above for backwards compatibility
|
||||
return vit_base_r50_s16_384(pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_tiny_r_s16_p8_224(pretrained=False, **kwargs):
|
||||
""" R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 224 x 224.
|
||||
|
@ -180,36 +174,13 @@ def vit_tiny_r_s16_p8_224(pretrained=False, **kwargs):
|
|||
|
||||
|
||||
@register_model
|
||||
def vit_small_r_s16_p8_224(pretrained=False, **kwargs):
|
||||
""" R+ViT-S/S16 w/ 8x8 patch hybrid @ 224 x 224.
|
||||
def vit_tiny_r_s16_p8_384(pretrained=False, **kwargs):
|
||||
""" R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 384 x 384.
|
||||
"""
|
||||
backbone = _resnetv2(layers=(), **kwargs)
|
||||
model_kwargs = dict(patch_size=8, embed_dim=384, depth=12, num_heads=6, **kwargs)
|
||||
model_kwargs = dict(patch_size=8, embed_dim=192, depth=12, num_heads=3, **kwargs)
|
||||
model = _create_vision_transformer_hybrid(
|
||||
'vit_small_r_s16_p8_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_small_r20_s16_p2_224(pretrained=False, **kwargs):
|
||||
""" R52+ViT-S/S16 w/ 2x2 patch hybrid @ 224 x 224.
|
||||
"""
|
||||
backbone = _resnetv2((2, 4), **kwargs)
|
||||
model_kwargs = dict(patch_size=2, embed_dim=384, depth=12, num_heads=6, **kwargs)
|
||||
model = _create_vision_transformer_hybrid(
|
||||
'vit_small_r20_s16_p2_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_small_r20_s16_224(pretrained=False, **kwargs):
|
||||
""" R20+ViT-S/S16 hybrid.
|
||||
"""
|
||||
backbone = _resnetv2((2, 2, 2), **kwargs)
|
||||
model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, **kwargs)
|
||||
model = _create_vision_transformer_hybrid(
|
||||
'vit_small_r20_s16_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
||||
'vit_tiny_r_s16_p8_384', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
|
@ -225,13 +196,13 @@ def vit_small_r26_s32_224(pretrained=False, **kwargs):
|
|||
|
||||
|
||||
@register_model
|
||||
def vit_base_r20_s16_224(pretrained=False, **kwargs):
|
||||
""" R20+ViT-B/S16 hybrid.
|
||||
def vit_small_r26_s32_384(pretrained=False, **kwargs):
|
||||
""" R26+ViT-S/S32 hybrid.
|
||||
"""
|
||||
backbone = _resnetv2((2, 2, 2), **kwargs)
|
||||
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs)
|
||||
backbone = _resnetv2((2, 2, 2, 2), **kwargs)
|
||||
model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, **kwargs)
|
||||
model = _create_vision_transformer_hybrid(
|
||||
'vit_base_r20_s16_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
||||
'vit_small_r26_s32_384', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
|
@ -257,17 +228,97 @@ def vit_base_r50_s16_224(pretrained=False, **kwargs):
|
|||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_r50_s16_384(pretrained=False, **kwargs):
|
||||
""" R50+ViT-B/16 hybrid from original paper (https://arxiv.org/abs/2010.11929).
|
||||
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
|
||||
"""
|
||||
backbone = _resnetv2((3, 4, 9), **kwargs)
|
||||
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs)
|
||||
model = _create_vision_transformer_hybrid(
|
||||
'vit_base_r50_s16_384', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_resnet50_384(pretrained=False, **kwargs):
|
||||
# DEPRECATED this is forwarding to model def above for backwards compatibility
|
||||
return vit_base_r50_s16_384(pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_large_r50_s32_224(pretrained=False, **kwargs):
|
||||
""" R50+ViT-L/S32 hybrid.
|
||||
"""
|
||||
backbone = _resnetv2((3, 4, 6, 3), **kwargs)
|
||||
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs)
|
||||
model_kwargs = dict(embed_dim=1024, depth=24, num_heads=16, **kwargs)
|
||||
model = _create_vision_transformer_hybrid(
|
||||
'vit_large_r50_s32_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_large_r50_s32_384(pretrained=False, **kwargs):
|
||||
""" R50+ViT-L/S32 hybrid.
|
||||
"""
|
||||
backbone = _resnetv2((3, 4, 6, 3), **kwargs)
|
||||
model_kwargs = dict(embed_dim=1024, depth=24, num_heads=16, **kwargs)
|
||||
model = _create_vision_transformer_hybrid(
|
||||
'vit_large_r50_s32_384', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_tiny_r_s16_p8_224_in21k(pretrained=False, **kwargs):
|
||||
""" R+ViT-Ti/S16 w/ 8x8 patch hybrid. ImageNet-21k.
|
||||
"""
|
||||
backbone = _resnetv2(layers=(), **kwargs)
|
||||
model_kwargs = dict(patch_size=8, embed_dim=192, depth=12, num_heads=3, **kwargs)
|
||||
model = _create_vision_transformer_hybrid(
|
||||
'vit_tiny_r_s16_p8_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_small_r26_s32_224_in21k(pretrained=False, **kwargs):
|
||||
""" R26+ViT-S/S32 hybrid. ImageNet-21k.
|
||||
"""
|
||||
backbone = _resnetv2((2, 2, 2, 2), **kwargs)
|
||||
model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, **kwargs)
|
||||
model = _create_vision_transformer_hybrid(
|
||||
'vit_small_r26_s32_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_r50_s16_224_in21k(pretrained=False, **kwargs):
|
||||
""" R50+ViT-B/16 hybrid model from original paper (https://arxiv.org/abs/2010.11929).
|
||||
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
||||
"""
|
||||
backbone = _resnetv2(layers=(3, 4, 9), **kwargs)
|
||||
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs)
|
||||
model = _create_vision_transformer_hybrid(
|
||||
'vit_base_r50_s16_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_resnet50_224_in21k(pretrained=False, **kwargs):
|
||||
# DEPRECATED this is forwarding to model def above for backwards compatibility
|
||||
return vit_base_r50_s16_224_in21k(pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_large_r50_s32_224_in21k(pretrained=False, **kwargs):
|
||||
""" R50+ViT-L/S32 hybrid. ImageNet-21k.
|
||||
"""
|
||||
backbone = _resnetv2((3, 4, 6, 3), **kwargs)
|
||||
model_kwargs = dict(embed_dim=1024, depth=24, num_heads=16, **kwargs)
|
||||
model = _create_vision_transformer_hybrid(
|
||||
'vit_large_r50_s32_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_small_resnet26d_224(pretrained=False, **kwargs):
|
||||
""" Custom ViT small hybrid w/ ResNet26D stride 32. No pretrained weights.
|
||||
|
|
|
@ -1 +1 @@
|
|||
__version__ = '0.4.11'
|
||||
__version__ = '0.4.12'
|
||||
|
|
Loading…
Reference in New Issue