mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add small vision transformer weights. 77.42 top-1.
This commit is contained in:
parent
ccfb5751ab
commit
d4db9e7977
@ -29,7 +29,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .helpers import build_model_with_cfg
|
||||
from .helpers import load_pretrained
|
||||
from .layers import DropPath, to_2tuple, trunc_normal_
|
||||
from .resnet import resnet26d, resnet50d
|
||||
from .registry import register_model
|
||||
@ -48,7 +48,9 @@ def _cfg(url='', **kwargs):
|
||||
|
||||
default_cfgs = {
|
||||
# patch models
|
||||
'vit_small_patch16_224': _cfg(),
|
||||
'vit_small_patch16_224': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth',
|
||||
),
|
||||
'vit_base_patch16_224': _cfg(),
|
||||
'vit_base_patch16_384': _cfg(input_size=(3, 384, 384)),
|
||||
'vit_base_patch32_384': _cfg(input_size=(3, 384, 384)),
|
||||
@ -271,6 +273,9 @@ class VisionTransformer(nn.Module):
|
||||
def vit_small_patch16_224(pretrained=False, **kwargs):
|
||||
model = VisionTransformer(patch_size=16, embed_dim=768, depth=8, num_heads=8, mlp_ratio=3., **kwargs)
|
||||
model.default_cfg = default_cfgs['vit_small_patch16_224']
|
||||
if pretrained:
|
||||
load_pretrained(
|
||||
model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
|
||||
return model
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user