Update mobilenet_v1.py

This commit is contained in:
Bin Lu 2021-05-28 10:36:48 +08:00 committed by GitHub
parent 5f3456767d
commit 86eb2bc693
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -95,7 +95,7 @@ class DepthwiseSeparable(TheseusLayer):
class MobileNet(TheseusLayer):
def __init__(self, scale=1.0, class_dim=1000):
def __init__(self, scale=1.0, class_num=1000):
super(MobileNet, self).__init__()
self.scale = scale
self.block_list = []
@ -135,7 +135,7 @@ class MobileNet(TheseusLayer):
self.out = Linear(
int(1024 * scale),
class_dim,
class_num,
weight_attr=ParamAttr(initializer=KaimingNormal()))
@ -150,21 +150,57 @@ class MobileNet(TheseusLayer):
def MobileNetV1_x0_25(**args):
"""
MobileNetV1_x0_25
Args:
pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
kwargs:
class_num: int=1000. Output dim of last fc layer.
Returns:
model: nn.Layer. Specific `MobileNetV1_x0_25` model depends on args.
"""
model = MobileNet(scale=0.25, **args)
return model
def MobileNetV1_x0_5(**args):
"""
MobileNetV1_x0_5
Args:
pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
kwargs:
class_num: int=1000. Output dim of last fc layer.
Returns:
model: nn.Layer. Specific `MobileNetV1_x0_5` model depends on args.
"""
model = MobileNet(scale=0.5, **args)
return model
def MobileNetV1_x0_75(**args):
"""
MobileNetV1_x0_75
Args:
pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
kwargs:
class_num: int=1000. Output dim of last fc layer.
Returns:
model: nn.Layer. Specific `MobileNetV1_x0_75` model depends on args.
"""
model = MobileNet(scale=0.75, **args)
return model
def MobileNetV1(**args):
"""
MobileNetV1
Args:
pretrained: bool=False. If `True` load pretrained parameters, `False` otherwise.
kwargs:
class_num: int=1000. Output dim of last fc layer.
Returns:
model: nn.Layer. Specific `MobileNetV1` model depends on args.
"""
model = MobileNet(scale=1.0, **args)
return model