mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge remote-tracking branch 'origin/tiny_test_models' into small_things
This commit is contained in:
commit
d5afe106dc
@ -271,8 +271,9 @@ class BasicBlock(nn.Module):
|
||||
def forward(self, x):
|
||||
shortcut = x
|
||||
x = self.conv1_kxk(x)
|
||||
x = self.conv2_kxk(x)
|
||||
x = self.attn(x)
|
||||
x = self.conv2_kxk(x)
|
||||
x = self.attn_last(x)
|
||||
x = self.drop_path(x)
|
||||
if self.shortcut is not None:
|
||||
x = x + self.shortcut(shortcut)
|
||||
@ -449,7 +450,6 @@ class EdgeBlock(nn.Module):
|
||||
downsample, in_chs, out_chs,
|
||||
stride=stride, dilation=dilation, apply_act=False, layers=layers,
|
||||
)
|
||||
|
||||
self.conv1_kxk = layers.conv_norm_act(
|
||||
in_chs, mid_chs, kernel_size,
|
||||
stride=stride, dilation=dilation[0], groups=groups, drop_layer=drop_block,
|
||||
@ -1931,7 +1931,6 @@ model_cfgs = dict(
|
||||
aa_layer='avg',
|
||||
head_type='attn_abs',
|
||||
),
|
||||
|
||||
resnet101_clip=ByoModelCfg(
|
||||
blocks=(
|
||||
ByoBlockCfg(type='bottle', d=3, c=256, s=1, br=0.25),
|
||||
@ -1946,7 +1945,6 @@ model_cfgs = dict(
|
||||
aa_layer='avg',
|
||||
head_type='attn_abs',
|
||||
),
|
||||
|
||||
resnet50x4_clip=ByoModelCfg(
|
||||
blocks=(
|
||||
ByoBlockCfg(type='bottle', d=4, c=256, s=1, br=0.25),
|
||||
@ -1962,7 +1960,6 @@ model_cfgs = dict(
|
||||
aa_layer='avg',
|
||||
head_type='attn_abs',
|
||||
),
|
||||
|
||||
resnet50x16_clip=ByoModelCfg(
|
||||
blocks=(
|
||||
ByoBlockCfg(type='bottle', d=6, c=256, s=1, br=0.25),
|
||||
@ -1978,7 +1975,6 @@ model_cfgs = dict(
|
||||
aa_layer='avg',
|
||||
head_type='attn_abs',
|
||||
),
|
||||
|
||||
resnet50x64_clip=ByoModelCfg(
|
||||
blocks=(
|
||||
ByoBlockCfg(type='bottle', d=3, c=256, s=1, br=0.25),
|
||||
@ -2010,6 +2006,21 @@ model_cfgs = dict(
|
||||
head_hidden_size=1024,
|
||||
head_type='mlp',
|
||||
),
|
||||
|
||||
test_byobnet=ByoModelCfg(
|
||||
blocks=(
|
||||
ByoBlockCfg(type='edge', d=1, c=32, s=2, gs=0, br=0.5),
|
||||
ByoBlockCfg(type='dark', d=1, c=64, s=2, gs=0, br=0.5),
|
||||
ByoBlockCfg(type='basic', d=1, c=128, s=2, gs=32, br=0.25),
|
||||
ByoBlockCfg(type='bottle', d=1, c=256, s=2, gs=64, br=0.25),
|
||||
),
|
||||
stem_chs=24,
|
||||
downsample='avg',
|
||||
stem_pool='',
|
||||
act_layer='relu',
|
||||
attn_layer='se',
|
||||
attn_kwargs=dict(rd_ratio=0.25),
|
||||
),
|
||||
)
|
||||
for k in ('resnet50_clip', 'resnet101_clip', 'resnet50x4_clip', 'resnet50x16_clip', 'resnet50x64_clip'):
|
||||
model_cfgs[k + '_gap'] = replace(model_cfgs[k], head_type='classifier')
|
||||
@ -2340,6 +2351,11 @@ default_cfgs = generate_default_cfgs({
|
||||
'resnet50_mlp.untrained': _cfgr(
|
||||
input_size=(3, 256, 256), pool_size=(8, 8),
|
||||
),
|
||||
|
||||
'test_byobnet.untrained': _cfgr(
|
||||
# hf_hub_id='timm/',
|
||||
input_size=(3, 160, 160), crop_pct=0.875, pool_size=(5, 5),
|
||||
),
|
||||
})
|
||||
|
||||
|
||||
@ -2719,3 +2735,10 @@ def resnet50_mlp(pretrained=False, **kwargs) -> ByobNet:
|
||||
"""
|
||||
"""
|
||||
return _create_byobnet('resnet50_mlp', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def test_byobnet(pretrained=False, **kwargs) -> ByobNet:
|
||||
""" Minimal test ResNet (BYOB based) model.
|
||||
"""
|
||||
return _create_byobnet('test_byobnet', pretrained=pretrained, **kwargs)
|
||||
|
@ -1183,6 +1183,31 @@ def _gen_mobilenet_edgetpu(variant, channel_multiplier=1.0, depth_multiplier=1.0
|
||||
return model
|
||||
|
||||
|
||||
def _gen_test_efficientnet(
|
||||
variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
|
||||
""" Minimal test EfficientNet generator.
|
||||
"""
|
||||
arch_def = [
|
||||
['cn_r1_k3_s1_e1_c16_skip'],
|
||||
['er_r1_k3_s2_e4_c24'],
|
||||
['er_r1_k3_s2_e4_c32'],
|
||||
['ir_r1_k3_s2_e4_c48_se0.25'],
|
||||
['ir_r1_k3_s2_e4_c64_se0.25'],
|
||||
]
|
||||
round_chs_fn = partial(round_channels, multiplier=channel_multiplier, round_limit=0.)
|
||||
model_kwargs = dict(
|
||||
block_args=decode_arch_def(arch_def, depth_multiplier),
|
||||
num_features=round_chs_fn(256),
|
||||
stem_size=24,
|
||||
round_chs_fn=round_chs_fn,
|
||||
norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
|
||||
act_layer=resolve_act_layer(kwargs, 'silu'),
|
||||
**kwargs,
|
||||
)
|
||||
model = _create_effnet(variant, pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||
@ -1731,6 +1756,9 @@ default_cfgs = generate_default_cfgs({
|
||||
#hf_hub_id='timm/',
|
||||
input_size=(3, 224, 224), crop_pct=0.9),
|
||||
|
||||
"test_efficientnet.untrained": _cfg(
|
||||
# hf_hub_id='timm/'
|
||||
input_size=(3, 160, 160), pool_size=(5, 5)),
|
||||
})
|
||||
|
||||
|
||||
@ -2713,6 +2741,12 @@ def mobilenet_edgetpu_v2_l(pretrained=False, **kwargs) -> EfficientNet:
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def test_efficientnet(pretrained=False, **kwargs) -> EfficientNet:
|
||||
model = _gen_test_efficientnet('test_efficientnet', pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
register_model_deprecations(__name__, {
|
||||
'tf_efficientnet_b0_ap': 'tf_efficientnet_b0.ap_in1k',
|
||||
'tf_efficientnet_b1_ap': 'tf_efficientnet_b1.ap_in1k',
|
||||
|
@ -1957,13 +1957,16 @@ default_cfgs = {
|
||||
hf_hub_id='timm/',
|
||||
num_classes=11821,
|
||||
input_size=(3, 256, 256), crop_pct=0.95),
|
||||
'vit_base_patch16_reg4_gap_256': _cfg(
|
||||
'vit_base_patch16_reg4_gap_256.untrained': _cfg(
|
||||
input_size=(3, 256, 256)),
|
||||
|
||||
'vit_so150m_patch16_reg4_gap_256': _cfg(
|
||||
'vit_so150m_patch16_reg4_gap_256.untrained': _cfg(
|
||||
input_size=(3, 256, 256)),
|
||||
'vit_so150m_patch16_reg4_map_256': _cfg(
|
||||
'vit_so150m_patch16_reg4_map_256.untrained': _cfg(
|
||||
input_size=(3, 256, 256)),
|
||||
|
||||
'test_vit.untrained': _cfg(
|
||||
input_size=(3, 160, 160), crop_pct=0.875),
|
||||
}
|
||||
|
||||
_quick_gelu_cfgs = [
|
||||
@ -3134,6 +3137,15 @@ def vit_so150m_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> Visio
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def test_vit(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||
""" ViT Test
|
||||
"""
|
||||
model_args = dict(patch_size=16, embed_dim=64, depth=6, num_heads=2, mlp_ratio=3)
|
||||
model = _create_vision_transformer('test_vit', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
register_model_deprecations(__name__, {
|
||||
'vit_tiny_patch16_224_in21k': 'vit_tiny_patch16_224.augreg_in21k',
|
||||
'vit_small_patch32_224_in21k': 'vit_small_patch32_224.augreg_in21k',
|
||||
|
Loading…
x
Reference in New Issue
Block a user