mirror of
https://github.com/PaddlePaddle/PaddleClas.git
synced 2025-06-03 21:55:06 +08:00
add seqeezenet series
This commit is contained in:
parent
931f3f1876
commit
aa653c8620
39
hubconf.py
39
hubconf.py
@ -10,6 +10,13 @@ from ppcls.modeling.architectures import squeezenet as _squeezenet
|
||||
from ppcls.modeling.architectures import densenet as _densenet
|
||||
from ppcls.modeling.architectures import inception_v3 as _inception_v3
|
||||
from ppcls.modeling.architectures import inception_v4 as _inception_v4
|
||||
from ppcls.modeling.architectures import googlenet as _googlenet
|
||||
from ppcls.modeling.architectures import shufflenet_v2 as _shufflenet_v2
|
||||
from ppcls.modeling.architectures import mobilenet_v1 as _mobilenet_v1
|
||||
from ppcls.modeling.architectures import mobilenet_v2 as _mobilenet_v2
|
||||
from ppcls.modeling.architectures import mobilenet_v3 as _mobilenet_v3
|
||||
from ppcls.modeling.architectures import resnext as _resnext
|
||||
|
||||
|
||||
# _checkpoints = {
|
||||
# 'ResNet18': 'https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet18_pretrained.pdparams',
|
||||
@ -188,3 +195,35 @@ def ResNet152(**kwargs):
|
||||
model.set_state_dict(paddle.load(path))
|
||||
|
||||
return model
|
||||
|
||||
|
||||
|
||||
def SqueezeNet1_0(**kwargs):
|
||||
'''SqueezeNet1_0
|
||||
'''
|
||||
pretrained = kwargs.pop('pretrained', False)
|
||||
|
||||
model = _squeezenet.SqueezeNet1_0(**kwargs)
|
||||
if pretrained:
|
||||
assert 'SqueezeNet1_0' in _checkpoints, 'Not provide `SqueezeNet1_0` pretrained model.'
|
||||
path = paddle.utils.download.get_weights_path_from_url(_checkpoints['SqueezeNet1_0'])
|
||||
model.set_state_dict(paddle.load(path))
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def SqueezeNet1_1(**kwargs):
|
||||
'''SqueezeNet1_1
|
||||
'''
|
||||
pretrained = kwargs.pop('pretrained', False)
|
||||
|
||||
model = _squeezenet.SqueezeNet1_1(**kwargs)
|
||||
if pretrained:
|
||||
assert 'SqueezeNet1_1' in _checkpoints, 'Not provide `SqueezeNet1_1` pretrained model.'
|
||||
path = paddle.utils.download.get_weights_path_from_url(_checkpoints['SqueezeNet1_1'])
|
||||
model.set_state_dict(paddle.load(path))
|
||||
|
||||
return model
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user