mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge pull request #263 from rwightman/fixes_oct2020
Fixes for upcoming PyPi release
This commit is contained in:
commit
af3299ba4a
@ -24,7 +24,7 @@ MAX_FWD_FEAT_SIZE = 448
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.timeout(120)
|
@pytest.mark.timeout(120)
|
||||||
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS))
|
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS[:-1]))
|
||||||
@pytest.mark.parametrize('batch_size', [1])
|
@pytest.mark.parametrize('batch_size', [1])
|
||||||
def test_model_forward(model_name, batch_size):
|
def test_model_forward(model_name, batch_size):
|
||||||
"""Run a single forward pass with each model"""
|
"""Run a single forward pass with each model"""
|
||||||
|
@ -277,11 +277,12 @@ def build_model_with_cfg(
|
|||||||
if pruned:
|
if pruned:
|
||||||
model = adapt_model_from_file(model, variant)
|
model = adapt_model_from_file(model, variant)
|
||||||
|
|
||||||
|
# for classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats
|
||||||
|
num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000))
|
||||||
if pretrained:
|
if pretrained:
|
||||||
load_pretrained(
|
load_pretrained(
|
||||||
model,
|
model,
|
||||||
num_classes=kwargs.get('num_classes', 0),
|
num_classes=num_classes_pretrained, in_chans=kwargs.get('in_chans', 3),
|
||||||
in_chans=kwargs.get('in_chans', 3),
|
|
||||||
filter_fn=pretrained_filter_fn, strict=pretrained_strict)
|
filter_fn=pretrained_filter_fn, strict=pretrained_strict)
|
||||||
|
|
||||||
if features:
|
if features:
|
||||||
|
@ -776,6 +776,7 @@ def _create_hrnet(variant, pretrained, **model_kwargs):
|
|||||||
strict = True
|
strict = True
|
||||||
if model_kwargs.pop('features_only', False):
|
if model_kwargs.pop('features_only', False):
|
||||||
model_cls = HighResolutionNetFeatures
|
model_cls = HighResolutionNetFeatures
|
||||||
|
model_kwargs['num_classes'] = 0
|
||||||
strict = False
|
strict = False
|
||||||
|
|
||||||
return build_model_with_cfg(
|
return build_model_with_cfg(
|
||||||
|
@ -6,9 +6,14 @@ from .activations_jit import *
|
|||||||
from .activations_me import *
|
from .activations_me import *
|
||||||
from .config import is_exportable, is_scriptable, is_no_jit
|
from .config import is_exportable, is_scriptable, is_no_jit
|
||||||
|
|
||||||
|
# PyTorch has an optimized, native 'silu' (aka 'swish') operator as of PyTorch 1.7. This code
|
||||||
|
# will use native version if present. Eventually, the custom Swish layers will be removed
|
||||||
|
# and only native 'silu' will be used.
|
||||||
|
_has_silu = 'silu' in dir(torch.nn.functional)
|
||||||
|
|
||||||
_ACT_FN_DEFAULT = dict(
|
_ACT_FN_DEFAULT = dict(
|
||||||
swish=swish,
|
silu=F.silu if _has_silu else swish,
|
||||||
|
swish=F.silu if _has_silu else swish,
|
||||||
mish=mish,
|
mish=mish,
|
||||||
relu=F.relu,
|
relu=F.relu,
|
||||||
relu6=F.relu6,
|
relu6=F.relu6,
|
||||||
@ -26,7 +31,8 @@ _ACT_FN_DEFAULT = dict(
|
|||||||
)
|
)
|
||||||
|
|
||||||
_ACT_FN_JIT = dict(
|
_ACT_FN_JIT = dict(
|
||||||
swish=swish_jit,
|
silu=F.silu if _has_silu else swish_jit,
|
||||||
|
swish=F.silu if _has_silu else swish_jit,
|
||||||
mish=mish_jit,
|
mish=mish_jit,
|
||||||
hard_sigmoid=hard_sigmoid_jit,
|
hard_sigmoid=hard_sigmoid_jit,
|
||||||
hard_swish=hard_swish_jit,
|
hard_swish=hard_swish_jit,
|
||||||
@ -34,7 +40,8 @@ _ACT_FN_JIT = dict(
|
|||||||
)
|
)
|
||||||
|
|
||||||
_ACT_FN_ME = dict(
|
_ACT_FN_ME = dict(
|
||||||
swish=swish_me,
|
silu=F.silu if _has_silu else swish_me,
|
||||||
|
swish=F.silu if _has_silu else swish_me,
|
||||||
mish=mish_me,
|
mish=mish_me,
|
||||||
hard_sigmoid=hard_sigmoid_me,
|
hard_sigmoid=hard_sigmoid_me,
|
||||||
hard_swish=hard_swish_me,
|
hard_swish=hard_swish_me,
|
||||||
@ -42,7 +49,8 @@ _ACT_FN_ME = dict(
|
|||||||
)
|
)
|
||||||
|
|
||||||
_ACT_LAYER_DEFAULT = dict(
|
_ACT_LAYER_DEFAULT = dict(
|
||||||
swish=Swish,
|
silu=nn.SiLU if _has_silu else Swish,
|
||||||
|
swish=nn.SiLU if _has_silu else Swish,
|
||||||
mish=Mish,
|
mish=Mish,
|
||||||
relu=nn.ReLU,
|
relu=nn.ReLU,
|
||||||
relu6=nn.ReLU6,
|
relu6=nn.ReLU6,
|
||||||
@ -60,7 +68,8 @@ _ACT_LAYER_DEFAULT = dict(
|
|||||||
)
|
)
|
||||||
|
|
||||||
_ACT_LAYER_JIT = dict(
|
_ACT_LAYER_JIT = dict(
|
||||||
swish=SwishJit,
|
silu=nn.SiLU if _has_silu else SwishJit,
|
||||||
|
swish=nn.SiLU if _has_silu else SwishJit,
|
||||||
mish=MishJit,
|
mish=MishJit,
|
||||||
hard_sigmoid=HardSigmoidJit,
|
hard_sigmoid=HardSigmoidJit,
|
||||||
hard_swish=HardSwishJit,
|
hard_swish=HardSwishJit,
|
||||||
@ -68,7 +77,8 @@ _ACT_LAYER_JIT = dict(
|
|||||||
)
|
)
|
||||||
|
|
||||||
_ACT_LAYER_ME = dict(
|
_ACT_LAYER_ME = dict(
|
||||||
swish=SwishMe,
|
silu=nn.SiLU if _has_silu else SwishMe,
|
||||||
|
swish=nn.SiLU if _has_silu else SwishMe,
|
||||||
mish=MishMe,
|
mish=MishMe,
|
||||||
hard_sigmoid=HardSigmoidMe,
|
hard_sigmoid=HardSigmoidMe,
|
||||||
hard_swish=HardSwishMe,
|
hard_swish=HardSwishMe,
|
||||||
|
@ -37,7 +37,7 @@ def _cfg(url='', **kwargs):
|
|||||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
||||||
'crop_pct': .9, 'interpolation': 'bicubic',
|
'crop_pct': .9, 'interpolation': 'bicubic',
|
||||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||||
'first_conv': '', 'classifier': 'head',
|
'first_conv': 'patch_embed.proj', 'classifier': 'head',
|
||||||
**kwargs
|
**kwargs
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -48,7 +48,8 @@ default_cfgs = {
|
|||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth',
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth',
|
||||||
),
|
),
|
||||||
'vit_base_patch16_224': _cfg(
|
'vit_base_patch16_224': _cfg(
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_base_p16_224-4e355ebd.pth',
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth',
|
||||||
|
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
||||||
),
|
),
|
||||||
'vit_base_patch16_384': _cfg(
|
'vit_base_patch16_384': _cfg(
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth',
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth',
|
||||||
@ -56,7 +57,9 @@ default_cfgs = {
|
|||||||
'vit_base_patch32_384': _cfg(
|
'vit_base_patch32_384': _cfg(
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p32_384-830016f5.pth',
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p32_384-830016f5.pth',
|
||||||
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
|
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
|
||||||
'vit_large_patch16_224': _cfg(),
|
'vit_large_patch16_224': _cfg(
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth',
|
||||||
|
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
||||||
'vit_large_patch16_384': _cfg(
|
'vit_large_patch16_384': _cfg(
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth',
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth',
|
||||||
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
|
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
|
||||||
@ -206,7 +209,7 @@ class VisionTransformer(nn.Module):
|
|||||||
drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm):
|
drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.embed_dim = embed_dim
|
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
||||||
|
|
||||||
if hybrid_backbone is not None:
|
if hybrid_backbone is not None:
|
||||||
self.patch_embed = HybridEmbed(
|
self.patch_embed = HybridEmbed(
|
||||||
@ -305,10 +308,9 @@ def vit_small_patch16_224(pretrained=False, **kwargs):
|
|||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def vit_base_patch16_224(pretrained=False, **kwargs):
|
def vit_base_patch16_224(pretrained=False, **kwargs):
|
||||||
if pretrained:
|
model = VisionTransformer(
|
||||||
# NOTE my scale was wrong for original weights, leaving this here until I have better ones for this model
|
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
||||||
kwargs.setdefault('qk_scale', 768 ** -0.5)
|
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
||||||
model = VisionTransformer(patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs)
|
|
||||||
model.default_cfg = default_cfgs['vit_base_patch16_224']
|
model.default_cfg = default_cfgs['vit_base_patch16_224']
|
||||||
if pretrained:
|
if pretrained:
|
||||||
load_pretrained(
|
load_pretrained(
|
||||||
@ -340,8 +342,12 @@ def vit_base_patch32_384(pretrained=False, **kwargs):
|
|||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def vit_large_patch16_224(pretrained=False, **kwargs):
|
def vit_large_patch16_224(pretrained=False, **kwargs):
|
||||||
model = VisionTransformer(patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, **kwargs)
|
model = VisionTransformer(
|
||||||
|
patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
|
||||||
|
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
||||||
model.default_cfg = default_cfgs['vit_large_patch16_224']
|
model.default_cfg = default_cfgs['vit_large_patch16_224']
|
||||||
|
if pretrained:
|
||||||
|
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@ -1 +1 @@
|
|||||||
__version__ = '0.2.2'
|
__version__ = '0.3.0'
|
||||||
|
Loading…
x
Reference in New Issue
Block a user