add `_load_pretrained_parameters`

pull/701/head
lyuwenyu 2021-04-02 20:01:55 +08:00
parent 48cdd3623d
commit bdd8178c15
1 changed files with 6 additions and 3 deletions

View File

@ -67,9 +67,12 @@ def _load_pretrained_urls():
_checkpoints = _load_pretrained_urls()
def _load_parameters(model, ):
pass
def _load_pretrained_parameters(model, name):
assert name in _checkpoints, 'Not provide {} pretrained model.'.format(name)
path = paddle.utils.download.get_weights_path_from_url(_checkpoints[name])
model.set_state_dict(paddle.load(path))
return model
def AlexNet(pretrained=False, **kwargs):
'''AlexNet