mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Per-stage eca MV3 config, going to shelve these experiments for a while, not proving very productive
This commit is contained in:
parent
ade1ba5fe3
commit
8bd08d5f64
@ -85,7 +85,8 @@ def _decode_block_str(block_str):
|
||||
attn_layer = 'sev2'
|
||||
attn_kwargs = dict(se_ratio=float(options['se']))
|
||||
elif 'eca' in options:
|
||||
attn_layer = 'eca'
|
||||
attn_layer = 'ceca'
|
||||
attn_kwargs = dict(kernel_size=int(options['eca']))
|
||||
|
||||
num_repeat = int(options['r'])
|
||||
# each type of block has different valid arguments, fill accordingly
|
||||
@ -378,7 +379,7 @@ class EfficientNetBuilder:
|
||||
return stages
|
||||
|
||||
|
||||
def _init_weight_goog(m, n='', fix_group_fanout=False):
|
||||
def _init_weight_goog(m, n='', fix_group_fanout=True):
|
||||
""" Weight initialization as per Tensorflow official implementations.
|
||||
|
||||
Args:
|
||||
|
@ -389,23 +389,38 @@ def _gen_mobilenet_v3_eca(variant, channel_multiplier=1.0, pretrained=False, **k
|
||||
else:
|
||||
num_features = 1280
|
||||
act_layer = HardSwish
|
||||
# arch_def = [
|
||||
# # stage 0, 112x112 in
|
||||
# ['ds_r1_k3_s1_e1_c16_nre'], # relu
|
||||
# # stage 1, 112x112 in
|
||||
# ['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'], # relu
|
||||
# # stage 2, 56x56 in
|
||||
# ['ir_r3_k5_s2_e3_c40_nre'], # relu
|
||||
# # stage 3, 28x28 in
|
||||
# ['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], # hard-swish
|
||||
# # stage 4, 14x14in
|
||||
# ['ir_r2_k3_s1_e6_c112'], # hard-swish
|
||||
# # stage 5, 14x14in
|
||||
# ['ir_r3_k5_s2_e6_c160'], # hard-swish
|
||||
# # stage 6, 7x7 in
|
||||
# ['cn_r1_k1_s1_c960'], # hard-swish
|
||||
# ]
|
||||
arch_def = [
|
||||
# stage 0, 112x112 in
|
||||
['ds_r1_k3_s1_e1_c16_nre'], # relu
|
||||
# stage 1, 112x112 in
|
||||
['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'], # relu
|
||||
# stage 2, 56x56 in
|
||||
['ir_r3_k5_s2_e3_c40_nre'], # relu
|
||||
['ir_r3_k5_s2_e3_c40_eca3_nre'], # relu
|
||||
# stage 3, 28x28 in
|
||||
['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], # hard-swish
|
||||
['ir_r1_k3_s2_e6_c80_eca3', 'ir_r1_k3_s1_e2.5_c80_eca3', 'ir_r2_k3_s1_e2.3_c80_eca3'], # hard-swish
|
||||
# stage 4, 14x14in
|
||||
['ir_r2_k3_s1_e6_c112'], # hard-swish
|
||||
['ir_r2_k3_s1_e6_c112_eca5'], # hard-swish
|
||||
# stage 5, 14x14in
|
||||
['ir_r3_k5_s2_e6_c160'], # hard-swish
|
||||
['ir_r3_k5_s2_e6_c160_eca5'], # hard-swish
|
||||
# stage 6, 7x7 in
|
||||
['cn_r1_k1_s1_c960'], # hard-swish
|
||||
]
|
||||
|
||||
model_kwargs = dict(
|
||||
block_args=decode_arch_def(arch_def),
|
||||
num_features=num_features,
|
||||
@ -413,7 +428,7 @@ def _gen_mobilenet_v3_eca(variant, channel_multiplier=1.0, pretrained=False, **k
|
||||
channel_multiplier=channel_multiplier,
|
||||
norm_kwargs=resolve_bn_args(kwargs),
|
||||
act_layer=act_layer,
|
||||
attn_layer='ceca',
|
||||
#attn_layer='ceca',
|
||||
attn_kwargs=dict(gate_fn=hard_sigmoid),
|
||||
**kwargs,
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user