add `_load_pretrained_parameters`
parent
48cdd3623d
commit
bdd8178c15
|
@ -67,8 +67,11 @@ def _load_pretrained_urls():
|
||||||
_checkpoints = _load_pretrained_urls()
|
_checkpoints = _load_pretrained_urls()
|
||||||
|
|
||||||
|
|
||||||
def _load_parameters(model, ):
|
def _load_pretrained_parameters(model, name):
|
||||||
pass
|
assert name in _checkpoints, 'Not provide {} pretrained model.'.format(name)
|
||||||
|
path = paddle.utils.download.get_weights_path_from_url(_checkpoints[name])
|
||||||
|
model.set_state_dict(paddle.load(path))
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
def AlexNet(pretrained=False, **kwargs):
|
def AlexNet(pretrained=False, **kwargs):
|
||||||
|
|
Loading…
Reference in New Issue