mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Placeholder for new mnv3 model
This commit is contained in:
parent
ed7aaf8d6d
commit
76b0e9931a
@ -453,7 +453,8 @@ def _gen_mobilenet_v3_rw(
|
|||||||
|
|
||||||
|
|
||||||
def _gen_mobilenet_v3(
|
def _gen_mobilenet_v3(
|
||||||
variant: str, channel_multiplier: float = 1.0, group_size=None, pretrained: bool = False, **kwargs
|
variant: str, channel_multiplier: float = 1.0, depth_multiplier: float = 1.0,
|
||||||
|
group_size=None, pretrained: bool = False, **kwargs
|
||||||
) -> MobileNetV3:
|
) -> MobileNetV3:
|
||||||
"""Creates a MobileNet-V3 model.
|
"""Creates a MobileNet-V3 model.
|
||||||
|
|
||||||
@ -537,7 +538,7 @@ def _gen_mobilenet_v3(
|
|||||||
]
|
]
|
||||||
se_layer = partial(SqueezeExcite, gate_layer='hard_sigmoid', force_act_layer=nn.ReLU, rd_round_fn=round_channels)
|
se_layer = partial(SqueezeExcite, gate_layer='hard_sigmoid', force_act_layer=nn.ReLU, rd_round_fn=round_channels)
|
||||||
model_kwargs = dict(
|
model_kwargs = dict(
|
||||||
block_args=decode_arch_def(arch_def, group_size=group_size),
|
block_args=decode_arch_def(arch_def, depth_multiplier=depth_multiplier, group_size=group_size),
|
||||||
num_features=num_features,
|
num_features=num_features,
|
||||||
stem_size=16,
|
stem_size=16,
|
||||||
fix_stem=channel_multiplier < 0.75,
|
fix_stem=channel_multiplier < 0.75,
|
||||||
@ -927,6 +928,9 @@ default_cfgs = generate_default_cfgs({
|
|||||||
origin_url='https://github.com/Alibaba-MIIL/ImageNet21K',
|
origin_url='https://github.com/Alibaba-MIIL/ImageNet21K',
|
||||||
paper_ids='arXiv:2104.10972v4',
|
paper_ids='arXiv:2104.10972v4',
|
||||||
interpolation='bilinear', mean=(0., 0., 0.), std=(1., 1., 1.), num_classes=11221),
|
interpolation='bilinear', mean=(0., 0., 0.), std=(1., 1., 1.), num_classes=11221),
|
||||||
|
'mobilenetv3_large_150d.untrained': _cfg(
|
||||||
|
#hf_hub_id='timm/',
|
||||||
|
),
|
||||||
|
|
||||||
'mobilenetv3_small_050.lamb_in1k': _cfg(
|
'mobilenetv3_small_050.lamb_in1k': _cfg(
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_small_050_lambc-4b7bbe87.pth',
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_small_050_lambc-4b7bbe87.pth',
|
||||||
@ -1099,6 +1103,11 @@ def mobilenetv3_large_100(pretrained: bool = False, **kwargs) -> MobileNetV3:
|
|||||||
model = _gen_mobilenet_v3('mobilenetv3_large_100', 1.0, pretrained=pretrained, **kwargs)
|
model = _gen_mobilenet_v3('mobilenetv3_large_100', 1.0, pretrained=pretrained, **kwargs)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def mobilenetv3_large_150d(pretrained: bool = False, **kwargs) -> MobileNetV3:
|
||||||
|
""" MobileNet V3 """
|
||||||
|
model = _gen_mobilenet_v3('mobilenetv3_large_150d', 1.5, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def mobilenetv3_small_050(pretrained: bool = False, **kwargs) -> MobileNetV3:
|
def mobilenetv3_small_050(pretrained: bool = False, **kwargs) -> MobileNetV3:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user