mirror of
https://github.com/PaddlePaddle/PaddleClas.git
synced 2025-06-03 21:55:06 +08:00
Update mobilenet_v1.py
This commit is contained in:
parent
5f3456767d
commit
86eb2bc693
@ -95,7 +95,7 @@ class DepthwiseSeparable(TheseusLayer):
|
||||
|
||||
|
||||
class MobileNet(TheseusLayer):
|
||||
def __init__(self, scale=1.0, class_dim=1000):
|
||||
def __init__(self, scale=1.0, class_num=1000):
|
||||
super(MobileNet, self).__init__()
|
||||
self.scale = scale
|
||||
self.block_list = []
|
||||
@ -135,7 +135,7 @@ class MobileNet(TheseusLayer):
|
||||
|
||||
self.out = Linear(
|
||||
int(1024 * scale),
|
||||
class_dim,
|
||||
class_num,
|
||||
weight_attr=ParamAttr(initializer=KaimingNormal()))
|
||||
|
||||
|
||||
@ -150,21 +150,57 @@ class MobileNet(TheseusLayer):
|
||||
|
||||
|
||||
def MobileNetV1_x0_25(**args):
|
||||
"""
|
||||
MobileNetV1_x0_25
|
||||
Args:
|
||||
pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
|
||||
kwargs:
|
||||
class_num: int=1000. Output dim of last fc layer.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `MobileNetV1_x0_25` model depends on args.
|
||||
"""
|
||||
model = MobileNet(scale=0.25, **args)
|
||||
return model
|
||||
|
||||
|
||||
def MobileNetV1_x0_5(**args):
|
||||
"""
|
||||
MobileNetV1_x0_5
|
||||
Args:
|
||||
pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
|
||||
kwargs:
|
||||
class_num: int=1000. Output dim of last fc layer.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `MobileNetV1_x0_5` model depends on args.
|
||||
"""
|
||||
model = MobileNet(scale=0.5, **args)
|
||||
return model
|
||||
|
||||
|
||||
def MobileNetV1_x0_75(**args):
|
||||
"""
|
||||
MobileNetV1_x0_75
|
||||
Args:
|
||||
pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
|
||||
kwargs:
|
||||
class_num: int=1000. Output dim of last fc layer.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `MobileNetV1_x0_75` model depends on args.
|
||||
"""
|
||||
model = MobileNet(scale=0.75, **args)
|
||||
return model
|
||||
|
||||
|
||||
def MobileNetV1(**args):
|
||||
"""
|
||||
MobileNetV1
|
||||
Args:
|
||||
pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
|
||||
kwargs:
|
||||
class_num: int=1000. Output dim of last fc layer.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `MobileNetV1` model depends on args.
|
||||
"""
|
||||
model = MobileNet(scale=1.0, **args)
|
||||
return model
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user