mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add a few more test model defs in prep for weight upload
This commit is contained in:
parent
6ab2af610d
commit
a2f539f055
@ -20,7 +20,7 @@ from timm.layers import CondConv2d, get_condconv_initializer, get_act_layer, get
|
|||||||
from ._efficientnet_blocks import *
|
from ._efficientnet_blocks import *
|
||||||
from ._manipulate import named_modules
|
from ._manipulate import named_modules
|
||||||
|
|
||||||
__all__ = ["EfficientNetBuilder", "decode_arch_def", "efficientnet_init_weights",
|
__all__ = ["EfficientNetBuilder", "BlockArgs", "decode_arch_def", "efficientnet_init_weights",
|
||||||
'resolve_bn_args', 'resolve_act_layer', 'round_channels', 'BN_MOMENTUM_TF_DEFAULT', 'BN_EPS_TF_DEFAULT']
|
'resolve_bn_args', 'resolve_act_layer', 'round_channels', 'BN_MOMENTUM_TF_DEFAULT', 'BN_EPS_TF_DEFAULT']
|
||||||
|
|
||||||
_logger = logging.getLogger(__name__)
|
_logger = logging.getLogger(__name__)
|
||||||
|
@ -44,7 +44,8 @@ import torch.nn.functional as F
|
|||||||
from torch.utils.checkpoint import checkpoint
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||||
from timm.layers import create_conv2d, create_classifier, get_norm_act_layer, GroupNormAct, LayerType
|
from timm.layers import create_conv2d, create_classifier, get_norm_act_layer, LayerType, \
|
||||||
|
GroupNormAct, LayerNormAct2d, EvoNorm2dS0
|
||||||
from ._builder import build_model_with_cfg, pretrained_cfg_for_features
|
from ._builder import build_model_with_cfg, pretrained_cfg_for_features
|
||||||
from ._efficientnet_blocks import SqueezeExcite
|
from ._efficientnet_blocks import SqueezeExcite
|
||||||
from ._efficientnet_builder import BlockArgs, EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \
|
from ._efficientnet_builder import BlockArgs, EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \
|
||||||
@ -1808,6 +1809,14 @@ default_cfgs = generate_default_cfgs({
|
|||||||
hf_hub_id='timm/',
|
hf_hub_id='timm/',
|
||||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
||||||
input_size=(3, 160, 160), pool_size=(5, 5)),
|
input_size=(3, 160, 160), pool_size=(5, 5)),
|
||||||
|
"test_efficientnet_ln.r160_in1k": _cfg(
|
||||||
|
#hf_hub_id='timm/',
|
||||||
|
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
||||||
|
input_size=(3, 160, 160), pool_size=(5, 5)),
|
||||||
|
"test_efficientnet_evos.r160_in1k": _cfg(
|
||||||
|
#hf_hub_id='timm/',
|
||||||
|
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
||||||
|
input_size=(3, 160, 160), pool_size=(5, 5)),
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
@ -2802,6 +2811,21 @@ def test_efficientnet_gn(pretrained=False, **kwargs) -> EfficientNet:
|
|||||||
'test_efficientnet_gn', pretrained=pretrained, norm_layer=partial(GroupNormAct, group_size=8), **kwargs)
|
'test_efficientnet_gn', pretrained=pretrained, norm_layer=partial(GroupNormAct, group_size=8), **kwargs)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def test_efficientnet_ln(pretrained=False, **kwargs) -> EfficientNet:
|
||||||
|
model = _gen_test_efficientnet(
|
||||||
|
'test_efficientnet_ln', pretrained=pretrained, norm_layer=LayerNormAct2d, **kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def test_efficientnet_evos(pretrained=False, **kwargs) -> EfficientNet:
|
||||||
|
model = _gen_test_efficientnet(
|
||||||
|
'test_efficientnet_evos', pretrained=pretrained, norm_layer=partial(EvoNorm2dS0, group_size=8), **kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
register_model_deprecations(__name__, {
|
register_model_deprecations(__name__, {
|
||||||
'tf_efficientnet_b0_ap': 'tf_efficientnet_b0.ap_in1k',
|
'tf_efficientnet_b0_ap': 'tf_efficientnet_b0.ap_in1k',
|
||||||
'tf_efficientnet_b1_ap': 'tf_efficientnet_b1.ap_in1k',
|
'tf_efficientnet_b1_ap': 'tf_efficientnet_b1.ap_in1k',
|
||||||
|
@ -2015,6 +2015,12 @@ default_cfgs = {
|
|||||||
'test_vit.r160_in1k': _cfg(
|
'test_vit.r160_in1k': _cfg(
|
||||||
hf_hub_id='timm/',
|
hf_hub_id='timm/',
|
||||||
input_size=(3, 160, 160), crop_pct=0.875),
|
input_size=(3, 160, 160), crop_pct=0.875),
|
||||||
|
'test_vit2.r160_in1k': _cfg(
|
||||||
|
#hf_hub_id='timm/',
|
||||||
|
input_size=(3, 160, 160), crop_pct=0.875),
|
||||||
|
'test_vit3.r160_in1k': _cfg(
|
||||||
|
#hf_hub_id='timm/',
|
||||||
|
input_size=(3, 160, 160), crop_pct=0.875),
|
||||||
}
|
}
|
||||||
|
|
||||||
_quick_gelu_cfgs = [
|
_quick_gelu_cfgs = [
|
||||||
@ -3216,6 +3222,26 @@ def test_vit(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def test_vit2(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||||
|
""" ViT Test
|
||||||
|
"""
|
||||||
|
model_args = dict(
|
||||||
|
patch_size=16, embed_dim=64, depth=8, num_heads=2, mlp_ratio=3,
|
||||||
|
class_token=False, reg_tokens=1, global_pool='avg', init_values=1e-5)
|
||||||
|
model = _create_vision_transformer('test_vit2', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def test_vit3(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||||
|
""" ViT Test
|
||||||
|
"""
|
||||||
|
model_args = dict(
|
||||||
|
patch_size=16, embed_dim=96, depth=10, num_heads=3, mlp_ratio=2,
|
||||||
|
class_token=False, reg_tokens=1, global_pool='map', init_values=1e-5)
|
||||||
|
model = _create_vision_transformer('test_vit3', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
register_model_deprecations(__name__, {
|
register_model_deprecations(__name__, {
|
||||||
'vit_tiny_patch16_224_in21k': 'vit_tiny_patch16_224.augreg_in21k',
|
'vit_tiny_patch16_224_in21k': 'vit_tiny_patch16_224.augreg_in21k',
|
||||||
'vit_small_patch32_224_in21k': 'vit_small_patch32_224.augreg_in21k',
|
'vit_small_patch32_224_in21k': 'vit_small_patch32_224.augreg_in21k',
|
||||||
|
Loading…
x
Reference in New Issue
Block a user