mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add small vision transformer weights. 77.42 top-1.
This commit is contained in:
parent
ccfb5751ab
commit
d4db9e7977
@ -29,7 +29,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from .helpers import build_model_with_cfg
|
from .helpers import load_pretrained
|
||||||
from .layers import DropPath, to_2tuple, trunc_normal_
|
from .layers import DropPath, to_2tuple, trunc_normal_
|
||||||
from .resnet import resnet26d, resnet50d
|
from .resnet import resnet26d, resnet50d
|
||||||
from .registry import register_model
|
from .registry import register_model
|
||||||
@ -48,7 +48,9 @@ def _cfg(url='', **kwargs):
|
|||||||
|
|
||||||
default_cfgs = {
|
default_cfgs = {
|
||||||
# patch models
|
# patch models
|
||||||
'vit_small_patch16_224': _cfg(),
|
'vit_small_patch16_224': _cfg(
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth',
|
||||||
|
),
|
||||||
'vit_base_patch16_224': _cfg(),
|
'vit_base_patch16_224': _cfg(),
|
||||||
'vit_base_patch16_384': _cfg(input_size=(3, 384, 384)),
|
'vit_base_patch16_384': _cfg(input_size=(3, 384, 384)),
|
||||||
'vit_base_patch32_384': _cfg(input_size=(3, 384, 384)),
|
'vit_base_patch32_384': _cfg(input_size=(3, 384, 384)),
|
||||||
@ -271,6 +273,9 @@ class VisionTransformer(nn.Module):
|
|||||||
def vit_small_patch16_224(pretrained=False, **kwargs):
|
def vit_small_patch16_224(pretrained=False, **kwargs):
|
||||||
model = VisionTransformer(patch_size=16, embed_dim=768, depth=8, num_heads=8, mlp_ratio=3., **kwargs)
|
model = VisionTransformer(patch_size=16, embed_dim=768, depth=8, num_heads=8, mlp_ratio=3., **kwargs)
|
||||||
model.default_cfg = default_cfgs['vit_small_patch16_224']
|
model.default_cfg = default_cfgs['vit_small_patch16_224']
|
||||||
|
if pretrained:
|
||||||
|
load_pretrained(
|
||||||
|
model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user