mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Fix stem width for really small mobilenetv3 arch defs
This commit is contained in:
parent
d6cf6b3a3a
commit
a3dfd180aa
@ -110,9 +110,10 @@ class MobileNetV3(nn.Module):
|
||||
* LCNet - https://arxiv.org/abs/2109.15099
|
||||
"""
|
||||
|
||||
def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=16, num_features=1280, head_bias=True,
|
||||
pad_type='', act_layer=None, norm_layer=None, se_layer=None, se_from_exp=True,
|
||||
round_chs_fn=round_channels, drop_rate=0., drop_path_rate=0., global_pool='avg'):
|
||||
def __init__(
|
||||
self, block_args, num_classes=1000, in_chans=3, stem_size=16, fix_stem=False, num_features=1280,
|
||||
head_bias=True, pad_type='', act_layer=None, norm_layer=None, se_layer=None, se_from_exp=True,
|
||||
round_chs_fn=round_channels, drop_rate=0., drop_path_rate=0., global_pool='avg'):
|
||||
super(MobileNetV3, self).__init__()
|
||||
act_layer = act_layer or nn.ReLU
|
||||
norm_layer = norm_layer or nn.BatchNorm2d
|
||||
@ -122,7 +123,8 @@ class MobileNetV3(nn.Module):
|
||||
self.drop_rate = drop_rate
|
||||
|
||||
# Stem
|
||||
stem_size = round_chs_fn(stem_size)
|
||||
if not fix_stem:
|
||||
stem_size = round_chs_fn(stem_size)
|
||||
self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type)
|
||||
self.bn1 = norm_layer(stem_size)
|
||||
self.act1 = act_layer(inplace=True)
|
||||
@ -188,8 +190,8 @@ class MobileNetV3Features(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck', in_chans=3,
|
||||
stem_size=16, output_stride=32, pad_type='', round_chs_fn=round_channels, se_from_exp=True,
|
||||
act_layer=None, norm_layer=None, se_layer=None, drop_rate=0., drop_path_rate=0.):
|
||||
stem_size=16, fix_stem=False, output_stride=32, pad_type='', round_chs_fn=round_channels,
|
||||
se_from_exp=True, act_layer=None, norm_layer=None, se_layer=None, drop_rate=0., drop_path_rate=0.):
|
||||
super(MobileNetV3Features, self).__init__()
|
||||
act_layer = act_layer or nn.ReLU
|
||||
norm_layer = norm_layer or nn.BatchNorm2d
|
||||
@ -197,7 +199,8 @@ class MobileNetV3Features(nn.Module):
|
||||
self.drop_rate = drop_rate
|
||||
|
||||
# Stem
|
||||
stem_size = round_chs_fn(stem_size)
|
||||
if not fix_stem:
|
||||
stem_size = round_chs_fn(stem_size)
|
||||
self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type)
|
||||
self.bn1 = norm_layer(stem_size)
|
||||
self.act1 = act_layer(inplace=True)
|
||||
@ -381,6 +384,7 @@ def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwarg
|
||||
block_args=decode_arch_def(arch_def),
|
||||
num_features=num_features,
|
||||
stem_size=16,
|
||||
fix_stem=channel_multiplier < 0.75,
|
||||
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
|
||||
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
|
||||
act_layer=act_layer,
|
||||
|
Loading…
x
Reference in New Issue
Block a user