add `_load_pretrained_parameters`
parent
48cdd3623d
commit
bdd8178c15
|
@ -67,9 +67,12 @@ def _load_pretrained_urls():
|
|||
_checkpoints = _load_pretrained_urls()
|
||||
|
||||
|
||||
def _load_parameters(model, ):
|
||||
pass
|
||||
|
||||
def _load_pretrained_parameters(model, name):
|
||||
assert name in _checkpoints, 'Not provide {} pretrained model.'.format(name)
|
||||
path = paddle.utils.download.get_weights_path_from_url(_checkpoints[name])
|
||||
model.set_state_dict(paddle.load(path))
|
||||
return model
|
||||
|
||||
|
||||
def AlexNet(pretrained=False, **kwargs):
|
||||
'''AlexNet
|
||||
|
|
Loading…
Reference in New Issue