Merge pull request #702 from rwightman/cleanup_xla_model_fixes

AugReg Vision Transformers, XLA model compat for ResNetV2-BiT / NFNet, ECA-NFNet-L2, GMixer-24 weights, ResMLP official weights, and cleanup
pull/714/head
Ross Wightman 2021-06-20 17:49:14 -07:00 committed by GitHub
commit 79927baaec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
28 changed files with 1234 additions and 700 deletions

View File

@ -23,6 +23,25 @@ I'm fortunate to be able to dedicate significant time and money of my own suppor
## What's New
### June 20, 2021
* Release Vision Transformer 'AugReg' weights from [How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers](https://arxiv.org/abs/2106.10270)
* .npz weight loading support added, can load any of the 50K+ weights from the [AugReg series](https://console.cloud.google.com/storage/browser/vit_models/augreg)
* See [example notebook](https://colab.research.google.com/github/google-research/vision_transformer/blob/master/vit_jax_augreg.ipynb) from official impl for navigating the augreg weights
* Replaced all default weights w/ best AugReg variant (if possible). All AugReg 21k classifiers work.
* Highlights: `vit_large_patch16_384` (87.1 top-1), `vit_large_r50_s32_384` (86.2 top-1), `vit_base_patch16_384` (86.0 top-1)
* `vit_deit_*` renamed to just `deit_*`
* Remove my old small model, replace with DeiT compatible small w/ AugReg weights
* Add 1st training of my `gmixer_24_224` MLP /w GLU, 78.1 top-1 w/ 25M params.
* Add weights from official ResMLP release (https://github.com/facebookresearch/deit)
* Add `eca_nfnet_l2` weights from my 'lightweight' series. 84.7 top-1 at 384x384.
* Add distilled BiT 50x1 student and 152x2 Teacher weights from [Knowledge distillation: A good teacher is patient and consistent](https://arxiv.org/abs/2106.05237)
* NFNets and ResNetV2-BiT models work w/ Pytorch XLA now
* weight standardization uses F.batch_norm instead of std_mean (std_mean wasn't lowered)
* eps values adjusted, will be slight differences but should be quite close
* Improve test coverage and classifier interface of non-conv (vision transformer and mlp) models
* Cleanup a few classifier / flatten details for models w/ conv classifiers or early global pool
* Please report any regressions, this PR touched quite a few models.
### June 8, 2021
* Add first ResMLP weights, trained in PyTorch XLA on TPU-VM w/ my XLA branch. 24 block variant, 79.2 top-1.
* Add ResNet51-Q model w/ pretrained weights at 82.36 top-1.

View File

@ -17,7 +17,7 @@ if hasattr(torch._C, '_jit_set_profiling_executor'):
# transformer models don't support many of the spatial / feature based model functionalities
NON_STD_FILTERS = [
'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
'convit_*', 'levit*', 'visformer*']
'convit_*', 'levit*', 'visformer*', 'deit*']
NUM_NON_STD = len(NON_STD_FILTERS)
# exclude models that cause specific test failures
@ -120,7 +120,6 @@ def test_model_default_cfgs(model_name, batch_size):
state_dict = model.state_dict()
cfg = model.default_cfg
classifier = cfg['classifier']
pool_size = cfg['pool_size']
input_size = model.default_cfg['input_size']
@ -149,7 +148,57 @@ def test_model_default_cfgs(model_name, batch_size):
assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2]
# check classifier name matches default_cfg
assert classifier + ".weight" in state_dict.keys(), f'{classifier} not in model params'
classifier = cfg['classifier']
if not isinstance(classifier, (tuple, list)):
classifier = classifier,
for c in classifier:
assert c + ".weight" in state_dict.keys(), f'{c} not in model params'
# check first conv(s) names match default_cfg
first_conv = cfg['first_conv']
if isinstance(first_conv, str):
first_conv = (first_conv,)
assert isinstance(first_conv, (tuple, list))
for fc in first_conv:
assert fc + ".weight" in state_dict.keys(), f'{fc} not in model params'
@pytest.mark.timeout(300)
@pytest.mark.parametrize('model_name', list_models(filter=NON_STD_FILTERS))
@pytest.mark.parametrize('batch_size', [1])
def test_model_default_cfgs_non_std(model_name, batch_size):
"""Run a single forward pass with each model"""
model = create_model(model_name, pretrained=False)
model.eval()
state_dict = model.state_dict()
cfg = model.default_cfg
input_size = _get_input_size(model=model)
if max(input_size) > 320: # FIXME const
pytest.skip("Fixed input size model > limit.")
input_tensor = torch.randn((batch_size, *input_size))
# test forward_features (always unpooled)
outputs = model.forward_features(input_tensor)
if isinstance(outputs, tuple):
outputs = outputs[0]
assert outputs.shape[1] == model.num_features
# test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features
model.reset_classifier(0)
outputs = model.forward(input_tensor)
if isinstance(outputs, tuple):
outputs = outputs[0]
assert len(outputs.shape) == 2
assert outputs.shape[1] == model.num_features
# check classifier name matches default_cfg
classifier = cfg['classifier']
if not isinstance(classifier, (tuple, list)):
classifier = classifier,
for c in classifier:
assert c + ".weight" in state_dict.keys(), f'{c} not in model params'
# check first conv(s) names match default_cfg
first_conv = cfg['first_conv']

View File

@ -74,11 +74,11 @@ default_cfgs = dict(
class ClassAttn(nn.Module):
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
# with slight modifications to do CA
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.scale = head_dim ** -0.5
self.q = nn.Linear(dim, dim, bias=qkv_bias)
self.k = nn.Linear(dim, dim, bias=qkv_bias)
@ -110,13 +110,13 @@ class LayerScaleBlockClassAttn(nn.Module):
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
# with slight modifications to add CA and LayerScale
def __init__(
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, attn_block=ClassAttn,
mlp_block=Mlp, init_values=1e-4):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = attn_block(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
@ -134,14 +134,14 @@ class LayerScaleBlockClassAttn(nn.Module):
class TalkingHeadAttn(nn.Module):
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
# with slight modifications to add Talking Heads Attention (https://arxiv.org/pdf/2003.02436v1.pdf)
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
@ -177,13 +177,13 @@ class LayerScaleBlock(nn.Module):
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
# with slight modifications to add layerScale
def __init__(
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, attn_block=TalkingHeadAttn,
mlp_block=Mlp, init_values=1e-4):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = attn_block(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
@ -202,7 +202,7 @@ class Cait(nn.Module):
# with slight modifications to adapt to our cait models
def __init__(
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0.,
drop_path_rate=0.,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
global_pool=None,
@ -235,14 +235,14 @@ class Cait(nn.Module):
dpr = [drop_path_rate for i in range(depth)]
self.blocks = nn.ModuleList([
block_layers(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
act_layer=act_layer, attn_block=attn_block, mlp_block=mlp_block, init_values=init_scale)
for i in range(depth)])
self.blocks_token_only = nn.ModuleList([
block_layers_token(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio_clstk, qkv_bias=qkv_bias, qk_scale=qk_scale,
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio_clstk, qkv_bias=qkv_bias,
drop=0.0, attn_drop=0.0, drop_path=0.0, norm_layer=norm_layer,
act_layer=act_layer, attn_block=attn_block_token_only,
mlp_block=mlp_block_token_only, init_values=init_scale)
@ -270,6 +270,13 @@ class Cait(nn.Module):
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
B = x.shape[0]
x = self.patch_embed(x)
@ -293,7 +300,6 @@ class Cait(nn.Module):
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x

View File

@ -335,6 +335,8 @@ class CoaT(nn.Module):
crpe_window = crpe_window or {3: 2, 5: 3, 7: 3}
self.return_interm_layers = return_interm_layers
self.out_features = out_features
self.embed_dims = embed_dims
self.num_features = embed_dims[-1]
self.num_classes = num_classes
# Patch embeddings.
@ -441,10 +443,10 @@ class CoaT(nn.Module):
# CoaT series: Aggregate features of last three scales for classification.
assert embed_dims[1] == embed_dims[2] == embed_dims[3]
self.aggregate = torch.nn.Conv1d(in_channels=3, out_channels=1, kernel_size=1)
self.head = nn.Linear(embed_dims[3], num_classes)
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
else:
# CoaT-Lite series: Use feature of last scale for classification.
self.head = nn.Linear(embed_dims[3], num_classes)
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
# Initialize weights.
trunc_normal_(self.cls_token1, std=.02)
@ -471,7 +473,7 @@ class CoaT(nn.Module):
def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def insert_cls(self, x, cls_token):
""" Insert CLS token. """

View File

@ -57,13 +57,13 @@ default_cfgs = {
class GPSA(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.,
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.,
locality_strength=1.):
super().__init__()
self.num_heads = num_heads
self.dim = dim
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.scale = head_dim ** -0.5
self.locality_strength = locality_strength
self.qk = nn.Linear(dim, dim * 2, bias=qkv_bias)
@ -141,11 +141,11 @@ class GPSA(nn.Module):
class MHSA(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
@ -190,19 +190,16 @@ class MHSA(nn.Module):
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_gpsa=True, **kwargs):
super().__init__()
self.norm1 = norm_layer(dim)
self.use_gpsa = use_gpsa
if self.use_gpsa:
self.attn = GPSA(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop,
proj_drop=drop, **kwargs)
dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, **kwargs)
else:
self.attn = MHSA(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop,
proj_drop=drop, **kwargs)
self.attn = MHSA(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
@ -219,7 +216,7 @@ class ConViT(nn.Module):
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
num_heads=12, mlp_ratio=4., qkv_bias=False, drop_rate=0., attn_drop_rate=0.,
drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm, global_pool=None,
local_up_to_layer=3, locality_strength=1., use_pos_embed=True):
super().__init__()
@ -249,13 +246,13 @@ class ConViT(nn.Module):
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.ModuleList([
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
use_gpsa=True,
locality_strength=locality_strength)
if i < local_up_to_layer else
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
use_gpsa=False)
for i in range(depth)])

View File

@ -288,6 +288,8 @@ class DLA(nn.Module):
self.num_features = channels[-1]
self.global_pool, self.fc = create_classifier(
self.num_features, self.num_classes, pool_type=global_pool, use_conv=True)
self.flatten = nn.Flatten(1) if global_pool else nn.Identity()
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
@ -314,6 +316,7 @@ class DLA(nn.Module):
self.num_classes = num_classes
self.global_pool, self.fc = create_classifier(
self.num_features, self.num_classes, pool_type=global_pool, use_conv=True)
self.flatten = nn.Flatten(1) if global_pool else nn.Identity()
def forward_features(self, x):
x = self.base_layer(x)
@ -331,8 +334,7 @@ class DLA(nn.Module):
if self.drop_rate > 0.:
x = F.dropout(x, p=self.drop_rate, training=self.training)
x = self.fc(x)
if not self.global_pool.is_identity():
x = x.flatten(1) # conv classifier, flatten if pooling isn't pass-through (disabled)
x = self.flatten(x)
return x

View File

@ -237,6 +237,7 @@ class DPN(nn.Module):
# Using 1x1 conv for the FC layer to allow the extra pooling scheme
self.global_pool, self.classifier = create_classifier(
self.num_features, self.num_classes, pool_type=global_pool, use_conv=True)
self.flatten = nn.Flatten(1) if global_pool else nn.Identity()
def get_classifier(self):
return self.classifier
@ -245,6 +246,7 @@ class DPN(nn.Module):
self.num_classes = num_classes
self.global_pool, self.classifier = create_classifier(
self.num_features, self.num_classes, pool_type=global_pool, use_conv=True)
self.flatten = nn.Flatten(1) if global_pool else nn.Identity()
def forward_features(self, x):
return self.features(x)
@ -255,8 +257,7 @@ class DPN(nn.Module):
if self.drop_rate > 0.:
x = F.dropout(x, p=self.drop_rate, training=self.training)
x = self.classifier(x)
if not self.global_pool.is_identity():
x = x.flatten(1) # conv classifier, flatten if pooling isn't pass-through (disabled)
x = self.flatten(x)
return x

View File

@ -133,7 +133,7 @@ class GhostBottleneck(nn.Module):
class GhostNet(nn.Module):
def __init__(self, cfgs, num_classes=1000, width=1.0, dropout=0.2, in_chans=3, output_stride=32):
def __init__(self, cfgs, num_classes=1000, width=1.0, dropout=0.2, in_chans=3, output_stride=32, global_pool='avg'):
super(GhostNet, self).__init__()
# setting of inverted residual blocks
assert output_stride == 32, 'only output_stride==32 is valid, dilation not supported'
@ -178,9 +178,10 @@ class GhostNet(nn.Module):
# building last several layers
self.num_features = out_chs = 1280
self.global_pool = SelectAdaptivePool2d(pool_type='avg')
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.conv_head = nn.Conv2d(prev_chs, out_chs, 1, 1, 0, bias=True)
self.act2 = nn.ReLU(inplace=True)
self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
self.classifier = Linear(out_chs, num_classes)
def get_classifier(self):
@ -190,6 +191,7 @@ class GhostNet(nn.Module):
self.num_classes = num_classes
# cannot meaningfully change pooling of efficient head after creation
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
self.classifier = Linear(self.pool_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
@ -204,8 +206,7 @@ class GhostNet(nn.Module):
def forward(self, x):
x = self.forward_features(x)
if not self.global_pool.is_identity():
x = x.view(x.size(0), -1)
x = self.flatten(x)
if self.dropout > 0.:
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.classifier(x)

View File

@ -45,6 +45,13 @@ def load_state_dict(checkpoint_path, use_ema=False):
def load_checkpoint(model, checkpoint_path, use_ema=False, strict=True):
if os.path.splitext(checkpoint_path)[-1].lower() in ('.npz', '.npy'):
# numpy checkpoint, try to load via model specific load_pretrained fn
if hasattr(model, 'load_pretrained'):
model.load_pretrained(checkpoint_path)
else:
raise NotImplementedError('Model cannot load numpy checkpoint')
return
state_dict = load_state_dict(checkpoint_path, use_ema)
model.load_state_dict(state_dict, strict=strict)
@ -477,3 +484,25 @@ def model_parameters(model, exclude_head=False):
return [p for p in model.parameters()][:-2]
else:
return model.parameters()
def named_apply(fn: Callable, module: nn.Module, name='', depth_first=True, include_root=False) -> nn.Module:
if not depth_first and include_root:
fn(module=module, name=name)
for child_name, child_module in module.named_children():
child_name = '.'.join((name, child_name)) if name else child_name
named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
if depth_first and include_root:
fn(module=module, name=name)
return module
def named_modules(module: nn.Module, name='', depth_first=True, include_root=False):
if not depth_first and include_root:
yield name, module
for child_name, child_module in module.named_children():
child_name = '.'.join((name, child_name)) if name else child_name
yield from named_modules(
module=child_module, name=child_name, depth_first=depth_first, include_root=True)
if depth_first and include_root:
yield name, module

View File

@ -55,7 +55,7 @@ class FastAdaptiveAvgPool2d(nn.Module):
self.flatten = flatten
def forward(self, x):
return x.mean((2, 3)) if self.flatten else x.mean((2, 3), keepdim=True)
return x.mean((2, 3), keepdim=not self.flatten)
class AdaptiveAvgMaxPool2d(nn.Module):
@ -82,13 +82,13 @@ class SelectAdaptivePool2d(nn.Module):
def __init__(self, output_size=1, pool_type='fast', flatten=False):
super(SelectAdaptivePool2d, self).__init__()
self.pool_type = pool_type or '' # convert other falsy values to empty string for consistent TS typing
self.flatten = flatten
self.flatten = nn.Flatten(1) if flatten else nn.Identity()
if pool_type == '':
self.pool = nn.Identity() # pass through
elif pool_type == 'fast':
assert output_size == 1
self.pool = FastAdaptiveAvgPool2d(self.flatten)
self.flatten = False
self.pool = FastAdaptiveAvgPool2d(flatten)
self.flatten = nn.Identity()
elif pool_type == 'avg':
self.pool = nn.AdaptiveAvgPool2d(output_size)
elif pool_type == 'avgmax':
@ -101,12 +101,11 @@ class SelectAdaptivePool2d(nn.Module):
assert False, 'Invalid pool type: %s' % pool_type
def is_identity(self):
return self.pool_type == ''
return not self.pool_type
def forward(self, x):
x = self.pool(x)
if self.flatten:
x = x.flatten(1)
x = self.flatten(x)
return x
def feat_mult(self):

View File

@ -20,7 +20,7 @@ def _create_pool(num_features, num_classes, pool_type='avg', use_conv=False):
return global_pool, num_pooled_features
def _create_fc(num_features, num_classes, pool_type='avg', use_conv=False):
def _create_fc(num_features, num_classes, use_conv=False):
if num_classes <= 0:
fc = nn.Identity() # pass-through (no classifier)
elif use_conv:
@ -45,11 +45,12 @@ class ClassifierHead(nn.Module):
self.drop_rate = drop_rate
self.global_pool, num_pooled_features = _create_pool(in_chs, num_classes, pool_type, use_conv=use_conv)
self.fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv)
self.flatten_after_fc = use_conv and pool_type
self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity()
def forward(self, x):
x = self.global_pool(x)
if self.drop_rate:
x = F.dropout(x, p=float(self.drop_rate), training=self.training)
x = self.fc(x)
x = self.flatten(x)
return x

View File

@ -40,6 +40,12 @@ class GluMlp(nn.Module):
self.fc2 = nn.Linear(hidden_features // 2, out_features)
self.drop = nn.Dropout(drop)
def init_weights(self):
# override init of fc1 w/ gate portion set to weight near zero, bias=1
fc1_mid = self.fc1.bias.shape[0] // 2
nn.init.ones_(self.fc1.bias[fc1_mid:])
nn.init.normal_(self.fc1.weight[fc1_mid:], std=1e-6)
def forward(self, x):
x = self.fc1(x)
x, gates = x.chunk(2, dim=-1)

View File

@ -27,7 +27,8 @@ class AvgPool2dSame(nn.AvgPool2d):
super(AvgPool2dSame, self).__init__(kernel_size, stride, (0, 0), ceil_mode, count_include_pad)
def forward(self, x):
return avg_pool2d_same(
x = pad_same(x, self.kernel_size, self.stride)
return F.avg_pool2d(
x, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad)
@ -41,14 +42,15 @@ def max_pool2d_same(
class MaxPool2dSame(nn.MaxPool2d):
""" Tensorflow like 'SAME' wrapper for 2D max pooling
"""
def __init__(self, kernel_size: int, stride=None, padding=0, dilation=1, ceil_mode=False, count_include_pad=True):
def __init__(self, kernel_size: int, stride=None, padding=0, dilation=1, ceil_mode=False):
kernel_size = to_2tuple(kernel_size)
stride = to_2tuple(stride)
dilation = to_2tuple(dilation)
super(MaxPool2dSame, self).__init__(kernel_size, stride, (0, 0), dilation, ceil_mode, count_include_pad)
super(MaxPool2dSame, self).__init__(kernel_size, stride, (0, 0), dilation, ceil_mode)
def forward(self, x):
return max_pool2d_same(x, self.kernel_size, self.stride, self.padding, self.dilation, self.ceil_mode)
x = pad_same(x, self.kernel_size, self.stride, value=-float('inf'))
return F.max_pool2d(x, self.kernel_size, self.stride, (0, 0), self.dilation, self.ceil_mode)
def create_pool2d(pool_type, kernel_size, stride=None, **kwargs):

View File

@ -1,3 +1,21 @@
""" Convolution with Weight Standardization (StdConv and ScaledStdConv)
StdConv:
@article{weightstandardization,
author = {Siyuan Qiao and Huiyu Wang and Chenxi Liu and Wei Shen and Alan Yuille},
title = {Weight Standardization},
journal = {arXiv preprint arXiv:1903.10520},
year = {2019},
}
Code: https://github.com/joe-siyuan-qiao/WeightStandardization
ScaledStdConv:
Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets`
- https://arxiv.org/abs/2101.08692
Official Deepmind JAX code: https://github.com/deepmind/deepmind-research/tree/master/nfnets
Hacked together by / copyright Ross Wightman, 2021.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
@ -5,12 +23,6 @@ import torch.nn.functional as F
from .padding import get_padding, get_padding_value, pad_same
def get_weight(module):
std, mean = torch.std_mean(module.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
weight = (module.weight - mean) / (std + module.eps)
return weight
class StdConv2d(nn.Conv2d):
"""Conv2d with Weight Standardization. Used for BiT ResNet-V2 models.
@ -18,8 +30,8 @@ class StdConv2d(nn.Conv2d):
https://arxiv.org/abs/1903.10520v2
"""
def __init__(
self, in_channel, out_channels, kernel_size, stride=1, padding=None, dilation=1,
groups=1, bias=False, eps=1e-5):
self, in_channel, out_channels, kernel_size, stride=1, padding=None,
dilation=1, groups=1, bias=False, eps=1e-6):
if padding is None:
padding = get_padding(kernel_size, stride, dilation)
super().__init__(
@ -27,13 +39,11 @@ class StdConv2d(nn.Conv2d):
padding=padding, dilation=dilation, groups=groups, bias=bias)
self.eps = eps
def get_weight(self):
std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
weight = (self.weight - mean) / (std + self.eps)
return weight
def forward(self, x):
x = F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups)
weight = F.batch_norm(
self.weight.view(1, self.out_channels, -1), None, None,
training=True, momentum=0., eps=self.eps).reshape_as(self.weight)
x = F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
return x
@ -44,8 +54,8 @@ class StdConv2dSame(nn.Conv2d):
https://arxiv.org/abs/1903.10520v2
"""
def __init__(
self, in_channel, out_channels, kernel_size, stride=1, padding='SAME', dilation=1,
groups=1, bias=False, eps=1e-5):
self, in_channel, out_channels, kernel_size, stride=1, padding='SAME',
dilation=1, groups=1, bias=False, eps=1e-6):
padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, dilation=dilation)
super().__init__(
in_channel, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation,
@ -53,15 +63,13 @@ class StdConv2dSame(nn.Conv2d):
self.same_pad = is_dynamic
self.eps = eps
def get_weight(self):
std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
weight = (self.weight - mean) / (std + self.eps)
return weight
def forward(self, x):
if self.same_pad:
x = pad_same(x, self.kernel_size, self.stride, self.dilation)
x = F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups)
weight = F.batch_norm(
self.weight.view(1, self.out_channels, -1), None, None,
training=True, momentum=0., eps=self.eps).reshape_as(self.weight)
x = F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
return x
@ -75,8 +83,8 @@ class ScaledStdConv2d(nn.Conv2d):
"""
def __init__(
self, in_channels, out_channels, kernel_size, stride=1, padding=None, dilation=1, groups=1,
bias=True, gamma=1.0, eps=1e-5, gain_init=1.0, use_layernorm=False):
self, in_channels, out_channels, kernel_size, stride=1, padding=None,
dilation=1, groups=1, bias=True, gamma=1.0, eps=1e-6, gain_init=1.0):
if padding is None:
padding = get_padding(kernel_size, stride, dilation)
super().__init__(
@ -84,19 +92,14 @@ class ScaledStdConv2d(nn.Conv2d):
groups=groups, bias=bias)
self.gain = nn.Parameter(torch.full((self.out_channels, 1, 1, 1), gain_init))
self.scale = gamma * self.weight[0].numel() ** -0.5 # gamma * 1 / sqrt(fan-in)
self.eps = eps ** 2 if use_layernorm else eps
self.use_layernorm = use_layernorm # experimental, slightly faster/less GPU memory to hijack LN kernel
def get_weight(self):
if self.use_layernorm:
weight = self.scale * F.layer_norm(self.weight, self.weight.shape[1:], eps=self.eps)
else:
std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
weight = self.scale * (self.weight - mean) / (std + self.eps)
return self.gain * weight
self.eps = eps
def forward(self, x):
return F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups)
weight = F.batch_norm(
self.weight.view(1, self.out_channels, -1), None, None,
weight=(self.gain * self.scale).view(-1),
training=True, momentum=0., eps=self.eps).reshape_as(self.weight)
return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
class ScaledStdConv2dSame(nn.Conv2d):
@ -109,8 +112,8 @@ class ScaledStdConv2dSame(nn.Conv2d):
"""
def __init__(
self, in_channels, out_channels, kernel_size, stride=1, padding='SAME', dilation=1, groups=1,
bias=True, gamma=1.0, eps=1e-5, gain_init=1.0, use_layernorm=False):
self, in_channels, out_channels, kernel_size, stride=1, padding='SAME',
dilation=1, groups=1, bias=True, gamma=1.0, eps=1e-6, gain_init=1.0):
padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, dilation=dilation)
super().__init__(
in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation,
@ -118,26 +121,13 @@ class ScaledStdConv2dSame(nn.Conv2d):
self.gain = nn.Parameter(torch.full((self.out_channels, 1, 1, 1), gain_init))
self.scale = gamma * self.weight[0].numel() ** -0.5
self.same_pad = is_dynamic
self.eps = eps ** 2 if use_layernorm else eps
self.use_layernorm = use_layernorm # experimental, slightly faster/less GPU memory to hijack LN kernel
# NOTE an alternate formulation to consider, closer to DeepMind Haiku impl but doesn't seem
# to make much numerical difference (+/- .002 to .004) in top-1 during eval.
# def get_weight(self):
# var, mean = torch.var_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
# scale = torch.rsqrt((self.weight[0].numel() * var).clamp_(self.eps)) * self.gain
# weight = (self.weight - mean) * scale
# return self.gain * weight
def get_weight(self):
if self.use_layernorm:
weight = self.scale * F.layer_norm(self.weight, self.weight.shape[1:], eps=self.eps)
else:
std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
weight = self.scale * (self.weight - mean) / (std + self.eps)
return self.gain * weight
self.eps = eps
def forward(self, x):
if self.same_pad:
x = pad_same(x, self.kernel_size, self.stride, self.dilation)
return F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups)
weight = F.batch_norm(
self.weight.view(1, self.out_channels, -1), None, None,
weight=(self.gain * self.scale).view(-1),
training=True, momentum=0., eps=self.eps).reshape_as(self.weight)
return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)

View File

@ -84,63 +84,33 @@ __all__ = ['Levit']
@register_model
def levit_128s(pretrained=False, fuse=False,distillation=True, use_conv=False, **kwargs):
def levit_128s(pretrained=False, use_conv=False, **kwargs):
return create_levit(
'levit_128s', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
'levit_128s', pretrained=pretrained, use_conv=use_conv, **kwargs)
@register_model
def levit_128(pretrained=False, fuse=False, distillation=True, use_conv=False, **kwargs):
def levit_128(pretrained=False, use_conv=False, **kwargs):
return create_levit(
'levit_128', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
'levit_128', pretrained=pretrained, use_conv=use_conv, **kwargs)
@register_model
def levit_192(pretrained=False, fuse=False, distillation=True, use_conv=False, **kwargs):
def levit_192(pretrained=False, use_conv=False, **kwargs):
return create_levit(
'levit_192', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
'levit_192', pretrained=pretrained, use_conv=use_conv, **kwargs)
@register_model
def levit_256(pretrained=False, fuse=False, distillation=True, use_conv=False, **kwargs):
def levit_256(pretrained=False, use_conv=False, **kwargs):
return create_levit(
'levit_256', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
'levit_256', pretrained=pretrained, use_conv=use_conv, **kwargs)
@register_model
def levit_384(pretrained=False, fuse=False, distillation=True, use_conv=False, **kwargs):
def levit_384(pretrained=False, use_conv=False, **kwargs):
return create_levit(
'levit_384', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
@register_model
def levit_c_128s(pretrained=False, fuse=False, distillation=True, use_conv=True,**kwargs):
return create_levit(
'levit_128s', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
@register_model
def levit_c_128(pretrained=False, fuse=False,distillation=True, use_conv=True, **kwargs):
return create_levit(
'levit_128', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
@register_model
def levit_c_192(pretrained=False, fuse=False, distillation=True, use_conv=True, **kwargs):
return create_levit(
'levit_192', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
@register_model
def levit_c_256(pretrained=False, fuse=False, distillation=True, use_conv=True, **kwargs):
return create_levit(
'levit_256', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
@register_model
def levit_c_384(pretrained=False, fuse=False, distillation=True, use_conv=True, **kwargs):
return create_levit(
'levit_384', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
'levit_384', pretrained=pretrained, use_conv=use_conv, **kwargs)
class ConvNorm(nn.Sequential):
@ -427,6 +397,9 @@ class AttentionSubsample(nn.Module):
class Levit(nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage
NOTE: distillation is defaulted to True since pretrained weights use it, will cause problems
w/ train scripts that don't take tuple outputs,
"""
def __init__(
@ -447,7 +420,8 @@ class Levit(nn.Module):
attn_act_layer='hard_swish',
distillation=True,
use_conv=False,
drop_path=0):
drop_rate=0.,
drop_path_rate=0.):
super().__init__()
act_layer = get_act_layer(act_layer)
attn_act_layer = get_act_layer(attn_act_layer)
@ -486,7 +460,7 @@ class Levit(nn.Module):
Attention(
ed, kd, nh, attn_ratio=ar, act_layer=attn_act_layer,
resolution=resolution, use_conv=use_conv),
drop_path))
drop_path_rate))
if mr > 0:
h = int(ed * mr)
self.blocks.append(
@ -494,7 +468,7 @@ class Levit(nn.Module):
ln_layer(ed, h, resolution=resolution),
act_layer(),
ln_layer(h, ed, bn_weight_init=0, resolution=resolution),
), drop_path))
), drop_path_rate))
if do[0] == 'Subsample':
# ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride)
resolution_ = (resolution - 1) // do[5] + 1
@ -511,26 +485,45 @@ class Levit(nn.Module):
ln_layer(embed_dim[i + 1], h, resolution=resolution),
act_layer(),
ln_layer(h, embed_dim[i + 1], bn_weight_init=0, resolution=resolution),
), drop_path))
), drop_path_rate))
self.blocks = nn.Sequential(*self.blocks)
# Classifier head
self.head = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity()
self.head_dist = None
if distillation:
self.head_dist = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity()
else:
self.head_dist = None
@torch.jit.ignore
def no_weight_decay(self):
return {x for x in self.state_dict().keys() if 'attention_biases' in x}
def forward(self, x):
def get_classifier(self):
if self.head_dist is None:
return self.head
else:
return self.head, self.head_dist
def reset_classifier(self, num_classes, global_pool='', distillation=None):
self.num_classes = num_classes
self.head = NormLinear(self.embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity()
if distillation is not None:
self.distillation = distillation
if self.distillation:
self.head_dist = NormLinear(self.embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity()
else:
self.head_dist = None
def forward_features(self, x):
x = self.patch_embed(x)
if not self.use_conv:
x = x.flatten(2).transpose(1, 2)
x = self.blocks(x)
x = x.mean((-2, -1)) if self.use_conv else x.mean(1)
return x
def forward(self, x):
x = self.forward_features(x)
if self.head_dist is not None:
x, x_dist = self.head(x), self.head_dist(x)
if self.training and not torch.jit.is_scripting():

View File

@ -14,8 +14,9 @@ Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2
year={2021}
}
Also supporting preliminary (not verified) implementations of ResMlp, gMLP, and possibly more...
Also supporting ResMlp, and a preliminary (not verified) implementations of gMLP
Code: https://github.com/facebookresearch/deit
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
@misc{touvron2021resmlp,
title={ResMLP: Feedforward networks for image classification with data-efficient training},
@ -45,7 +46,7 @@ import torch
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg, overlay_external_default_cfg
from .helpers import build_model_with_cfg, overlay_external_default_cfg, named_apply
from .layers import PatchEmbed, Mlp, GluMlp, GatedMlp, DropPath, lecun_normal_, to_2tuple
from .registry import register_model
@ -92,13 +93,40 @@ default_cfgs = dict(
),
gmixer_12_224=_cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
gmixer_24_224=_cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
gmixer_24_224=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gmixer_24_224_raa-7daf7ae6.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
resmlp_12_224=_cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
resmlp_12_224=_cfg(
url='https://dl.fbaipublicfiles.com/deit/resmlp_12_no_dist.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
resmlp_24_224=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resmlp_24_224_raa-a8256759.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=0.89),
resmlp_36_224=_cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
url='https://dl.fbaipublicfiles.com/deit/resmlp_24_no_dist.pth',
#url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resmlp_24_224_raa-a8256759.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
resmlp_36_224=_cfg(
url='https://dl.fbaipublicfiles.com/deit/resmlp_36_no_dist.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
resmlp_big_24_224=_cfg(
url='https://dl.fbaipublicfiles.com/deit/resmlpB_24_no_dist.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
resmlp_12_distilled_224=_cfg(
url='https://dl.fbaipublicfiles.com/deit/resmlp_12_dist.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
resmlp_24_distilled_224=_cfg(
url='https://dl.fbaipublicfiles.com/deit/resmlp_24_dist.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
resmlp_36_distilled_224=_cfg(
url='https://dl.fbaipublicfiles.com/deit/resmlp_36_dist.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
resmlp_big_24_distilled_224=_cfg(
url='https://dl.fbaipublicfiles.com/deit/resmlpB_24_dist.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
resmlp_big_24_224_in22ft1k=_cfg(
url='https://dl.fbaipublicfiles.com/deit/resmlpB_24_22k.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
gmlp_ti16_224=_cfg(),
gmlp_s16_224=_cfg(),
@ -172,6 +200,11 @@ class SpatialGatingUnit(nn.Module):
self.norm = norm_layer(gate_dim)
self.proj = nn.Linear(seq_len, seq_len)
def init_weights(self):
# special init for the projection gate, called as override by base model init
nn.init.normal_(self.proj.weight, std=1e-6)
nn.init.ones_(self.proj.bias)
def forward(self, x):
u, v = x.chunk(2, dim=-1)
v = self.norm(v)
@ -208,7 +241,7 @@ class MlpMixer(nn.Module):
in_chans=3,
patch_size=16,
num_blocks=8,
hidden_dim=512,
embed_dim=512,
mlp_ratio=(0.5, 4.0),
block_layer=MixerBlock,
mlp_layer=Mlp,
@ -221,59 +254,95 @@ class MlpMixer(nn.Module):
):
super().__init__()
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.stem = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=hidden_dim,
norm_layer=norm_layer if stem_norm else None)
img_size=img_size, patch_size=patch_size, in_chans=in_chans,
embed_dim=embed_dim, norm_layer=norm_layer if stem_norm else None)
# FIXME drop_path (stochastic depth scaling rule or all the same?)
self.blocks = nn.Sequential(*[
block_layer(
hidden_dim, self.stem.num_patches, mlp_ratio, mlp_layer=mlp_layer, norm_layer=norm_layer,
embed_dim, self.stem.num_patches, mlp_ratio, mlp_layer=mlp_layer, norm_layer=norm_layer,
act_layer=act_layer, drop=drop_rate, drop_path=drop_path_rate)
for _ in range(num_blocks)])
self.norm = norm_layer(hidden_dim)
self.head = nn.Linear(hidden_dim, self.num_classes) # zero init
self.norm = norm_layer(embed_dim)
self.head = nn.Linear(embed_dim, self.num_classes) # zero init
self.init_weights(nlhb=nlhb)
def init_weights(self, nlhb=False):
head_bias = -math.log(self.num_classes) if nlhb else 0.
for n, m in self.named_modules():
_init_weights(m, n, head_bias=head_bias)
named_apply(partial(_init_weights, head_bias=head_bias), module=self) # depth-first
def forward(self, x):
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
x = self.stem(x)
x = self.blocks(x)
x = self.norm(x)
x = x.mean(dim=1)
return x
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x
def _init_weights(m, n: str, head_bias: float = 0.):
def _init_weights(module: nn.Module, name: str, head_bias: float = 0., flax=False):
""" Mixer weight initialization (trying to match Flax defaults)
"""
if isinstance(m, nn.Linear):
if n.startswith('head'):
nn.init.zeros_(m.weight)
nn.init.constant_(m.bias, head_bias)
elif n.endswith('gate.proj'):
nn.init.normal_(m.weight, std=1e-4)
nn.init.ones_(m.bias)
if isinstance(module, nn.Linear):
if name.startswith('head'):
nn.init.zeros_(module.weight)
nn.init.constant_(module.bias, head_bias)
else:
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
if 'mlp' in n:
nn.init.normal_(m.bias, std=1e-6)
else:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Conv2d):
lecun_normal_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
nn.init.zeros_(m.bias)
nn.init.ones_(m.weight)
if flax:
# Flax defaults
lecun_normal_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
else:
# like MLP init in vit (my original init)
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
if 'mlp' in name:
nn.init.normal_(module.bias, std=1e-6)
else:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Conv2d):
lecun_normal_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d, nn.GroupNorm)):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)
elif hasattr(module, 'init_weights'):
# NOTE if a parent module contains init_weights method, it can override the init of the
# child modules as this will be called in depth-first order.
module.init_weights()
def checkpoint_filter_fn(state_dict, model):
""" Remap checkpoints if needed """
if 'patch_embed.proj.weight' in state_dict:
# Remap FB ResMlp models -> timm
out_dict = {}
for k, v in state_dict.items():
k = k.replace('patch_embed.', 'stem.')
k = k.replace('attn.', 'linear_tokens.')
k = k.replace('mlp.', 'mlp_channels.')
k = k.replace('gamma_', 'ls')
if k.endswith('.alpha') or k.endswith('.beta'):
v = v.reshape(1, 1, -1)
out_dict[k] = v
return out_dict
return state_dict
def _create_mixer(variant, pretrained=False, **kwargs):
@ -283,6 +352,7 @@ def _create_mixer(variant, pretrained=False, **kwargs):
model = build_model_with_cfg(
MlpMixer, variant, pretrained,
default_cfg=default_cfgs[variant],
pretrained_filter_fn=checkpoint_filter_fn,
**kwargs)
return model
@ -292,7 +362,7 @@ def mixer_s32_224(pretrained=False, **kwargs):
""" Mixer-S/32 224x224
Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
"""
model_args = dict(patch_size=32, num_blocks=8, hidden_dim=512, **kwargs)
model_args = dict(patch_size=32, num_blocks=8, embed_dim=512, **kwargs)
model = _create_mixer('mixer_s32_224', pretrained=pretrained, **model_args)
return model
@ -302,7 +372,7 @@ def mixer_s16_224(pretrained=False, **kwargs):
""" Mixer-S/16 224x224
Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
"""
model_args = dict(patch_size=16, num_blocks=8, hidden_dim=512, **kwargs)
model_args = dict(patch_size=16, num_blocks=8, embed_dim=512, **kwargs)
model = _create_mixer('mixer_s16_224', pretrained=pretrained, **model_args)
return model
@ -312,7 +382,7 @@ def mixer_b32_224(pretrained=False, **kwargs):
""" Mixer-B/32 224x224
Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
"""
model_args = dict(patch_size=32, num_blocks=12, hidden_dim=768, **kwargs)
model_args = dict(patch_size=32, num_blocks=12, embed_dim=768, **kwargs)
model = _create_mixer('mixer_b32_224', pretrained=pretrained, **model_args)
return model
@ -322,7 +392,7 @@ def mixer_b16_224(pretrained=False, **kwargs):
""" Mixer-B/16 224x224. ImageNet-1k pretrained weights.
Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
"""
model_args = dict(patch_size=16, num_blocks=12, hidden_dim=768, **kwargs)
model_args = dict(patch_size=16, num_blocks=12, embed_dim=768, **kwargs)
model = _create_mixer('mixer_b16_224', pretrained=pretrained, **model_args)
return model
@ -332,7 +402,7 @@ def mixer_b16_224_in21k(pretrained=False, **kwargs):
""" Mixer-B/16 224x224. ImageNet-21k pretrained weights.
Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
"""
model_args = dict(patch_size=16, num_blocks=12, hidden_dim=768, **kwargs)
model_args = dict(patch_size=16, num_blocks=12, embed_dim=768, **kwargs)
model = _create_mixer('mixer_b16_224_in21k', pretrained=pretrained, **model_args)
return model
@ -342,7 +412,7 @@ def mixer_l32_224(pretrained=False, **kwargs):
""" Mixer-L/32 224x224.
Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
"""
model_args = dict(patch_size=32, num_blocks=24, hidden_dim=1024, **kwargs)
model_args = dict(patch_size=32, num_blocks=24, embed_dim=1024, **kwargs)
model = _create_mixer('mixer_l32_224', pretrained=pretrained, **model_args)
return model
@ -352,7 +422,7 @@ def mixer_l16_224(pretrained=False, **kwargs):
""" Mixer-L/16 224x224. ImageNet-1k pretrained weights.
Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
"""
model_args = dict(patch_size=16, num_blocks=24, hidden_dim=1024, **kwargs)
model_args = dict(patch_size=16, num_blocks=24, embed_dim=1024, **kwargs)
model = _create_mixer('mixer_l16_224', pretrained=pretrained, **model_args)
return model
@ -362,35 +432,38 @@ def mixer_l16_224_in21k(pretrained=False, **kwargs):
""" Mixer-L/16 224x224. ImageNet-21k pretrained weights.
Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
"""
model_args = dict(patch_size=16, num_blocks=24, hidden_dim=1024, **kwargs)
model_args = dict(patch_size=16, num_blocks=24, embed_dim=1024, **kwargs)
model = _create_mixer('mixer_l16_224_in21k', pretrained=pretrained, **model_args)
return model
@register_model
def mixer_b16_224_miil(pretrained=False, **kwargs):
""" Mixer-B/16 224x224. ImageNet-21k pretrained weights.
Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
"""
model_args = dict(patch_size=16, num_blocks=12, hidden_dim=768, **kwargs)
model_args = dict(patch_size=16, num_blocks=12, embed_dim=768, **kwargs)
model = _create_mixer('mixer_b16_224_miil', pretrained=pretrained, **model_args)
return model
@register_model
def mixer_b16_224_miil_in21k(pretrained=False, **kwargs):
""" Mixer-B/16 224x224. ImageNet-1k pretrained weights.
Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
"""
model_args = dict(patch_size=16, num_blocks=12, hidden_dim=768, **kwargs)
model_args = dict(patch_size=16, num_blocks=12, embed_dim=768, **kwargs)
model = _create_mixer('mixer_b16_224_miil_in21k', pretrained=pretrained, **model_args)
return model
@register_model
def gmixer_12_224(pretrained=False, **kwargs):
""" Glu-Mixer-12 224x224 (short & fat)
""" Glu-Mixer-12 224x224
Experiment by Ross Wightman, adding (Si)GLU to MLP-Mixer
"""
model_args = dict(
patch_size=20, num_blocks=12, hidden_dim=512, mlp_ratio=(1.0, 6.0),
patch_size=16, num_blocks=12, embed_dim=384, mlp_ratio=(1.0, 4.0),
mlp_layer=GluMlp, act_layer=nn.SiLU, **kwargs)
model = _create_mixer('gmixer_12_224', pretrained=pretrained, **model_args)
return model
@ -398,11 +471,11 @@ def gmixer_12_224(pretrained=False, **kwargs):
@register_model
def gmixer_24_224(pretrained=False, **kwargs):
""" Glu-Mixer-24 224x224 (tall & slim)
""" Glu-Mixer-24 224x224
Experiment by Ross Wightman, adding (Si)GLU to MLP-Mixer
"""
model_args = dict(
patch_size=20, num_blocks=24, hidden_dim=384, mlp_ratio=(1.0, 6.0),
patch_size=16, num_blocks=24, embed_dim=384, mlp_ratio=(1.0, 4.0),
mlp_layer=GluMlp, act_layer=nn.SiLU, **kwargs)
model = _create_mixer('gmixer_24_224', pretrained=pretrained, **model_args)
return model
@ -414,7 +487,7 @@ def resmlp_12_224(pretrained=False, **kwargs):
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
"""
model_args = dict(
patch_size=16, num_blocks=12, hidden_dim=384, mlp_ratio=4, block_layer=ResBlock, norm_layer=Affine, **kwargs)
patch_size=16, num_blocks=12, embed_dim=384, mlp_ratio=4, block_layer=ResBlock, norm_layer=Affine, **kwargs)
model = _create_mixer('resmlp_12_224', pretrained=pretrained, **model_args)
return model
@ -425,7 +498,8 @@ def resmlp_24_224(pretrained=False, **kwargs):
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
"""
model_args = dict(
patch_size=16, num_blocks=24, hidden_dim=384, mlp_ratio=4, block_layer=ResBlock, norm_layer=Affine, **kwargs)
patch_size=16, num_blocks=24, embed_dim=384, mlp_ratio=4,
block_layer=partial(ResBlock, init_values=1e-5), norm_layer=Affine, **kwargs)
model = _create_mixer('resmlp_24_224', pretrained=pretrained, **model_args)
return model
@ -436,18 +510,90 @@ def resmlp_36_224(pretrained=False, **kwargs):
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
"""
model_args = dict(
patch_size=16, num_blocks=36, hidden_dim=384, mlp_ratio=4, block_layer=ResBlock, norm_layer=Affine, **kwargs)
patch_size=16, num_blocks=36, embed_dim=384, mlp_ratio=4,
block_layer=partial(ResBlock, init_values=1e-6), norm_layer=Affine, **kwargs)
model = _create_mixer('resmlp_36_224', pretrained=pretrained, **model_args)
return model
@register_model
def resmlp_big_24_224(pretrained=False, **kwargs):
""" ResMLP-B-24
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
"""
model_args = dict(
patch_size=8, num_blocks=24, embed_dim=768, mlp_ratio=4,
block_layer=partial(ResBlock, init_values=1e-6), norm_layer=Affine, **kwargs)
model = _create_mixer('resmlp_big_24_224', pretrained=pretrained, **model_args)
return model
@register_model
def resmlp_12_distilled_224(pretrained=False, **kwargs):
""" ResMLP-12
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
"""
model_args = dict(
patch_size=16, num_blocks=12, embed_dim=384, mlp_ratio=4, block_layer=ResBlock, norm_layer=Affine, **kwargs)
model = _create_mixer('resmlp_12_distilled_224', pretrained=pretrained, **model_args)
return model
@register_model
def resmlp_24_distilled_224(pretrained=False, **kwargs):
""" ResMLP-24
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
"""
model_args = dict(
patch_size=16, num_blocks=24, embed_dim=384, mlp_ratio=4,
block_layer=partial(ResBlock, init_values=1e-5), norm_layer=Affine, **kwargs)
model = _create_mixer('resmlp_24_distilled_224', pretrained=pretrained, **model_args)
return model
@register_model
def resmlp_36_distilled_224(pretrained=False, **kwargs):
""" ResMLP-36
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
"""
model_args = dict(
patch_size=16, num_blocks=36, embed_dim=384, mlp_ratio=4,
block_layer=partial(ResBlock, init_values=1e-6), norm_layer=Affine, **kwargs)
model = _create_mixer('resmlp_36_distilled_224', pretrained=pretrained, **model_args)
return model
@register_model
def resmlp_big_24_distilled_224(pretrained=False, **kwargs):
""" ResMLP-B-24
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
"""
model_args = dict(
patch_size=8, num_blocks=24, embed_dim=768, mlp_ratio=4,
block_layer=partial(ResBlock, init_values=1e-6), norm_layer=Affine, **kwargs)
model = _create_mixer('resmlp_big_24_distilled_224', pretrained=pretrained, **model_args)
return model
@register_model
def resmlp_big_24_224_in22ft1k(pretrained=False, **kwargs):
""" ResMLP-B-24
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
"""
model_args = dict(
patch_size=8, num_blocks=24, embed_dim=768, mlp_ratio=4,
block_layer=partial(ResBlock, init_values=1e-6), norm_layer=Affine, **kwargs)
model = _create_mixer('resmlp_big_24_224_in22ft1k', pretrained=pretrained, **model_args)
return model
@register_model
def gmlp_ti16_224(pretrained=False, **kwargs):
""" gMLP-Tiny
Paper: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050
"""
model_args = dict(
patch_size=16, num_blocks=30, hidden_dim=128, mlp_ratio=6, block_layer=SpatialGatingBlock,
patch_size=16, num_blocks=30, embed_dim=128, mlp_ratio=6, block_layer=SpatialGatingBlock,
mlp_layer=GatedMlp, **kwargs)
model = _create_mixer('gmlp_ti16_224', pretrained=pretrained, **model_args)
return model
@ -459,7 +605,7 @@ def gmlp_s16_224(pretrained=False, **kwargs):
Paper: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050
"""
model_args = dict(
patch_size=16, num_blocks=30, hidden_dim=256, mlp_ratio=6, block_layer=SpatialGatingBlock,
patch_size=16, num_blocks=30, embed_dim=256, mlp_ratio=6, block_layer=SpatialGatingBlock,
mlp_layer=GatedMlp, **kwargs)
model = _create_mixer('gmlp_s16_224', pretrained=pretrained, **model_args)
return model
@ -471,7 +617,7 @@ def gmlp_b16_224(pretrained=False, **kwargs):
Paper: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050
"""
model_args = dict(
patch_size=16, num_blocks=30, hidden_dim=512, mlp_ratio=6, block_layer=SpatialGatingBlock,
patch_size=16, num_blocks=30, embed_dim=512, mlp_ratio=6, block_layer=SpatialGatingBlock,
mlp_layer=GatedMlp, **kwargs)
model = _create_mixer('gmlp_b16_224', pretrained=pretrained, **model_args)
return model

View File

@ -119,6 +119,7 @@ class MobileNetV3(nn.Module):
num_pooled_chs = head_chs * self.global_pool.feat_mult()
self.conv_head = create_conv2d(num_pooled_chs, self.num_features, 1, padding=pad_type, bias=head_bias)
self.act2 = act_layer(inplace=True)
self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
efficientnet_init_weights(self)
@ -137,6 +138,7 @@ class MobileNetV3(nn.Module):
self.num_classes = num_classes
# cannot meaningfully change pooling of efficient head after creation
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
@ -151,8 +153,7 @@ class MobileNetV3(nn.Module):
def forward(self, x):
x = self.forward_features(x)
if not self.global_pool.is_identity():
x = x.flatten(1)
x = self.flatten(x)
if self.drop_rate > 0.:
x = F.dropout(x, p=self.drop_rate, training=self.training)
return self.classifier(x)

View File

@ -111,11 +111,11 @@ default_cfgs = dict(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecanfnet_l1_ra2-7dce93cd.pth',
pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 320, 320), crop_pct=1.0),
eca_nfnet_l2=_dcfg(
url='',
pool_size=(9, 9), input_size=(3, 288, 288), test_input_size=(3, 352, 352), crop_pct=1.0),
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecanfnet_l2_ra3-da781a61.pth',
pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 384, 384), crop_pct=1.0),
eca_nfnet_l3=_dcfg(
url='',
pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 384, 384), crop_pct=1.0),
pool_size=(11, 11), input_size=(3, 352, 352), test_input_size=(3, 448, 448), crop_pct=1.0),
nf_regnet_b0=_dcfg(
url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv'),
@ -166,6 +166,7 @@ class NfCfg:
extra_conv: bool = False # extra 3x3 bottleneck convolution for NFNet models
gamma_in_act: bool = False
same_padding: bool = False
std_conv_eps: float = 1e-5
skipinit: bool = False # disabled by default, non-trivial performance impact
zero_init_fc: bool = False
act_layer: str = 'silu'
@ -209,6 +210,7 @@ def _dm_nfnet_cfg(depths, channels=(256, 512, 1536, 1536), act_layer='gelu', ski
return cfg
model_cfgs = dict(
# NFNet-F models w/ GELU compatible with DeepMind weights
dm_nfnet_f0=_dm_nfnet_cfg(depths=(1, 2, 6, 3)),
@ -482,10 +484,10 @@ class NormFreeNet(nn.Module):
conv_layer = ScaledStdConv2dSame if cfg.same_padding else ScaledStdConv2d
if cfg.gamma_in_act:
act_layer = act_with_gamma(cfg.act_layer, gamma=_nonlin_gamma[cfg.act_layer])
conv_layer = partial(conv_layer, eps=1e-4) # DM weights better with higher eps
conv_layer = partial(conv_layer, eps=cfg.std_conv_eps)
else:
act_layer = get_act_layer(cfg.act_layer)
conv_layer = partial(conv_layer, gamma=_nonlin_gamma[cfg.act_layer])
conv_layer = partial(conv_layer, gamma=_nonlin_gamma[cfg.act_layer], eps=cfg.std_conv_eps)
attn_layer = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None
stem_chs = make_divisible((cfg.stem_chs or cfg.channels[0]) * cfg.width_factor, cfg.ch_div)

View File

@ -186,12 +186,13 @@ class PoolingVisionTransformer(nn.Module):
]
self.transformers = SequentialTuple(*transformers)
self.norm = nn.LayerNorm(base_dims[-1] * heads[-1], eps=1e-6)
self.embed_dim = base_dims[-1] * heads[-1]
self.num_features = self.embed_dim = base_dims[-1] * heads[-1]
# Classifier head
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) \
if num_classes > 0 and distilled else nn.Identity()
self.head_dist = None
if distilled:
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02)
@ -207,13 +208,16 @@ class PoolingVisionTransformer(nn.Module):
return {'pos_embed', 'cls_token'}
def get_classifier(self):
return self.head
if self.head_dist is not None:
return self.head, self.head_dist
else:
return self.head
def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) \
if num_classes > 0 and self.num_tokens == 2 else nn.Identity()
if self.head_dist is not None:
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
x = self.patch_embed(x)
@ -221,19 +225,21 @@ class PoolingVisionTransformer(nn.Module):
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
x, cls_tokens = self.transformers((x, cls_tokens))
cls_tokens = self.norm(cls_tokens)
return cls_tokens
if self.head_dist is not None:
return cls_tokens[:, 0], cls_tokens[:, 1]
else:
return cls_tokens[:, 0]
def forward(self, x):
x = self.forward_features(x)
x_cls = self.head(x[:, 0])
if self.num_tokens > 1:
x_dist = self.head_dist(x[:, 1])
if self.head_dist is not None:
x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple
if self.training and not torch.jit.is_scripting():
return x_cls, x_dist
return x, x_dist
else:
return (x_cls + x_dist) / 2
return (x + x_dist) / 2
else:
return x_cls
return self.head(x)
def checkpoint_filter_fn(state_dict, model):

View File

@ -65,11 +65,18 @@ def list_models(filter='', module='', pretrained=False, exclude_filters='', name
model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module
"""
if module:
models = list(_module_to_models[module])
all_models = list(_module_to_models[module])
else:
models = _model_entrypoints.keys()
all_models = _model_entrypoints.keys()
if filter:
models = fnmatch.filter(models, filter) # include these models
models = []
include_filters = filter if isinstance(filter, (tuple, list)) else [filter]
for f in include_filters:
include_models = fnmatch.filter(all_models, f) # include these models
if len(include_models):
models = set(models).union(include_models)
else:
models = all_models
if exclude_filters:
if not isinstance(exclude_filters, (tuple, list)):
exclude_filters = [exclude_filters]

View File

@ -638,12 +638,15 @@ class ResNet(nn.Module):
self.num_features = 512 * block.expansion
self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
self.init_weights(zero_init_last_bn=zero_init_last_bn)
def init_weights(self, zero_init_last_bn=True):
for n, m in self.named_modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1.)
nn.init.constant_(m.bias, 0.)
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
if zero_init_last_bn:
for m in self.modules():
if hasattr(m, 'zero_init_last_bn'):

View File

@ -11,6 +11,7 @@ https://github.com/google-research/vision_transformer
Thanks to the Google team for the above two repositories and associated papers:
* Big Transfer (BiT): General Visual Representation Learning - https://arxiv.org/abs/1912.11370
* An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale - https://arxiv.org/abs/2010.11929
* Knowledge distillation: A good teacher is patient and consistent - https://arxiv.org/abs/2106.05237
Original copyright of Google code below, modifications by Ross Wightman, Copyright 2020.
"""
@ -35,16 +36,16 @@ import torch.nn as nn
from functools import partial
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .helpers import build_model_with_cfg
from .helpers import build_model_with_cfg, named_apply, adapt_input_conv
from .registry import register_model
from .layers import GroupNormAct, ClassifierHead, DropPath, AvgPool2dSame, create_pool2d, StdConv2d
from .layers import GroupNormAct, ClassifierHead, DropPath, AvgPool2dSame, create_pool2d, StdConv2d, create_conv2d
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 480, 480), 'pool_size': (7, 7),
'crop_pct': 1.0, 'interpolation': 'bilinear',
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bilinear',
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
'first_conv': 'stem.conv', 'classifier': 'head.fc',
**kwargs
@ -54,17 +55,23 @@ def _cfg(url='', **kwargs):
default_cfgs = {
# pretrained on imagenet21k, finetuned on imagenet1k
'resnetv2_50x1_bitm': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-M-R50x1-ILSVRC2012.npz'),
url='https://storage.googleapis.com/bit_models/BiT-M-R50x1-ILSVRC2012.npz',
input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0),
'resnetv2_50x3_bitm': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-M-R50x3-ILSVRC2012.npz'),
url='https://storage.googleapis.com/bit_models/BiT-M-R50x3-ILSVRC2012.npz',
input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0),
'resnetv2_101x1_bitm': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-M-R101x1-ILSVRC2012.npz'),
url='https://storage.googleapis.com/bit_models/BiT-M-R101x1-ILSVRC2012.npz',
input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0),
'resnetv2_101x3_bitm': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-M-R101x3-ILSVRC2012.npz'),
url='https://storage.googleapis.com/bit_models/BiT-M-R101x3-ILSVRC2012.npz',
input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0),
'resnetv2_152x2_bitm': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-M-R152x2-ILSVRC2012.npz'),
url='https://storage.googleapis.com/bit_models/BiT-M-R152x2-ILSVRC2012.npz',
input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0),
'resnetv2_152x4_bitm': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-M-R152x4-ILSVRC2012.npz'),
url='https://storage.googleapis.com/bit_models/BiT-M-R152x4-ILSVRC2012.npz',
input_size=(3, 480, 480), pool_size=(15, 15), crop_pct=1.0), # only one at 480x480?
# trained on imagenet-21k
'resnetv2_50x1_bitm_in21k': _cfg(
@ -86,20 +93,20 @@ default_cfgs = {
url='https://storage.googleapis.com/bit_models/BiT-M-R152x4.npz',
num_classes=21843),
'resnetv2_50x1_bit_distilled': _cfg(
url='https://storage.googleapis.com/bit_models/distill/R50x1_224.npz',
interpolation='bicubic'),
'resnetv2_152x2_bit_teacher': _cfg(
url='https://storage.googleapis.com/bit_models/distill/R152x2_T_224.npz',
interpolation='bicubic'),
'resnetv2_152x2_bit_teacher_384': _cfg(
url='https://storage.googleapis.com/bit_models/distill/R152x2_T_384.npz',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, interpolation='bicubic'),
# trained on imagenet-1k, NOTE not overly interesting set of weights, leaving disabled for now
# 'resnetv2_50x1_bits': _cfg(
# url='https://storage.googleapis.com/bit_models/BiT-S-R50x1.npz'),
# 'resnetv2_50x3_bits': _cfg(
# url='https://storage.googleapis.com/bit_models/BiT-S-R50x3.npz'),
# 'resnetv2_101x1_bits': _cfg(
# url='https://storage.googleapis.com/bit_models/BiT-S-R101x3.npz'),
# 'resnetv2_101x3_bits': _cfg(
# url='https://storage.googleapis.com/bit_models/BiT-S-R101x3.npz'),
# 'resnetv2_152x2_bits': _cfg(
# url='https://storage.googleapis.com/bit_models/BiT-S-R152x2.npz'),
# 'resnetv2_152x4_bits': _cfg(
# url='https://storage.googleapis.com/bit_models/BiT-S-R152x4.npz'),
'resnetv2_50': _cfg(
interpolation='bicubic'),
'resnetv2_50d': _cfg(
interpolation='bicubic', first_conv='stem.conv1'),
}
@ -111,13 +118,6 @@ def make_div(v, divisor=8):
return new_v
def tf2th(conv_weights):
"""Possibly convert HWIO to OIHW."""
if conv_weights.ndim == 4:
conv_weights = conv_weights.transpose([3, 2, 0, 1])
return torch.from_numpy(conv_weights)
class PreActBottleneck(nn.Module):
"""Pre-activation (v2) bottleneck block.
@ -152,6 +152,9 @@ class PreActBottleneck(nn.Module):
self.conv3 = conv_layer(mid_chs, out_chs, 1)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
def zero_init_last_bn(self):
nn.init.zeros_(self.norm3.weight)
def forward(self, x):
x_preact = self.norm1(x)
@ -198,6 +201,9 @@ class Bottleneck(nn.Module):
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
self.act3 = act_layer(inplace=True)
def zero_init_last_bn(self):
nn.init.zeros_(self.norm3.weight)
def forward(self, x):
# shortcut branch
shortcut = x
@ -285,14 +291,17 @@ def create_resnetv2_stem(
# A 3 deep 3x3 conv stack as in ResNet V1D models
mid_chs = out_chs // 2
stem['conv1'] = conv_layer(in_chs, mid_chs, kernel_size=3, stride=2)
stem['norm1'] = norm_layer(mid_chs)
stem['conv2'] = conv_layer(mid_chs, mid_chs, kernel_size=3, stride=1)
stem['norm2'] = norm_layer(mid_chs)
stem['conv3'] = conv_layer(mid_chs, out_chs, kernel_size=3, stride=1)
if not preact:
stem['norm3'] = norm_layer(out_chs)
else:
# The usual 7x7 stem conv
stem['conv'] = conv_layer(in_chs, out_chs, kernel_size=7, stride=2)
if not preact:
stem['norm'] = norm_layer(out_chs)
if not preact:
stem['norm'] = norm_layer(out_chs)
if 'fixed' in stem_type:
# 'fixed' SAME padding approximation that is used in BiT models
@ -312,11 +321,12 @@ class ResNetV2(nn.Module):
"""Implementation of Pre-activation (v2) ResNet mode.
"""
def __init__(self, layers, channels=(256, 512, 1024, 2048),
num_classes=1000, in_chans=3, global_pool='avg', output_stride=32,
width_factor=1, stem_chs=64, stem_type='', avg_down=False, preact=True,
act_layer=nn.ReLU, conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32),
drop_rate=0., drop_path_rate=0.):
def __init__(
self, layers, channels=(256, 512, 1024, 2048),
num_classes=1000, in_chans=3, global_pool='avg', output_stride=32,
width_factor=1, stem_chs=64, stem_type='', avg_down=False, preact=True,
act_layer=nn.ReLU, conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32),
drop_rate=0., drop_path_rate=0., zero_init_last_bn=True):
super().__init__()
self.num_classes = num_classes
self.drop_rate = drop_rate
@ -354,12 +364,14 @@ class ResNetV2(nn.Module):
self.head = ClassifierHead(
self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate, use_conv=True)
for n, m in self.named_modules():
if isinstance(m, nn.Linear) or ('.fc' in n and isinstance(m, nn.Conv2d)):
nn.init.normal_(m.weight, mean=0.0, std=0.01)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
self.init_weights(zero_init_last_bn=zero_init_last_bn)
def init_weights(self, zero_init_last_bn=True):
named_apply(partial(_init_weights, zero_init_last_bn=zero_init_last_bn), self)
@torch.jit.ignore()
def load_pretrained(self, checkpoint_path, prefix='resnet/'):
_load_weights(self, checkpoint_path, prefix)
def get_classifier(self):
return self.head.fc
@ -378,41 +390,59 @@ class ResNetV2(nn.Module):
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
if not self.head.global_pool.is_identity():
x = x.flatten(1) # conv classifier, flatten if pooling isn't pass-through (disabled)
return x
def load_pretrained(self, checkpoint_path, prefix='resnet/'):
import numpy as np
weights = np.load(checkpoint_path)
with torch.no_grad():
stem_conv_w = tf2th(weights[f'{prefix}root_block/standardized_conv2d/kernel'])
if self.stem.conv.weight.shape[1] == 1:
self.stem.conv.weight.copy_(stem_conv_w.sum(dim=1, keepdim=True))
# FIXME handle > 3 in_chans?
else:
self.stem.conv.weight.copy_(stem_conv_w)
self.norm.weight.copy_(tf2th(weights[f'{prefix}group_norm/gamma']))
self.norm.bias.copy_(tf2th(weights[f'{prefix}group_norm/beta']))
if self.head.fc.weight.shape[0] == weights[f'{prefix}head/conv2d/kernel'].shape[-1]:
self.head.fc.weight.copy_(tf2th(weights[f'{prefix}head/conv2d/kernel']))
self.head.fc.bias.copy_(tf2th(weights[f'{prefix}head/conv2d/bias']))
for i, (sname, stage) in enumerate(self.stages.named_children()):
for j, (bname, block) in enumerate(stage.blocks.named_children()):
convname = 'standardized_conv2d'
block_prefix = f'{prefix}block{i + 1}/unit{j + 1:02d}/'
block.conv1.weight.copy_(tf2th(weights[f'{block_prefix}a/{convname}/kernel']))
block.conv2.weight.copy_(tf2th(weights[f'{block_prefix}b/{convname}/kernel']))
block.conv3.weight.copy_(tf2th(weights[f'{block_prefix}c/{convname}/kernel']))
block.norm1.weight.copy_(tf2th(weights[f'{block_prefix}a/group_norm/gamma']))
block.norm2.weight.copy_(tf2th(weights[f'{block_prefix}b/group_norm/gamma']))
block.norm3.weight.copy_(tf2th(weights[f'{block_prefix}c/group_norm/gamma']))
block.norm1.bias.copy_(tf2th(weights[f'{block_prefix}a/group_norm/beta']))
block.norm2.bias.copy_(tf2th(weights[f'{block_prefix}b/group_norm/beta']))
block.norm3.bias.copy_(tf2th(weights[f'{block_prefix}c/group_norm/beta']))
if block.downsample is not None:
w = weights[f'{block_prefix}a/proj/{convname}/kernel']
block.downsample.conv.weight.copy_(tf2th(w))
def _init_weights(module: nn.Module, name: str = '', zero_init_last_bn=True):
if isinstance(module, nn.Linear) or ('head.fc' in name and isinstance(module, nn.Conv2d)):
nn.init.normal_(module.weight, mean=0.0, std=0.01)
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Conv2d):
nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, (nn.BatchNorm2d, nn.LayerNorm, nn.GroupNorm)):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)
elif zero_init_last_bn and hasattr(module, 'zero_init_last_bn'):
module.zero_init_last_bn()
@torch.no_grad()
def _load_weights(model: nn.Module, checkpoint_path: str, prefix: str = 'resnet/'):
import numpy as np
def t2p(conv_weights):
"""Possibly convert HWIO to OIHW."""
if conv_weights.ndim == 4:
conv_weights = conv_weights.transpose([3, 2, 0, 1])
return torch.from_numpy(conv_weights)
weights = np.load(checkpoint_path)
stem_conv_w = adapt_input_conv(
model.stem.conv.weight.shape[1], t2p(weights[f'{prefix}root_block/standardized_conv2d/kernel']))
model.stem.conv.weight.copy_(stem_conv_w)
model.norm.weight.copy_(t2p(weights[f'{prefix}group_norm/gamma']))
model.norm.bias.copy_(t2p(weights[f'{prefix}group_norm/beta']))
if model.head.fc.weight.shape[0] == weights[f'{prefix}head/conv2d/kernel'].shape[-1]:
model.head.fc.weight.copy_(t2p(weights[f'{prefix}head/conv2d/kernel']))
model.head.fc.bias.copy_(t2p(weights[f'{prefix}head/conv2d/bias']))
for i, (sname, stage) in enumerate(model.stages.named_children()):
for j, (bname, block) in enumerate(stage.blocks.named_children()):
cname = 'standardized_conv2d'
block_prefix = f'{prefix}block{i + 1}/unit{j + 1:02d}/'
block.conv1.weight.copy_(t2p(weights[f'{block_prefix}a/{cname}/kernel']))
block.conv2.weight.copy_(t2p(weights[f'{block_prefix}b/{cname}/kernel']))
block.conv3.weight.copy_(t2p(weights[f'{block_prefix}c/{cname}/kernel']))
block.norm1.weight.copy_(t2p(weights[f'{block_prefix}a/group_norm/gamma']))
block.norm2.weight.copy_(t2p(weights[f'{block_prefix}b/group_norm/gamma']))
block.norm3.weight.copy_(t2p(weights[f'{block_prefix}c/group_norm/gamma']))
block.norm1.bias.copy_(t2p(weights[f'{block_prefix}a/group_norm/beta']))
block.norm2.bias.copy_(t2p(weights[f'{block_prefix}b/group_norm/beta']))
block.norm3.bias.copy_(t2p(weights[f'{block_prefix}c/group_norm/beta']))
if block.downsample is not None:
w = weights[f'{block_prefix}a/proj/{cname}/kernel']
block.downsample.conv.weight.copy_(t2p(w))
def _create_resnetv2(variant, pretrained=False, **kwargs):
@ -425,130 +455,126 @@ def _create_resnetv2(variant, pretrained=False, **kwargs):
**kwargs)
def _create_resnetv2_bit(variant, pretrained=False, **kwargs):
return _create_resnetv2(
variant, pretrained=pretrained, stem_type='fixed', conv_layer=partial(StdConv2d, eps=1e-8), **kwargs)
@register_model
def resnetv2_50x1_bitm(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_50x1_bitm', pretrained=pretrained,
layers=[3, 4, 6, 3], width_factor=1, stem_type='fixed', **kwargs)
return _create_resnetv2_bit(
'resnetv2_50x1_bitm', pretrained=pretrained, layers=[3, 4, 6, 3], width_factor=1, **kwargs)
@register_model
def resnetv2_50x3_bitm(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_50x3_bitm', pretrained=pretrained,
layers=[3, 4, 6, 3], width_factor=3, stem_type='fixed', **kwargs)
return _create_resnetv2_bit(
'resnetv2_50x3_bitm', pretrained=pretrained, layers=[3, 4, 6, 3], width_factor=3, **kwargs)
@register_model
def resnetv2_101x1_bitm(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_101x1_bitm', pretrained=pretrained,
layers=[3, 4, 23, 3], width_factor=1, stem_type='fixed', **kwargs)
return _create_resnetv2_bit(
'resnetv2_101x1_bitm', pretrained=pretrained, layers=[3, 4, 23, 3], width_factor=1, **kwargs)
@register_model
def resnetv2_101x3_bitm(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_101x3_bitm', pretrained=pretrained,
layers=[3, 4, 23, 3], width_factor=3, stem_type='fixed', **kwargs)
return _create_resnetv2_bit(
'resnetv2_101x3_bitm', pretrained=pretrained, layers=[3, 4, 23, 3], width_factor=3, **kwargs)
@register_model
def resnetv2_152x2_bitm(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_152x2_bitm', pretrained=pretrained,
layers=[3, 8, 36, 3], width_factor=2, stem_type='fixed', **kwargs)
return _create_resnetv2_bit(
'resnetv2_152x2_bitm', pretrained=pretrained, layers=[3, 8, 36, 3], width_factor=2, **kwargs)
@register_model
def resnetv2_152x4_bitm(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_152x4_bitm', pretrained=pretrained,
layers=[3, 8, 36, 3], width_factor=4, stem_type='fixed', **kwargs)
return _create_resnetv2_bit(
'resnetv2_152x4_bitm', pretrained=pretrained, layers=[3, 8, 36, 3], width_factor=4, **kwargs)
@register_model
def resnetv2_50x1_bitm_in21k(pretrained=False, **kwargs):
return _create_resnetv2(
return _create_resnetv2_bit(
'resnetv2_50x1_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
layers=[3, 4, 6, 3], width_factor=1, stem_type='fixed', **kwargs)
layers=[3, 4, 6, 3], width_factor=1, **kwargs)
@register_model
def resnetv2_50x3_bitm_in21k(pretrained=False, **kwargs):
return _create_resnetv2(
return _create_resnetv2_bit(
'resnetv2_50x3_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
layers=[3, 4, 6, 3], width_factor=3, stem_type='fixed', **kwargs)
layers=[3, 4, 6, 3], width_factor=3, **kwargs)
@register_model
def resnetv2_101x1_bitm_in21k(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_101x1_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
layers=[3, 4, 23, 3], width_factor=1, stem_type='fixed', **kwargs)
layers=[3, 4, 23, 3], width_factor=1, **kwargs)
@register_model
def resnetv2_101x3_bitm_in21k(pretrained=False, **kwargs):
return _create_resnetv2(
return _create_resnetv2_bit(
'resnetv2_101x3_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
layers=[3, 4, 23, 3], width_factor=3, stem_type='fixed', **kwargs)
layers=[3, 4, 23, 3], width_factor=3, **kwargs)
@register_model
def resnetv2_152x2_bitm_in21k(pretrained=False, **kwargs):
return _create_resnetv2(
return _create_resnetv2_bit(
'resnetv2_152x2_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
layers=[3, 8, 36, 3], width_factor=2, stem_type='fixed', **kwargs)
layers=[3, 8, 36, 3], width_factor=2, **kwargs)
@register_model
def resnetv2_152x4_bitm_in21k(pretrained=False, **kwargs):
return _create_resnetv2(
return _create_resnetv2_bit(
'resnetv2_152x4_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
layers=[3, 8, 36, 3], width_factor=4, stem_type='fixed', **kwargs)
layers=[3, 8, 36, 3], width_factor=4, **kwargs)
# NOTE the 'S' versions of the model weights arent as interesting as original 21k or transfer to 1K M.
@register_model
def resnetv2_50x1_bit_distilled(pretrained=False, **kwargs):
""" ResNetV2-50x1-BiT Distilled
Paper: Knowledge distillation: A good teacher is patient and consistent - https://arxiv.org/abs/2106.05237
"""
return _create_resnetv2_bit(
'resnetv2_50x1_bit_distilled', pretrained=pretrained, layers=[3, 4, 6, 3], width_factor=1, **kwargs)
# @register_model
# def resnetv2_50x1_bits(pretrained=False, **kwargs):
# return _create_resnetv2(
# 'resnetv2_50x1_bits', pretrained=pretrained,
# layers=[3, 4, 6, 3], width_factor=1, stem_type='fixed', **kwargs)
#
#
# @register_model
# def resnetv2_50x3_bits(pretrained=False, **kwargs):
# return _create_resnetv2(
# 'resnetv2_50x3_bits', pretrained=pretrained,
# layers=[3, 4, 6, 3], width_factor=3, stem_type='fixed', **kwargs)
#
#
# @register_model
# def resnetv2_101x1_bits(pretrained=False, **kwargs):
# return _create_resnetv2(
# 'resnetv2_101x1_bits', pretrained=pretrained,
# layers=[3, 4, 23, 3], width_factor=1, stem_type='fixed', **kwargs)
#
#
# @register_model
# def resnetv2_101x3_bits(pretrained=False, **kwargs):
# return _create_resnetv2(
# 'resnetv2_101x3_bits', pretrained=pretrained,
# layers=[3, 4, 23, 3], width_factor=3, stem_type='fixed', **kwargs)
#
#
# @register_model
# def resnetv2_152x2_bits(pretrained=False, **kwargs):
# return _create_resnetv2(
# 'resnetv2_152x2_bits', pretrained=pretrained,
# layers=[3, 8, 36, 3], width_factor=2, stem_type='fixed', **kwargs)
#
#
# @register_model
# def resnetv2_152x4_bits(pretrained=False, **kwargs):
# return _create_resnetv2(
# 'resnetv2_152x4_bits', pretrained=pretrained,
# layers=[3, 8, 36, 3], width_factor=4, stem_type='fixed', **kwargs)
#
@register_model
def resnetv2_152x2_bit_teacher(pretrained=False, **kwargs):
""" ResNetV2-152x2-BiT Teacher
Paper: Knowledge distillation: A good teacher is patient and consistent - https://arxiv.org/abs/2106.05237
"""
return _create_resnetv2_bit(
'resnetv2_152x2_bit_teacher', pretrained=pretrained, layers=[3, 8, 36, 3], width_factor=2, **kwargs)
@register_model
def resnetv2_152x2_bit_teacher_384(pretrained=False, **kwargs):
""" ResNetV2-152xx-BiT Teacher @ 384x384
Paper: Knowledge distillation: A good teacher is patient and consistent - https://arxiv.org/abs/2106.05237
"""
return _create_resnetv2_bit(
'resnetv2_152x2_bit_teacher_384', pretrained=pretrained, layers=[3, 8, 36, 3], width_factor=2, **kwargs)
@register_model
def resnetv2_50(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_50', pretrained=pretrained,
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=nn.BatchNorm2d, **kwargs)
@register_model
def resnetv2_50d(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_50d', pretrained=pretrained,
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=nn.BatchNorm2d,
stem_type='deep', avg_down=True, **kwargs)

View File

@ -126,19 +126,18 @@ class WindowAttention(nn.Module):
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.scale = head_dim ** -0.5
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
@ -210,7 +209,6 @@ class SwinTransformerBlock(nn.Module):
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
@ -219,7 +217,7 @@ class SwinTransformerBlock(nn.Module):
"""
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
@ -236,8 +234,8 @@ class SwinTransformerBlock(nn.Module):
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, qkv_bias=qkv_bias,
attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
@ -369,7 +367,6 @@ class BasicLayer(nn.Module):
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
@ -379,7 +376,7 @@ class BasicLayer(nn.Module):
"""
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
super().__init__()
@ -390,14 +387,11 @@ class BasicLayer(nn.Module):
# build blocks
self.blocks = nn.ModuleList([
SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
num_heads=num_heads, window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop, attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer)
SwinTransformerBlock(
dim=dim, input_resolution=input_resolution, num_heads=num_heads, window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2, mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, drop=drop, attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer)
for i in range(depth)])
# patch merging layer
@ -436,7 +430,6 @@ class SwinTransformer(nn.Module):
window_size (int): Window size. Default: 7
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
drop_rate (float): Dropout rate. Default: 0
attn_drop_rate (float): Attention dropout rate. Default: 0
drop_path_rate (float): Stochastic depth rate. Default: 0.1
@ -448,7 +441,7 @@ class SwinTransformer(nn.Module):
def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24),
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
window_size=7, mlp_ratio=4., qkv_bias=True,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
use_checkpoint=False, weight_init='', **kwargs):
@ -491,8 +484,9 @@ class SwinTransformer(nn.Module):
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate,
qkv_bias=qkv_bias,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
norm_layer=norm_layer,
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
@ -520,6 +514,13 @@ class SwinTransformer(nn.Module):
def no_weight_decay_keywords(self):
return {'relative_position_bias_table'}
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
x = self.patch_embed(x)
if self.absolute_pos_embed is not None:

View File

@ -278,6 +278,8 @@ class Twins(nn.Module):
super().__init__()
self.num_classes = num_classes
self.depths = depths
self.embed_dims = embed_dims
self.num_features = embed_dims[-1]
img_size = to_2tuple(img_size)
prev_chs = in_chans
@ -303,10 +305,10 @@ class Twins(nn.Module):
self.pos_block = nn.ModuleList([PosConv(embed_dim, embed_dim) for embed_dim in embed_dims])
self.norm = norm_layer(embed_dims[-1])
self.norm = norm_layer(self.num_features)
# classification head
self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity()
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
# init weights
self.apply(self._init_weights)
@ -320,7 +322,7 @@ class Twins(nn.Module):
def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def _init_weights(self, m):
if isinstance(m, nn.Linear):

View File

@ -13,7 +13,7 @@ import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg, overlay_external_default_cfg
from .layers import to_2tuple, trunc_normal_, DropPath, PatchEmbed, LayerNorm2d
from .layers import to_2tuple, trunc_normal_, DropPath, PatchEmbed, LayerNorm2d, create_classifier
from .registry import register_model
@ -140,14 +140,14 @@ class Visformer(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, init_channels=32, embed_dim=384,
depth=12, num_heads=6, mlp_ratio=4., drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
norm_layer=LayerNorm2d, attn_stage='111', pos_embed=True, spatial_conv='111',
vit_stem=False, group=8, pool=True, conv_init=False, embed_norm=None):
vit_stem=False, group=8, global_pool='avg', conv_init=False, embed_norm=None):
super().__init__()
img_size = to_2tuple(img_size)
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim
self.embed_dim = embed_dim
self.init_channels = init_channels
self.img_size = img_size
self.vit_stem = vit_stem
self.pool = pool
self.conv_init = conv_init
if isinstance(depth, (list, tuple)):
self.stage_num1, self.stage_num2, self.stage_num3 = depth
@ -164,31 +164,31 @@ class Visformer(nn.Module):
self.patch_embed1 = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans,
embed_dim=embed_dim, norm_layer=embed_norm, flatten=False)
img_size //= 16
img_size = [x // 16 for x in img_size]
else:
if self.init_channels is None:
self.stem = None
self.patch_embed1 = PatchEmbed(
img_size=img_size, patch_size=patch_size // 2, in_chans=in_chans,
embed_dim=embed_dim // 2, norm_layer=embed_norm, flatten=False)
img_size //= 8
img_size = [x // 8 for x in img_size]
else:
self.stem = nn.Sequential(
nn.Conv2d(in_chans, self.init_channels, 7, stride=2, padding=3, bias=False),
nn.BatchNorm2d(self.init_channels),
nn.ReLU(inplace=True)
)
img_size //= 2
img_size = [x // 2 for x in img_size]
self.patch_embed1 = PatchEmbed(
img_size=img_size, patch_size=patch_size // 4, in_chans=self.init_channels,
embed_dim=embed_dim // 2, norm_layer=embed_norm, flatten=False)
img_size //= 4
img_size = [x // 4 for x in img_size]
if self.pos_embed:
if self.vit_stem:
self.pos_embed1 = nn.Parameter(torch.zeros(1, embed_dim, img_size, img_size))
self.pos_embed1 = nn.Parameter(torch.zeros(1, embed_dim, *img_size))
else:
self.pos_embed1 = nn.Parameter(torch.zeros(1, embed_dim//2, img_size, img_size))
self.pos_embed1 = nn.Parameter(torch.zeros(1, embed_dim//2, *img_size))
self.pos_drop = nn.Dropout(p=drop_rate)
self.stage1 = nn.ModuleList([
Block(
@ -199,14 +199,14 @@ class Visformer(nn.Module):
for i in range(self.stage_num1)
])
#stage2
# stage2
if not self.vit_stem:
self.patch_embed2 = PatchEmbed(
img_size=img_size, patch_size=patch_size // 8, in_chans=embed_dim // 2,
embed_dim=embed_dim, norm_layer=embed_norm, flatten=False)
img_size //= 2
img_size = [x // 2 for x in img_size]
if self.pos_embed:
self.pos_embed2 = nn.Parameter(torch.zeros(1, embed_dim, img_size, img_size))
self.pos_embed2 = nn.Parameter(torch.zeros(1, embed_dim, *img_size))
self.stage2 = nn.ModuleList([
Block(
dim=embed_dim, num_heads=num_heads, head_dim_ratio=1.0, mlp_ratio=mlp_ratio,
@ -221,9 +221,9 @@ class Visformer(nn.Module):
self.patch_embed3 = PatchEmbed(
img_size=img_size, patch_size=patch_size // 8, in_chans=embed_dim,
embed_dim=embed_dim * 2, norm_layer=embed_norm, flatten=False)
img_size //= 2
img_size = [x // 2 for x in img_size]
if self.pos_embed:
self.pos_embed3 = nn.Parameter(torch.zeros(1, embed_dim*2, img_size, img_size))
self.pos_embed3 = nn.Parameter(torch.zeros(1, embed_dim*2, *img_size))
self.stage3 = nn.ModuleList([
Block(
dim=embed_dim*2, num_heads=num_heads, head_dim_ratio=1.0, mlp_ratio=mlp_ratio,
@ -234,11 +234,10 @@ class Visformer(nn.Module):
])
# head
if self.pool:
self.global_pooling = nn.AdaptiveAvgPool2d(1)
head_dim = embed_dim if self.vit_stem else embed_dim * 2
self.norm = norm_layer(head_dim)
self.head = nn.Linear(head_dim, num_classes)
self.num_features = embed_dim if self.vit_stem else embed_dim * 2
self.norm = norm_layer(self.num_features)
self.global_pool, self.head = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
self.head = nn.Linear(self.num_features, num_classes)
# weights init
if self.pos_embed:
@ -267,7 +266,14 @@ class Visformer(nn.Module):
if m.bias is not None:
nn.init.constant_(m.bias, 0.)
def forward(self, x):
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool='avg'):
self.num_classes = num_classes
self.global_pool, self.head = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
def forward_features(self, x):
if self.stem is not None:
x = self.stem(x)
@ -297,14 +303,13 @@ class Visformer(nn.Module):
for b in self.stage3:
x = b(x)
# head
x = self.norm(x)
if self.pool:
x = self.global_pooling(x)
else:
x = x[:, :, 0, 0]
return x
x = self.head(x.view(x.size(0), -1))
def forward(self, x):
x = self.forward_features(x)
x = self.global_pool(x)
x = self.head(x)
return x
@ -321,7 +326,7 @@ def _create_visformer(variant, pretrained=False, default_cfg=None, **kwargs):
@register_model
def visformer_tiny(pretrained=False, **kwargs):
model_cfg = dict(
img_size=224, init_channels=16, embed_dim=192, depth=(7, 4, 4), num_heads=3, mlp_ratio=4., group=8,
init_channels=16, embed_dim=192, depth=(7, 4, 4), num_heads=3, mlp_ratio=4., group=8,
attn_stage='011', spatial_conv='100', norm_layer=nn.BatchNorm2d, conv_init=True,
embed_norm=nn.BatchNorm2d, **kwargs)
model = _create_visformer('visformer_tiny', pretrained=pretrained, **model_cfg)
@ -331,7 +336,7 @@ def visformer_tiny(pretrained=False, **kwargs):
@register_model
def visformer_small(pretrained=False, **kwargs):
model_cfg = dict(
img_size=224, init_channels=32, embed_dim=384, depth=(7, 4, 4), num_heads=6, mlp_ratio=4., group=8,
init_channels=32, embed_dim=384, depth=(7, 4, 4), num_heads=6, mlp_ratio=4., group=8,
attn_stage='011', spatial_conv='100', norm_layer=nn.BatchNorm2d, conv_init=True,
embed_norm=nn.BatchNorm2d, **kwargs)
model = _create_visformer('visformer_small', pretrained=pretrained, **model_cfg)

View File

@ -1,7 +1,12 @@
""" Vision Transformer (ViT) in PyTorch
A PyTorch implement of Vision Transformers as described in
'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' - https://arxiv.org/abs/2010.11929
A PyTorch implement of Vision Transformers as described in:
'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale'
- https://arxiv.org/abs/2010.11929
`How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers`
- https://arxiv.org/abs/2106.TODO
The official jax code is released and available at https://github.com/google-research/vision_transformer
@ -15,7 +20,7 @@ for some einops/einsum fun
* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
* Bert reference code checks against Huggingface Transformers and Tensorflow Bert
Hacked together by / Copyright 2020 Ross Wightman
Hacked together by / Copyright 2021 Ross Wightman
"""
import math
import logging
@ -27,8 +32,8 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg, overlay_external_default_cfg
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .helpers import build_model_with_cfg, named_apply, adapt_input_conv
from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_
from .registry import register_model
@ -40,86 +45,118 @@ def _cfg(url='', **kwargs):
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
'first_conv': 'patch_embed.proj', 'classifier': 'head',
**kwargs
}
default_cfgs = {
# patch models (my experiments)
# patch models (weights from official Google JAX impl)
'vit_tiny_patch16_224': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
'vit_tiny_patch16_384': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
input_size=(3, 384, 384), crop_pct=1.0),
'vit_small_patch32_224': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
'vit_small_patch32_384': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
input_size=(3, 384, 384), crop_pct=1.0),
'vit_small_patch16_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth',
),
# patch models (weights ported from official Google JAX impl)
'vit_base_patch16_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
),
url='https://storage.googleapis.com/vit_models/augreg/'
'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
'vit_small_patch16_384': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
input_size=(3, 384, 384), crop_pct=1.0),
'vit_base_patch32_224': _cfg(
url='', # no official model weights for this combo, only for in21k
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
'vit_base_patch16_384': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth',
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
url='https://storage.googleapis.com/vit_models/augreg/'
'B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
'vit_base_patch32_384': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p32_384-830016f5.pth',
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
'vit_large_patch16_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
url='https://storage.googleapis.com/vit_models/augreg/'
'B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
input_size=(3, 384, 384), crop_pct=1.0),
'vit_base_patch16_224': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'),
'vit_base_patch16_384': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz',
input_size=(3, 384, 384), crop_pct=1.0),
'vit_large_patch32_224': _cfg(
url='', # no official model weights for this combo, only for in21k
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
'vit_large_patch16_384': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth',
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
),
'vit_large_patch32_384': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth',
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
input_size=(3, 384, 384), crop_pct=1.0),
'vit_large_patch16_224': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz'),
'vit_large_patch16_384': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz',
input_size=(3, 384, 384), crop_pct=1.0),
# patch models, imagenet21k (weights ported from official Google JAX impl)
'vit_base_patch16_224_in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth',
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
# patch models, imagenet21k (weights from official Google JAX impl)
'vit_tiny_patch16_224_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz',
num_classes=21843),
'vit_small_patch32_224_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz',
num_classes=21843),
'vit_small_patch16_224_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz',
num_classes=21843),
'vit_base_patch32_224_in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch32_224_in21k-8db57226.pth',
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
'vit_large_patch16_224_in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch16_224_in21k-606da67d.pth',
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0.npz',
num_classes=21843),
'vit_base_patch16_224_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz',
num_classes=21843),
'vit_large_patch32_224_in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth',
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
num_classes=21843),
'vit_large_patch16_224_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1.npz',
num_classes=21843),
'vit_huge_patch14_224_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz',
hf_hub='timm/vit_huge_patch14_224_in21k',
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
num_classes=21843),
# deit models (FB weights)
'vit_deit_tiny_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth'),
'vit_deit_small_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth'),
'vit_deit_base_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth',),
'vit_deit_base_patch16_384': _cfg(
'deit_tiny_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
'deit_small_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
'deit_base_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
'deit_base_patch16_384': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth',
input_size=(3, 384, 384), crop_pct=1.0),
'vit_deit_tiny_distilled_patch16_224': _cfg(
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(3, 384, 384), crop_pct=1.0),
'deit_tiny_distilled_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth',
classifier=('head', 'head_dist')),
'vit_deit_small_distilled_patch16_224': _cfg(
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')),
'deit_small_distilled_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth',
classifier=('head', 'head_dist')),
'vit_deit_base_distilled_patch16_224': _cfg(
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')),
'deit_base_distilled_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth',
classifier=('head', 'head_dist')),
'vit_deit_base_distilled_patch16_384': _cfg(
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')),
'deit_base_distilled_patch16_384': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth',
input_size=(3, 384, 384), crop_pct=1.0, classifier=('head', 'head_dist')),
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(3, 384, 384), crop_pct=1.0,
classifier=('head', 'head_dist')),
# ViT ImageNet-21K-P pretraining
# ViT ImageNet-21K-P pretraining by MILL
'vit_base_patch16_224_miil_in21k': _cfg(
url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/vit_base_patch16_224_in21k_miil.pth',
mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', num_classes=11221,
@ -133,11 +170,11 @@ default_cfgs = {
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
@ -161,12 +198,11 @@ class Attention(nn.Module):
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
@ -190,7 +226,7 @@ class VisionTransformer(nn.Module):
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, distilled=False,
num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None,
act_layer=None, weight_init=''):
"""
@ -204,7 +240,6 @@ class VisionTransformer(nn.Module):
num_heads (int): number of attention heads
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True
qk_scale (float): override default qk scale of head_dim ** -0.5 if set
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
distilled (bool): model includes a distillation token and head as in DeiT models
drop_rate (float): dropout rate
@ -233,8 +268,8 @@ class VisionTransformer(nn.Module):
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.Sequential(*[
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer)
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate,
attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer)
for i in range(depth)])
self.norm = norm_layer(embed_dim)
@ -254,16 +289,17 @@ class VisionTransformer(nn.Module):
if distilled:
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
# Weight init
assert weight_init in ('jax', 'jax_nlhb', 'nlhb', '')
head_bias = -math.log(self.num_classes) if 'nlhb' in weight_init else 0.
self.init_weights(weight_init)
def init_weights(self, mode=''):
assert mode in ('jax', 'jax_nlhb', 'nlhb', '')
head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
trunc_normal_(self.pos_embed, std=.02)
if self.dist_token is not None:
trunc_normal_(self.dist_token, std=.02)
if weight_init.startswith('jax'):
if mode.startswith('jax'):
# leave cls token as zeros to match jax impl
for n, m in self.named_modules():
_init_vit_weights(m, n, head_bias=head_bias, jax_impl=True)
named_apply(partial(_init_vit_weights, head_bias=head_bias, jax_impl=True), self)
else:
trunc_normal_(self.cls_token, std=.02)
self.apply(_init_vit_weights)
@ -272,6 +308,10 @@ class VisionTransformer(nn.Module):
# this fn left here for compat with downstream users
_init_vit_weights(m)
@torch.jit.ignore()
def load_pretrained(self, checkpoint_path, prefix=''):
_load_weights(self, checkpoint_path, prefix)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token', 'dist_token'}
@ -317,39 +357,116 @@ class VisionTransformer(nn.Module):
return x
def _init_vit_weights(m, n: str = '', head_bias: float = 0., jax_impl: bool = False):
def _init_vit_weights(module: nn.Module, name: str = '', head_bias: float = 0., jax_impl: bool = False):
""" ViT weight initialization
* When called without n, head_bias, jax_impl args it will behave exactly the same
as my original init for compatibility with prev hparam / downstream use cases (ie DeiT).
* When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl
"""
if isinstance(m, nn.Linear):
if n.startswith('head'):
nn.init.zeros_(m.weight)
nn.init.constant_(m.bias, head_bias)
elif n.startswith('pre_logits'):
lecun_normal_(m.weight)
nn.init.zeros_(m.bias)
if isinstance(module, nn.Linear):
if name.startswith('head'):
nn.init.zeros_(module.weight)
nn.init.constant_(module.bias, head_bias)
elif name.startswith('pre_logits'):
lecun_normal_(module.weight)
nn.init.zeros_(module.bias)
else:
if jax_impl:
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
if 'mlp' in n:
nn.init.normal_(m.bias, std=1e-6)
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
if 'mlp' in name:
nn.init.normal_(module.bias, std=1e-6)
else:
nn.init.zeros_(m.bias)
nn.init.zeros_(module.bias)
else:
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif jax_impl and isinstance(m, nn.Conv2d):
trunc_normal_(module.weight, std=.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif jax_impl and isinstance(module, nn.Conv2d):
# NOTE conv was left to pytorch default in my original init
lecun_normal_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
nn.init.zeros_(m.bias)
nn.init.ones_(m.weight)
lecun_normal_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
nn.init.zeros_(module.bias)
nn.init.ones_(module.weight)
@torch.no_grad()
def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
""" Load weights from .npz checkpoints for official Google Brain Flax implementation
"""
import numpy as np
def _n2p(w, t=True):
if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
w = w.flatten()
if t:
if w.ndim == 4:
w = w.transpose([3, 2, 0, 1])
elif w.ndim == 3:
w = w.transpose([2, 0, 1])
elif w.ndim == 2:
w = w.transpose([1, 0])
return torch.from_numpy(w)
w = np.load(checkpoint_path)
if not prefix and 'opt/target/embedding/kernel' in w:
prefix = 'opt/target/'
if hasattr(model.patch_embed, 'backbone'):
# hybrid
backbone = model.patch_embed.backbone
stem_only = not hasattr(backbone, 'stem')
stem = backbone if stem_only else backbone.stem
stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
if not stem_only:
for i, stage in enumerate(backbone.stages):
for j, block in enumerate(stage.blocks):
bp = f'{prefix}block{i + 1}/unit{j + 1}/'
for r in range(3):
getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
if block.downsample is not None:
block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
else:
embed_conv_w = adapt_input_conv(
model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
model.patch_embed.proj.weight.copy_(embed_conv_w)
model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
if pos_embed_w.shape != model.pos_embed.shape:
pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
model.pos_embed.copy_(pos_embed_w)
model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
if model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
for i, block in enumerate(model.blocks.children()):
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
block.attn.qkv.weight.copy_(torch.cat([
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
block.attn.qkv.bias.copy_(torch.cat([
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
for r in range(2):
getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()):
@ -413,34 +530,64 @@ def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kw
default_cfg=default_cfg,
representation_size=repr_size,
pretrained_filter_fn=checkpoint_filter_fn,
pretrained_custom_load='npz' in default_cfg['url'],
**kwargs)
return model
@register_model
def vit_small_patch16_224(pretrained=False, **kwargs):
""" My custom 'small' ViT model. embed_dim=768, depth=8, num_heads=8, mlp_ratio=3.
NOTE:
* this differs from the DeiT based 'small' definitions with embed_dim=384, depth=12, num_heads=6
* this model does not have a bias for QKV (unlike the official ViT and DeiT models)
def vit_tiny_patch16_224(pretrained=False, **kwargs):
""" ViT-Tiny (Vit-Ti/16)
"""
model_kwargs = dict(
patch_size=16, embed_dim=768, depth=8, num_heads=8, mlp_ratio=3.,
qkv_bias=False, norm_layer=nn.LayerNorm, **kwargs)
if pretrained:
# NOTE my scale was wrong for original weights, leaving this here until I have better ones for this model
model_kwargs.setdefault('qk_scale', 768 ** -0.5)
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
model = _create_vision_transformer('vit_tiny_patch16_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_tiny_patch16_384(pretrained=False, **kwargs):
""" ViT-Tiny (Vit-Ti/16) @ 384x384.
"""
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
model = _create_vision_transformer('vit_tiny_patch16_384', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_small_patch32_224(pretrained=False, **kwargs):
""" ViT-Small (ViT-S/32)
"""
model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer('vit_small_patch32_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_small_patch32_384(pretrained=False, **kwargs):
""" ViT-Small (ViT-S/32) at 384x384.
"""
model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer('vit_small_patch32_384', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_small_patch16_224(pretrained=False, **kwargs):
""" ViT-Small (ViT-S/16)
NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper
"""
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_patch16_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
def vit_small_patch16_384(pretrained=False, **kwargs):
""" ViT-Small (ViT-S/16)
NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper
"""
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer('vit_small_patch16_384', pretrained=pretrained, **model_kwargs)
return model
@ -453,16 +600,6 @@ def vit_base_patch32_224(pretrained=False, **kwargs):
return model
@register_model
def vit_base_patch16_384(pretrained=False, **kwargs):
""" ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
"""
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer('vit_base_patch16_384', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_patch32_384(pretrained=False, **kwargs):
""" ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
@ -474,12 +611,22 @@ def vit_base_patch32_384(pretrained=False, **kwargs):
@register_model
def vit_large_patch16_224(pretrained=False, **kwargs):
""" ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
def vit_base_patch16_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
"""
model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_patch16_384(pretrained=False, **kwargs):
""" ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
"""
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer('vit_base_patch16_384', pretrained=pretrained, **model_kwargs)
return model
@ -492,16 +639,6 @@ def vit_large_patch32_224(pretrained=False, **kwargs):
return model
@register_model
def vit_large_patch16_384(pretrained=False, **kwargs):
""" ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
"""
model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
model = _create_vision_transformer('vit_large_patch16_384', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_large_patch32_384(pretrained=False, **kwargs):
""" ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
@ -513,13 +650,52 @@ def vit_large_patch32_384(pretrained=False, **kwargs):
@register_model
def vit_base_patch16_224_in21k(pretrained=False, **kwargs):
""" ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
def vit_large_patch16_224(pretrained=False, **kwargs):
""" ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
"""
model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_large_patch16_384(pretrained=False, **kwargs):
""" ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
"""
model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
model = _create_vision_transformer('vit_large_patch16_384', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_tiny_patch16_224_in21k(pretrained=False, **kwargs):
""" ViT-Tiny (Vit-Ti/16).
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
"""
model_kwargs = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs)
model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
model = _create_vision_transformer('vit_tiny_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_small_patch32_224_in21k(pretrained=False, **kwargs):
""" ViT-Small (ViT-S/16)
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
"""
model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer('vit_small_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_small_patch16_224_in21k(pretrained=False, **kwargs):
""" ViT-Small (ViT-S/16)
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
"""
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer('vit_small_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
return model
@ -535,13 +711,13 @@ def vit_base_patch32_224_in21k(pretrained=False, **kwargs):
@register_model
def vit_large_patch16_224_in21k(pretrained=False, **kwargs):
""" ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
def vit_base_patch16_224_in21k(pretrained=False, **kwargs):
""" ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
"""
model_kwargs = dict(
patch_size=16, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs)
model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
patch_size=16, embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs)
model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
return model
@ -556,6 +732,17 @@ def vit_large_patch32_224_in21k(pretrained=False, **kwargs):
return model
@register_model
def vit_large_patch16_224_in21k(pretrained=False, **kwargs):
""" ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
"""
model_kwargs = dict(
patch_size=16, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs)
model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_huge_patch14_224_in21k(pretrained=False, **kwargs):
""" ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
@ -569,86 +756,86 @@ def vit_huge_patch14_224_in21k(pretrained=False, **kwargs):
@register_model
def vit_deit_tiny_patch16_224(pretrained=False, **kwargs):
def deit_tiny_patch16_224(pretrained=False, **kwargs):
""" DeiT-tiny model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
model = _create_vision_transformer('vit_deit_tiny_patch16_224', pretrained=pretrained, **model_kwargs)
model = _create_vision_transformer('deit_tiny_patch16_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_deit_small_patch16_224(pretrained=False, **kwargs):
def deit_small_patch16_224(pretrained=False, **kwargs):
""" DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer('vit_deit_small_patch16_224', pretrained=pretrained, **model_kwargs)
model = _create_vision_transformer('deit_small_patch16_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_deit_base_patch16_224(pretrained=False, **kwargs):
def deit_base_patch16_224(pretrained=False, **kwargs):
""" DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer('vit_deit_base_patch16_224', pretrained=pretrained, **model_kwargs)
model = _create_vision_transformer('deit_base_patch16_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_deit_base_patch16_384(pretrained=False, **kwargs):
def deit_base_patch16_384(pretrained=False, **kwargs):
""" DeiT base model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer('vit_deit_base_patch16_384', pretrained=pretrained, **model_kwargs)
model = _create_vision_transformer('deit_base_patch16_384', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_deit_tiny_distilled_patch16_224(pretrained=False, **kwargs):
def deit_tiny_distilled_patch16_224(pretrained=False, **kwargs):
""" DeiT-tiny distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
model = _create_vision_transformer(
'vit_deit_tiny_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
'deit_tiny_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
return model
@register_model
def vit_deit_small_distilled_patch16_224(pretrained=False, **kwargs):
def deit_small_distilled_patch16_224(pretrained=False, **kwargs):
""" DeiT-small distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer(
'vit_deit_small_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
'deit_small_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
return model
@register_model
def vit_deit_base_distilled_patch16_224(pretrained=False, **kwargs):
def deit_base_distilled_patch16_224(pretrained=False, **kwargs):
""" DeiT-base distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer(
'vit_deit_base_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
'deit_base_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
return model
@register_model
def vit_deit_base_distilled_patch16_384(pretrained=False, **kwargs):
def deit_base_distilled_patch16_384(pretrained=False, **kwargs):
""" DeiT-base distilled model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer(
'vit_deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **model_kwargs)
'deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **model_kwargs)
return model

View File

@ -1,13 +1,17 @@
""" Hybrid Vision Transformer (ViT) in PyTorch
A PyTorch implement of the Hybrid Vision Transformers as described in
A PyTorch implement of the Hybrid Vision Transformers as described in:
'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale'
- https://arxiv.org/abs/2010.11929
NOTE This relies on code in vision_transformer.py. The hybrid model definitions were moved here to
keep file sizes sane.
`How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers`
- https://arxiv.org/abs/2106.TODO
Hacked together by / Copyright 2020 Ross Wightman
NOTE These hybrid model definitions depend on code in vision_transformer.py.
They were moved here to keep file sizes sane.
Hacked together by / Copyright 2021 Ross Wightman
"""
from copy import deepcopy
from functools import partial
@ -35,32 +39,61 @@ def _cfg(url='', **kwargs):
default_cfgs = {
# hybrid in-21k models (weights ported from official Google JAX impl where they exist)
'vit_base_r50_s16_224_in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth',
num_classes=21843, crop_pct=0.9),
# hybrid in-1k models (weights ported from official JAX impl)
# hybrid in-1k models (weights from official JAX impl where they exist)
'vit_tiny_r_s16_p8_224': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
first_conv='patch_embed.backbone.conv'),
'vit_tiny_r_s16_p8_384': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
first_conv='patch_embed.backbone.conv', input_size=(3, 384, 384), crop_pct=1.0),
'vit_small_r26_s32_224': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.03-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.03-res_224.npz',
),
'vit_small_r26_s32_384': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
input_size=(3, 384, 384), crop_pct=1.0),
'vit_base_r26_s32_224': _cfg(),
'vit_base_r50_s16_224': _cfg(),
'vit_base_r50_s16_384': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth',
input_size=(3, 384, 384), crop_pct=1.0),
'vit_large_r50_s32_224': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'R50_L_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz'
),
'vit_large_r50_s32_384': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'R50_L_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz',
input_size=(3, 384, 384), crop_pct=1.0
),
# hybrid in-1k models (mostly untrained, experimental configs w/ resnetv2 stdconv backbones)
'vit_tiny_r_s16_p8_224': _cfg(),
'vit_small_r_s16_p8_224': _cfg(),
'vit_small_r20_s16_p2_224': _cfg(),
'vit_small_r20_s16_224': _cfg(),
'vit_small_r26_s32_224': _cfg(),
'vit_base_r20_s16_224': _cfg(),
'vit_base_r26_s32_224': _cfg(),
'vit_base_r50_s16_224': _cfg(),
'vit_large_r50_s32_224': _cfg(),
# hybrid in-21k models (weights from official Google JAX impl where they exist)
'vit_tiny_r_s16_p8_224_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz',
num_classes=21843, crop_pct=0.9, first_conv='patch_embed.backbone.conv'),
'vit_small_r26_s32_224_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0.npz',
num_classes=21843, crop_pct=0.9),
'vit_base_r50_s16_224_in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth',
num_classes=21843, crop_pct=0.9),
'vit_large_r50_s32_224_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R50_L_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0.npz',
num_classes=21843, crop_pct=0.9),
# hybrid models (using timm resnet backbones)
'vit_small_resnet26d_224': _cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
'vit_small_resnet50d_s16_224': _cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
'vit_base_resnet26d_224': _cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
'vit_base_resnet50d_224': _cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
'vit_small_resnet26d_224': _cfg(
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
'vit_small_resnet50d_s16_224': _cfg(
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
'vit_base_resnet26d_224': _cfg(
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
'vit_base_resnet50d_224': _cfg(
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
}
@ -95,7 +128,8 @@ class HybridEmbed(nn.Module):
else:
feature_dim = self.backbone.num_features
assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0
self.num_patches = feature_size[0] // patch_size[0] * feature_size[1] // patch_size[1]
self.grid_size = (feature_size[0] // patch_size[0], feature_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
@ -116,12 +150,8 @@ def _create_vision_transformer_hybrid(variant, backbone, pretrained=False, **kwa
def _resnetv2(layers=(3, 4, 9), **kwargs):
""" ResNet-V2 backbone helper"""
padding_same = kwargs.get('padding_same', True)
if padding_same:
stem_type = 'same'
conv_layer = StdConv2dSame
else:
stem_type = ''
conv_layer = StdConv2d
stem_type = 'same' if padding_same else ''
conv_layer = partial(StdConv2dSame, eps=1e-8) if padding_same else partial(StdConv2d, eps=1e-8)
if len(layers):
backbone = ResNetV2(
layers=layers, num_classes=0, global_pool='', in_chans=kwargs.get('in_chans', 3),
@ -132,42 +162,6 @@ def _resnetv2(layers=(3, 4, 9), **kwargs):
return backbone
@register_model
def vit_base_r50_s16_224_in21k(pretrained=False, **kwargs):
""" R50+ViT-B/16 hybrid model from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
"""
backbone = _resnetv2(layers=(3, 4, 9), **kwargs)
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs)
model = _create_vision_transformer_hybrid(
'vit_base_r50_s16_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_resnet50_224_in21k(pretrained=False, **kwargs):
# NOTE this is forwarding to model def above for backwards compatibility
return vit_base_r50_s16_224_in21k(pretrained=pretrained, **kwargs)
@register_model
def vit_base_r50_s16_384(pretrained=False, **kwargs):
""" R50+ViT-B/16 hybrid from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
"""
backbone = _resnetv2((3, 4, 9), **kwargs)
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer_hybrid(
'vit_base_r50_s16_384', backbone=backbone, pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_resnet50_384(pretrained=False, **kwargs):
# NOTE this is forwarding to model def above for backwards compatibility
return vit_base_r50_s16_384(pretrained=pretrained, **kwargs)
@register_model
def vit_tiny_r_s16_p8_224(pretrained=False, **kwargs):
""" R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 224 x 224.
@ -180,36 +174,13 @@ def vit_tiny_r_s16_p8_224(pretrained=False, **kwargs):
@register_model
def vit_small_r_s16_p8_224(pretrained=False, **kwargs):
""" R+ViT-S/S16 w/ 8x8 patch hybrid @ 224 x 224.
def vit_tiny_r_s16_p8_384(pretrained=False, **kwargs):
""" R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 384 x 384.
"""
backbone = _resnetv2(layers=(), **kwargs)
model_kwargs = dict(patch_size=8, embed_dim=384, depth=12, num_heads=6, **kwargs)
model_kwargs = dict(patch_size=8, embed_dim=192, depth=12, num_heads=3, **kwargs)
model = _create_vision_transformer_hybrid(
'vit_small_r_s16_p8_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_small_r20_s16_p2_224(pretrained=False, **kwargs):
""" R52+ViT-S/S16 w/ 2x2 patch hybrid @ 224 x 224.
"""
backbone = _resnetv2((2, 4), **kwargs)
model_kwargs = dict(patch_size=2, embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer_hybrid(
'vit_small_r20_s16_p2_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_small_r20_s16_224(pretrained=False, **kwargs):
""" R20+ViT-S/S16 hybrid.
"""
backbone = _resnetv2((2, 2, 2), **kwargs)
model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer_hybrid(
'vit_small_r20_s16_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
'vit_tiny_r_s16_p8_384', backbone=backbone, pretrained=pretrained, **model_kwargs)
return model
@ -225,13 +196,13 @@ def vit_small_r26_s32_224(pretrained=False, **kwargs):
@register_model
def vit_base_r20_s16_224(pretrained=False, **kwargs):
""" R20+ViT-B/S16 hybrid.
def vit_small_r26_s32_384(pretrained=False, **kwargs):
""" R26+ViT-S/S32 hybrid.
"""
backbone = _resnetv2((2, 2, 2), **kwargs)
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs)
backbone = _resnetv2((2, 2, 2, 2), **kwargs)
model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer_hybrid(
'vit_base_r20_s16_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
'vit_small_r26_s32_384', backbone=backbone, pretrained=pretrained, **model_kwargs)
return model
@ -257,17 +228,97 @@ def vit_base_r50_s16_224(pretrained=False, **kwargs):
return model
@register_model
def vit_base_r50_s16_384(pretrained=False, **kwargs):
""" R50+ViT-B/16 hybrid from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
"""
backbone = _resnetv2((3, 4, 9), **kwargs)
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer_hybrid(
'vit_base_r50_s16_384', backbone=backbone, pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_resnet50_384(pretrained=False, **kwargs):
# DEPRECATED this is forwarding to model def above for backwards compatibility
return vit_base_r50_s16_384(pretrained=pretrained, **kwargs)
@register_model
def vit_large_r50_s32_224(pretrained=False, **kwargs):
""" R50+ViT-L/S32 hybrid.
"""
backbone = _resnetv2((3, 4, 6, 3), **kwargs)
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs)
model_kwargs = dict(embed_dim=1024, depth=24, num_heads=16, **kwargs)
model = _create_vision_transformer_hybrid(
'vit_large_r50_s32_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_large_r50_s32_384(pretrained=False, **kwargs):
""" R50+ViT-L/S32 hybrid.
"""
backbone = _resnetv2((3, 4, 6, 3), **kwargs)
model_kwargs = dict(embed_dim=1024, depth=24, num_heads=16, **kwargs)
model = _create_vision_transformer_hybrid(
'vit_large_r50_s32_384', backbone=backbone, pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_tiny_r_s16_p8_224_in21k(pretrained=False, **kwargs):
""" R+ViT-Ti/S16 w/ 8x8 patch hybrid. ImageNet-21k.
"""
backbone = _resnetv2(layers=(), **kwargs)
model_kwargs = dict(patch_size=8, embed_dim=192, depth=12, num_heads=3, **kwargs)
model = _create_vision_transformer_hybrid(
'vit_tiny_r_s16_p8_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_small_r26_s32_224_in21k(pretrained=False, **kwargs):
""" R26+ViT-S/S32 hybrid. ImageNet-21k.
"""
backbone = _resnetv2((2, 2, 2, 2), **kwargs)
model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer_hybrid(
'vit_small_r26_s32_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_r50_s16_224_in21k(pretrained=False, **kwargs):
""" R50+ViT-B/16 hybrid model from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
"""
backbone = _resnetv2(layers=(3, 4, 9), **kwargs)
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs)
model = _create_vision_transformer_hybrid(
'vit_base_r50_s16_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_resnet50_224_in21k(pretrained=False, **kwargs):
# DEPRECATED this is forwarding to model def above for backwards compatibility
return vit_base_r50_s16_224_in21k(pretrained=pretrained, **kwargs)
@register_model
def vit_large_r50_s32_224_in21k(pretrained=False, **kwargs):
""" R50+ViT-L/S32 hybrid. ImageNet-21k.
"""
backbone = _resnetv2((3, 4, 6, 3), **kwargs)
model_kwargs = dict(embed_dim=1024, depth=24, num_heads=16, **kwargs)
model = _create_vision_transformer_hybrid(
'vit_large_r50_s32_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_small_resnet26d_224(pretrained=False, **kwargs):
""" Custom ViT small hybrid w/ ResNet26D stride 32. No pretrained weights.

View File

@ -1 +1 @@
__version__ = '0.4.11'
__version__ = '0.4.12'