add seqeezenet series

This commit is contained in:
lyuwenyu 2021-03-25 15:10:17 +08:00
parent 931f3f1876
commit aa653c8620

View File

@ -10,6 +10,13 @@ from ppcls.modeling.architectures import squeezenet as _squeezenet
from ppcls.modeling.architectures import densenet as _densenet
from ppcls.modeling.architectures import inception_v3 as _inception_v3
from ppcls.modeling.architectures import inception_v4 as _inception_v4
from ppcls.modeling.architectures import googlenet as _googlenet
from ppcls.modeling.architectures import shufflenet_v2 as _shufflenet_v2
from ppcls.modeling.architectures import mobilenet_v1 as _mobilenet_v1
from ppcls.modeling.architectures import mobilenet_v2 as _mobilenet_v2
from ppcls.modeling.architectures import mobilenet_v3 as _mobilenet_v3
from ppcls.modeling.architectures import resnext as _resnext
# _checkpoints = {
# 'ResNet18': 'https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet18_pretrained.pdparams',
@ -188,3 +195,35 @@ def ResNet152(**kwargs):
model.set_state_dict(paddle.load(path))
return model
def SqueezeNet1_0(**kwargs):
'''SqueezeNet1_0
'''
pretrained = kwargs.pop('pretrained', False)
model = _squeezenet.SqueezeNet1_0(**kwargs)
if pretrained:
assert 'SqueezeNet1_0' in _checkpoints, 'Not provide `SqueezeNet1_0` pretrained model.'
path = paddle.utils.download.get_weights_path_from_url(_checkpoints['SqueezeNet1_0'])
model.set_state_dict(paddle.load(path))
return model
def SqueezeNet1_1(**kwargs):
'''SqueezeNet1_1
'''
pretrained = kwargs.pop('pretrained', False)
model = _squeezenet.SqueezeNet1_1(**kwargs)
if pretrained:
assert 'SqueezeNet1_1' in _checkpoints, 'Not provide `SqueezeNet1_1` pretrained model.'
path = paddle.utils.download.get_weights_path_from_url(_checkpoints['SqueezeNet1_1'])
model.set_state_dict(paddle.load(path))
return model