mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
A few MobileNetV3 tweaks
* fix expansion ratio on early block * change comment re stride mistake in paper * fix rounding not being called properly for all multipliers != 1.0
This commit is contained in:
parent
6523e4abe4
commit
17da1adaca
@ -285,7 +285,7 @@ class _BlockBuilder:
|
|||||||
def _make_block(self, ba):
|
def _make_block(self, ba):
|
||||||
bt = ba.pop('block_type')
|
bt = ba.pop('block_type')
|
||||||
ba['in_chs'] = self.in_chs
|
ba['in_chs'] = self.in_chs
|
||||||
ba['out_chs'] = _round_channels(ba['out_chs'])
|
ba['out_chs'] = self._round_channels(ba['out_chs'])
|
||||||
ba['bn_momentum'] = self.bn_momentum
|
ba['bn_momentum'] = self.bn_momentum
|
||||||
ba['bn_eps'] = self.bn_eps
|
ba['bn_eps'] = self.bn_eps
|
||||||
ba['folded_bn'] = self.folded_bn
|
ba['folded_bn'] = self.folded_bn
|
||||||
@ -676,6 +676,7 @@ class GenMobileNet(nn.Module):
|
|||||||
# efficient head, currently only mobilenet-v3 performs pool before last 1x1 conv
|
# efficient head, currently only mobilenet-v3 performs pool before last 1x1 conv
|
||||||
x = self.global_pool(x) # always need to pool here regardless of bool
|
x = self.global_pool(x) # always need to pool here regardless of bool
|
||||||
x = self.conv_head(x)
|
x = self.conv_head(x)
|
||||||
|
# no BN
|
||||||
x = self.act_fn(x)
|
x = self.act_fn(x)
|
||||||
if pool:
|
if pool:
|
||||||
# expect flattened output if pool is true, otherwise keep dim
|
# expect flattened output if pool is true, otherwise keep dim
|
||||||
@ -884,7 +885,7 @@ def _gen_mobilenet_v3(depth_multiplier, num_classes=1000, **kwargs):
|
|||||||
# stage 0, 112x112 in
|
# stage 0, 112x112 in
|
||||||
['ds_r1_k3_s1_e1_c16_are_noskip'], # relu
|
['ds_r1_k3_s1_e1_c16_are_noskip'], # relu
|
||||||
# stage 1, 112x112 in
|
# stage 1, 112x112 in
|
||||||
['ir_r1_k3_s2_e4_c24_are', 'ir_r1_k3_s1_e6_c24_are'], # relu
|
['ir_r1_k3_s2_e4_c24_are', 'ir_r1_k3_s1_e3_c24_are'], # relu
|
||||||
# stage 2, 56x56 in
|
# stage 2, 56x56 in
|
||||||
['ir_r3_k5_s2_e3_c40_se0.25_are'], # relu
|
['ir_r3_k5_s2_e3_c40_se0.25_are'], # relu
|
||||||
# stage 3, 28x28 in
|
# stage 3, 28x28 in
|
||||||
@ -893,9 +894,10 @@ def _gen_mobilenet_v3(depth_multiplier, num_classes=1000, **kwargs):
|
|||||||
# stage 4, 14x14in
|
# stage 4, 14x14in
|
||||||
['ir_r2_k3_s1_e6_c112_se0.25'], # hard-swish
|
['ir_r2_k3_s1_e6_c112_se0.25'], # hard-swish
|
||||||
# stage 5, 14x14in
|
# stage 5, 14x14in
|
||||||
# FIXME the paper contains a weird block-stride pattern 1-2-1 that doesn't fit the usual 2-1-...
|
# FIXME paper has a mistaken block-stride pattern 1-2-1 that doesn't fit the usual 2-1-..., ignoring
|
||||||
# What is correct?
|
# The paper numbers result in an exp factor of 4.2 in the middle of this block, but keeping at 6
|
||||||
['ir_r3_k5_s2_e6_c160_se0.25'], # hard-swish
|
# results in a param count closer to 5.4m
|
||||||
|
['ir_r1_k5_s2_e6_c160_se0.25', 'ir_r1_k5_s1_e6_c160_se0.25', 'ir_r1_k5_s1_e6_c160_se0.25'], # hard-swish
|
||||||
# stage 6, 7x7 in
|
# stage 6, 7x7 in
|
||||||
['cn_r1_k1_s1_c960'], # hard-swish
|
['cn_r1_k1_s1_c960'], # hard-swish
|
||||||
]
|
]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user