mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Rename test_tiny* -> test*. Fix ByobNet BasicBlock attn location and add test_byobnet model.
This commit is contained in:
parent
66a0eb4673
commit
55101028bb
@ -261,8 +261,9 @@ class BasicBlock(nn.Module):
|
||||
def forward(self, x):
|
||||
shortcut = x
|
||||
x = self.conv1_kxk(x)
|
||||
x = self.conv2_kxk(x)
|
||||
x = self.attn(x)
|
||||
x = self.conv2_kxk(x)
|
||||
x = self.attn_last(x)
|
||||
x = self.drop_path(x)
|
||||
if self.shortcut is not None:
|
||||
x = x + self.shortcut(shortcut)
|
||||
@ -439,7 +440,6 @@ class EdgeBlock(nn.Module):
|
||||
downsample, in_chs, out_chs,
|
||||
stride=stride, dilation=dilation, apply_act=False, layers=layers,
|
||||
)
|
||||
|
||||
self.conv1_kxk = layers.conv_norm_act(
|
||||
in_chs, mid_chs, kernel_size,
|
||||
stride=stride, dilation=dilation[0], groups=groups, drop_layer=drop_block,
|
||||
@ -1835,16 +1835,19 @@ model_cfgs = dict(
|
||||
stem_chs=64,
|
||||
),
|
||||
|
||||
test_tiny_resnet=ByoModelCfg(
|
||||
test_byobnet=ByoModelCfg(
|
||||
blocks=(
|
||||
ByoBlockCfg(type='basic', d=1, c=24, s=1, gs=1, br=0.25),
|
||||
ByoBlockCfg(type='basic', d=1, c=32, s=2, gs=1, br=0.25),
|
||||
ByoBlockCfg(type='basic', d=1, c=64, s=2, gs=1, br=0.25),
|
||||
ByoBlockCfg(type='basic', d=1, c=128, s=2, gs=1, br=0.25),
|
||||
ByoBlockCfg(type='edge', d=1, c=32, s=2, gs=0, br=0.5),
|
||||
ByoBlockCfg(type='dark', d=1, c=64, s=2, gs=0, br=0.5),
|
||||
ByoBlockCfg(type='basic', d=1, c=128, s=2, gs=32, br=0.25),
|
||||
ByoBlockCfg(type='bottle', d=1, c=256, s=2, gs=64, br=0.25),
|
||||
),
|
||||
stem_chs=32,
|
||||
stem_pool='maxpool',
|
||||
stem_chs=24,
|
||||
downsample='avg',
|
||||
stem_pool='',
|
||||
act_layer='relu',
|
||||
attn_layer='se',
|
||||
attn_kwargs=dict(rd_ratio=0.25),
|
||||
),
|
||||
)
|
||||
|
||||
@ -2048,7 +2051,7 @@ default_cfgs = generate_default_cfgs({
|
||||
first_conv=('stem.conv_kxk.0.conv', 'stem.conv_scale.conv'),
|
||||
),
|
||||
|
||||
'test_tiny_byobnet.untrained': _cfgr(
|
||||
'test_byobnet.untrained': _cfgr(
|
||||
# hf_hub_id='timm/',
|
||||
input_size=(3, 160, 160), crop_pct=0.875, pool_size=(5, 5),
|
||||
),
|
||||
@ -2357,7 +2360,7 @@ def mobileone_s4(pretrained=False, **kwargs) -> ByobNet:
|
||||
|
||||
|
||||
@register_model
|
||||
def test_tiny_byobnet(pretrained=False, **kwargs) -> ByobNet:
|
||||
def test_byobnet(pretrained=False, **kwargs) -> ByobNet:
|
||||
""" Minimal test ResNet (BYOB based) model.
|
||||
"""
|
||||
return _create_byobnet('test_tiny_byobnet', pretrained=pretrained, **kwargs)
|
||||
return _create_byobnet('test_byobnet', pretrained=pretrained, **kwargs)
|
||||
|
@ -1610,7 +1610,7 @@ default_cfgs = generate_default_cfgs({
|
||||
url='https://github.com/huawei-noah/CV-Backbones/releases/download/v1.2.0/tinynet_e.pth',
|
||||
hf_hub_id='timm/'),
|
||||
|
||||
"test_tiny_efficientnet.untrained": _cfg(
|
||||
"test_efficientnet.untrained": _cfg(
|
||||
# hf_hub_id='timm/'
|
||||
input_size=(3, 160, 160), pool_size=(5, 5)),
|
||||
})
|
||||
@ -2540,8 +2540,8 @@ def tinynet_e(pretrained=False, **kwargs) -> EfficientNet:
|
||||
|
||||
|
||||
@register_model
|
||||
def test_tiny_efficientnet(pretrained=False, **kwargs) -> EfficientNet:
|
||||
model = _gen_test_efficientnet('test_tiny_efficientnet', pretrained=pretrained, **kwargs)
|
||||
def test_efficientnet(pretrained=False, **kwargs) -> EfficientNet:
|
||||
model = _gen_test_efficientnet('test_efficientnet', pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
|
@ -1937,7 +1937,7 @@ default_cfgs = {
|
||||
'vit_so150m_patch16_reg4_map_256.untrained': _cfg(
|
||||
input_size=(3, 256, 256)),
|
||||
|
||||
'test_tiny_vit.untrained': _cfg(
|
||||
'test_vit.untrained': _cfg(
|
||||
input_size=(3, 160, 160), crop_pct=0.875),
|
||||
}
|
||||
|
||||
@ -3110,11 +3110,11 @@ def vit_so150m_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> Visio
|
||||
|
||||
|
||||
@register_model
|
||||
def test_tiny_vit(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||
""" ViT-TestTiny
|
||||
def test_vit(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||
""" ViT Test
|
||||
"""
|
||||
model_args = dict(patch_size=16, embed_dim=64, depth=4, num_heads=1, mlp_ratio=3)
|
||||
model = _create_vision_transformer('test_tiny_vit', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
model_args = dict(patch_size=16, embed_dim=64, depth=6, num_heads=2, mlp_ratio=3)
|
||||
model = _create_vision_transformer('test_vit', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user