Benchmark models listed in txt file. Add more hybrid vit variants for testing
parent
2db2d87ff7
commit
0706d05d52
13
benchmark.py
13
benchmark.py
|
@ -45,6 +45,8 @@ _logger = logging.getLogger('validate')
|
|||
parser = argparse.ArgumentParser(description='PyTorch Benchmark')
|
||||
|
||||
# benchmark specific args
|
||||
parser.add_argument('--model-list', metavar='NAME', default='',
|
||||
help='txt file based list of model names to benchmark')
|
||||
parser.add_argument('--bench', default='both', type=str,
|
||||
help="Benchmark mode. One of 'inference', 'train', 'both'. Defaults to 'inference'")
|
||||
parser.add_argument('--detail', action='store_true', default=False,
|
||||
|
@ -357,7 +359,7 @@ def _try_run(model_name, bench_fn, initial_batch_size, bench_kwargs):
|
|||
except RuntimeError as e:
|
||||
torch.cuda.empty_cache()
|
||||
batch_size = decay_batch_exp(batch_size)
|
||||
print(f'Reducing batch size to {batch_size}')
|
||||
print(f'Error: {str(e)} while running benchmark. Reducing batch size to {batch_size} for retry.')
|
||||
return results
|
||||
|
||||
|
||||
|
@ -413,7 +415,12 @@ def main():
|
|||
model_cfgs = []
|
||||
model_names = []
|
||||
|
||||
if args.model == 'all':
|
||||
if args.model_list:
|
||||
args.model = ''
|
||||
with open(args.model_list) as f:
|
||||
model_names = [line.rstrip() for line in f]
|
||||
model_cfgs = [(n, None) for n in model_names]
|
||||
elif args.model == 'all':
|
||||
# validate all models in a list of names with pretrained checkpoints
|
||||
args.pretrained = True
|
||||
model_names = list_models(pretrained=True, exclude_filters=['*in21k'])
|
||||
|
@ -429,6 +436,8 @@ def main():
|
|||
results = []
|
||||
try:
|
||||
for m, _ in model_cfgs:
|
||||
if not m:
|
||||
continue
|
||||
args.model = m
|
||||
r = benchmark(args)
|
||||
results.append(r)
|
||||
|
|
|
@ -103,48 +103,90 @@ default_cfgs = {
|
|||
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=0.9, first_conv='patch_embed.backbone.stem.conv'),
|
||||
|
||||
# hybrid in-1k models (weights ported from official Google JAX impl where they exist)
|
||||
'vit_small_r_s16_p8_224': _cfg(
|
||||
input_size=(3, 224, 224), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
||||
'vit_tiny_r_s16_p8_224': _cfg(
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
||||
first_conv='patch_embed.backbone.stem.conv'),
|
||||
'vit_tiny_r_s16_p8_224_in21k': _cfg(
|
||||
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
||||
first_conv='patch_embed.backbone.stem.conv'),
|
||||
'vit_tiny_r_s16_p8_384': _cfg(
|
||||
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
||||
first_conv='patch_embed.backbone.stem.conv'),
|
||||
|
||||
'vit_small_r_s16_p8_224': _cfg(
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
||||
first_conv='patch_embed.backbone.stem.conv'),
|
||||
'vit_small_r_s16_p8_224_in21k': _cfg(
|
||||
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
||||
first_conv='patch_embed.backbone.stem.conv'),
|
||||
'vit_small_r_s16_p8_384': _cfg(
|
||||
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
||||
first_conv='patch_embed.backbone.stem.conv'),
|
||||
|
||||
'vit_small_r20_s16_p2_224': _cfg(
|
||||
input_size=(3, 224, 224), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
||||
first_conv='patch_embed.backbone.stem.conv'),
|
||||
'vit_small_r20_s16_p2_224_in21k': _cfg(
|
||||
inum_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
||||
first_conv='patch_embed.backbone.stem.conv'),
|
||||
'vit_small_r20_s16_p2_384': _cfg(
|
||||
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
||||
first_conv='patch_embed.backbone.stem.conv'),
|
||||
|
||||
|
||||
'vit_small_r20_s16_224': _cfg(
|
||||
input_size=(3, 224, 224), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
||||
first_conv='patch_embed.backbone.stem.conv'),
|
||||
'vit_small_r20_s16_224_in21k': _cfg(
|
||||
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
||||
first_conv='patch_embed.backbone.stem.conv'),
|
||||
'vit_small_r20_s16_384': _cfg(
|
||||
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
||||
first_conv='patch_embed.backbone.stem.conv'),
|
||||
|
||||
'vit_small_r26_s32_224': _cfg(
|
||||
input_size=(3, 224, 224), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
||||
first_conv='patch_embed.backbone.stem.conv'),
|
||||
'vit_small_r26_s32_224_in21k': _cfg(
|
||||
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
||||
first_conv='patch_embed.backbone.stem.conv'),
|
||||
'vit_small_r26_s32_384': _cfg(
|
||||
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
||||
first_conv='patch_embed.backbone.stem.conv'),
|
||||
|
||||
'vit_base_r20_s16_224': _cfg(
|
||||
input_size=(3, 224, 224), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
||||
first_conv='patch_embed.backbone.stem.conv'),
|
||||
'vit_base_r20_s16_224_in21k': _cfg(
|
||||
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
||||
first_conv='patch_embed.backbone.stem.conv'),
|
||||
'vit_base_r20_s16_384': _cfg(
|
||||
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
||||
first_conv='patch_embed.backbone.stem.conv'),
|
||||
|
||||
'vit_base_r26_s32_224': _cfg(
|
||||
input_size=(3, 224, 224), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
||||
first_conv='patch_embed.backbone.stem.conv'),
|
||||
'vit_base_r26_s32_224_in21k': _cfg(
|
||||
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
||||
first_conv='patch_embed.backbone.stem.conv'),
|
||||
'vit_base_r26_s32_384': _cfg(
|
||||
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
||||
first_conv='patch_embed.backbone.stem.conv'),
|
||||
|
||||
'vit_base_r50_s16_224': _cfg(
|
||||
input_size=(3, 224, 224), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
||||
first_conv='patch_embed.backbone.stem.conv'),
|
||||
'vit_base_r50_s16_384': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth',
|
||||
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
||||
first_conv='patch_embed.backbone.stem.conv'),
|
||||
|
||||
'vit_large_r50_s32_224': _cfg(
|
||||
input_size=(3, 224, 224), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
||||
first_conv='patch_embed.backbone.stem.conv'),
|
||||
'vit_large_r50_s32_224_in21k': _cfg(
|
||||
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
||||
first_conv='patch_embed.backbone.stem.conv'),
|
||||
'vit_large_r50_s32_384': _cfg(
|
||||
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
||||
|
@ -159,8 +201,19 @@ default_cfgs = {
|
|||
# deit models (FB weights)
|
||||
'vit_deit_tiny_patch16_224': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth'),
|
||||
'vit_deit_tiny_patch16_224_in21k': _cfg(num_classes=21843),
|
||||
'vit_deit_tiny_patch16_224_in21k_norep': _cfg(num_classes=21843),
|
||||
'vit_deit_tiny_patch16_384': _cfg(input_size=(3, 384, 384)),
|
||||
|
||||
'vit_deit_small_patch16_224': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth'),
|
||||
'vit_deit_small_patch16_224_in21k': _cfg(num_classes=21843),
|
||||
'vit_deit_small_patch16_384': _cfg(input_size=(3, 384, 384)),
|
||||
|
||||
'vit_deit_small_patch32_224': _cfg(),
|
||||
'vit_deit_small_patch32_224_in21k': _cfg(num_classes=21843),
|
||||
'vit_deit_small_patch32_384': _cfg(input_size=(3, 384, 384)),
|
||||
|
||||
'vit_deit_base_patch16_224': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth',),
|
||||
'vit_deit_base_patch16_384': _cfg(
|
||||
|
@ -728,7 +781,29 @@ def vit_tiny_r_s16_p8_224(pretrained=False, **kwargs):
|
|||
backbone = _resnetv2(layers=(), **kwargs)
|
||||
model_kwargs = dict(
|
||||
patch_size=8, embed_dim=192, depth=12, num_heads=3, hybrid_backbone=backbone, **kwargs)
|
||||
model = _create_vision_transformer('vit_small_r20_s16_p2_224', pretrained=pretrained, **model_kwargs)
|
||||
model = _create_vision_transformer('vit_tiny_r_s16_p8_224', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_tiny_r_s16_p8_224_in21k(pretrained=False, **kwargs):
|
||||
""" R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 224 x 224.
|
||||
"""
|
||||
backbone = _resnetv2(layers=(), **kwargs)
|
||||
model_kwargs = dict(
|
||||
patch_size=8, embed_dim=192, depth=12, num_heads=3, representation_size=192, hybrid_backbone=backbone, **kwargs)
|
||||
model = _create_vision_transformer('vit_tiny_r_s16_p8_224_in21k', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_tiny_r_s16_p8_384(pretrained=False, **kwargs):
|
||||
""" R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 224 x 224.
|
||||
"""
|
||||
backbone = _resnetv2(layers=(), **kwargs)
|
||||
model_kwargs = dict(
|
||||
patch_size=8, embed_dim=192, depth=12, num_heads=3, hybrid_backbone=backbone, **kwargs)
|
||||
model = _create_vision_transformer('vit_tiny_r_s16_p8_384', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
|
@ -740,6 +815,29 @@ def vit_small_r_s16_p8_224(pretrained=False, **kwargs):
|
|||
model_kwargs = dict(
|
||||
patch_size=8, embed_dim=384, depth=12, num_heads=6, hybrid_backbone=backbone, **kwargs)
|
||||
model = _create_vision_transformer('vit_small_r_s16_p8_224', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_small_r_s16_p8_224_in21k(pretrained=False, **kwargs):
|
||||
""" R+ViT-S/S16 w/ 8x8 patch hybrid @ 224 x 224.
|
||||
"""
|
||||
backbone = _resnetv2(layers=(), **kwargs)
|
||||
model_kwargs = dict(
|
||||
patch_size=8, embed_dim=384, depth=12, num_heads=6, representation_size=384, hybrid_backbone=backbone, **kwargs)
|
||||
model = _create_vision_transformer('vit_small_r_s16_p8_224_in21k', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_small_r_s16_p8_384(pretrained=False, **kwargs):
|
||||
""" R+ViT-S/S16 w/ 8x8 patch hybrid @ 224 x 224.
|
||||
"""
|
||||
backbone = _resnetv2(layers=(), **kwargs)
|
||||
model_kwargs = dict(
|
||||
patch_size=8, embed_dim=384, depth=12, num_heads=6, hybrid_backbone=backbone, **kwargs)
|
||||
model = _create_vision_transformer('vit_small_r_s16_p8_384', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
|
@ -754,6 +852,17 @@ def vit_small_r20_s16_p2_224(pretrained=False, **kwargs):
|
|||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_small_r20_s16_p2_224_in21k(pretrained=False, **kwargs):
|
||||
""" R52+ViT-S/S16 w/ 2x2 patch hybrid @ 224 x 224.
|
||||
"""
|
||||
backbone = _resnetv2((2, 4), **kwargs)
|
||||
model_kwargs = dict(
|
||||
patch_size=2, embed_dim=384, depth=12, num_heads=6, representation_size=384, hybrid_backbone=backbone, **kwargs)
|
||||
model = _create_vision_transformer('vit_small_r20_s16_p2_224_in21k', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_small_r20_s16_p2_384(pretrained=False, **kwargs):
|
||||
""" R20+ViT-S/S16 w/ 2x2 Patch hybrid @ 384x384.
|
||||
|
@ -775,6 +884,16 @@ def vit_small_r20_s16_224(pretrained=False, **kwargs):
|
|||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_small_r20_s16_224_in21k(pretrained=False, **kwargs):
|
||||
""" R20+ViT-S/S16 hybrid.
|
||||
"""
|
||||
backbone = _resnetv2((2, 2, 2), **kwargs)
|
||||
model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, representation_size=384, hybrid_backbone=backbone, **kwargs)
|
||||
model = _create_vision_transformer('vit_small_r20_s16_224_in21k', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_small_r20_s16_384(pretrained=False, **kwargs):
|
||||
""" R20+ViT-S/S16 hybrid @ 384x384.
|
||||
|
@ -795,6 +914,17 @@ def vit_small_r26_s32_224(pretrained=False, **kwargs):
|
|||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_small_r26_s32_224_in21k(pretrained=False, **kwargs):
|
||||
""" R26+ViT-S/S32 hybrid.
|
||||
"""
|
||||
backbone = _resnetv2((2, 2, 2, 2), **kwargs)
|
||||
model_kwargs = dict(
|
||||
embed_dim=384, depth=12, num_heads=6, representation_size=384, hybrid_backbone=backbone, **kwargs)
|
||||
model = _create_vision_transformer('vit_small_r26_s32_224_in21k', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_small_r26_s32_384(pretrained=False, **kwargs):
|
||||
""" R26+ViT-S/S32 hybrid @ 384x384.
|
||||
|
@ -810,12 +940,22 @@ def vit_base_r20_s16_224(pretrained=False, **kwargs):
|
|||
""" R20+ViT-B/S16 hybrid.
|
||||
"""
|
||||
backbone = _resnetv2((2, 2, 2), **kwargs)
|
||||
model_kwargs = dict(
|
||||
embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, act_layer=nn.SiLU, **kwargs)
|
||||
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs)
|
||||
model = _create_vision_transformer('vit_base_r20_s16_224', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_r20_s16_224_in21k(pretrained=False, **kwargs):
|
||||
""" R20+ViT-B/S16 hybrid.
|
||||
"""
|
||||
backbone = _resnetv2((2, 2, 2), **kwargs)
|
||||
model_kwargs = dict(
|
||||
embed_dim=768, depth=12, num_heads=12, representation_size=768, hybrid_backbone=backbone, **kwargs)
|
||||
model = _create_vision_transformer('vit_base_r20_s16_224_in21k', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_r20_s16_384(pretrained=False, **kwargs):
|
||||
""" R20+ViT-B/S16 hybrid.
|
||||
|
@ -836,6 +976,27 @@ def vit_base_r26_s32_224(pretrained=False, **kwargs):
|
|||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_r26_s32_224_in21k(pretrained=False, **kwargs):
|
||||
""" R26+ViT-B/S32 hybrid.
|
||||
"""
|
||||
backbone = _resnetv2((2, 2, 2, 2), **kwargs)
|
||||
model_kwargs = dict(
|
||||
embed_dim=768, depth=12, num_heads=12, representation_size=768, hybrid_backbone=backbone, **kwargs)
|
||||
model = _create_vision_transformer('vit_base_r26_s32_224_in21k', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_r26_s32_384(pretrained=False, **kwargs):
|
||||
""" R26+ViT-B/S32 hybrid.
|
||||
"""
|
||||
backbone = _resnetv2((2, 2, 2, 2), **kwargs)
|
||||
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs)
|
||||
model = _create_vision_transformer('vit_base_r26_s32_384', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_r50_s16_224(pretrained=False, **kwargs):
|
||||
""" R50+ViT-B/S16 hybrid from original paper (https://arxiv.org/abs/2010.11929).
|
||||
|
@ -867,6 +1028,17 @@ def vit_large_r50_s32_224(pretrained=False, **kwargs):
|
|||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_large_r50_s32_224_in21k(pretrained=False, **kwargs):
|
||||
""" R50+ViT-L/S32 hybrid.
|
||||
"""
|
||||
backbone = _resnetv2((3, 4, 6, 3), **kwargs)
|
||||
model_kwargs = dict(
|
||||
embed_dim=768, depth=12, num_heads=12, representation_size=768, hybrid_backbone=backbone, **kwargs)
|
||||
model = _create_vision_transformer('vit_large_r50_s32_224_in21k', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_large_r50_s32_384(pretrained=False, **kwargs):
|
||||
""" R50+ViT-L/S32 hybrid.
|
||||
|
@ -927,6 +1099,31 @@ def vit_deit_tiny_patch16_224(pretrained=False, **kwargs):
|
|||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_deit_tiny_patch16_224_in21k_norep(pretrained=False, **kwargs):
|
||||
""" DeiT-tiny model"""
|
||||
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
|
||||
model = _create_vision_transformer('vit_deit_tiny_patch16_224_in21k_norep', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_deit_tiny_patch16_224_in21k(pretrained=False, **kwargs):
|
||||
""" DeiT-tiny model"""
|
||||
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, representation_size=192, **kwargs)
|
||||
model = _create_vision_transformer('vit_deit_tiny_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_deit_tiny_patch16_384(pretrained=False, **kwargs):
|
||||
""" DeiT-tiny model"""
|
||||
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
|
||||
model = _create_vision_transformer('vit_deit_tiny_patch16_384', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_deit_small_patch16_224(pretrained=False, **kwargs):
|
||||
""" DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
|
||||
|
@ -937,6 +1134,48 @@ def vit_deit_small_patch16_224(pretrained=False, **kwargs):
|
|||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_deit_small_patch16_224_in21k(pretrained=False, **kwargs):
|
||||
""" DeiT-small """
|
||||
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, representation_size=384, **kwargs)
|
||||
model = _create_vision_transformer('vit_deit_small_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_deit_small_patch16_384(pretrained=False, **kwargs):
|
||||
""" DeiT-small """
|
||||
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
|
||||
model = _create_vision_transformer('vit_deit_small_patch16_384', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_deit_small_patch32_224(pretrained=False, **kwargs):
|
||||
""" DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
|
||||
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
|
||||
model = _create_vision_transformer('vit_deit_small_patch32_224', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_deit_small_patch32_224_in21k(pretrained=False, **kwargs):
|
||||
""" DeiT-small """
|
||||
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, representation_size=384, **kwargs)
|
||||
model = _create_vision_transformer('vit_deit_small_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_deit_small_patch32_384(pretrained=False, **kwargs):
|
||||
""" DeiT-small """
|
||||
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
|
||||
model = _create_vision_transformer('vit_deit_small_patch32_384', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_deit_base_patch16_224(pretrained=False, **kwargs):
|
||||
""" DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
|
||||
|
|
Loading…
Reference in New Issue