polish download pretrain model and add comments to the func

pull/19/head
WuHaobo 2020-04-13 08:00:19 +08:00
parent fef101ddce
commit a39ae69ce7
1 changed files with 14 additions and 15 deletions

View File

@ -17,10 +17,10 @@ from __future__ import division
from __future__ import print_function
import os
import shutil
import requests
import tqdm
import shutil
import tarfile
import tqdm
import zipfile
from ppcls.utils.check import check_architecture
@ -40,18 +40,6 @@ class UrlError(Exception):
super(UrlError, self).__init__(message)
class ModelNameError(Exception):
""" ModelNameError
"""
def __init__(self, message='', architecture=''):
similar_names = similar_architectures(architecture)
model_list = ', '.join(similar_names)
message += '\n{} is not exist. \nMaybe you want: [{}]'.format(
architecture, model_list)
super(ModelNameError, self).__init__(message)
class RetryError(Exception):
""" RetryError
"""
@ -172,7 +160,18 @@ def _decompress(fname):
def get(architecture, path, decompress=True):
check_architecture(architecture)
"""
Get the pretrained model.
Args:
architecture: the name of which architecture to get.
If the name is not exist, will raises UrlError with error code 404.
path: which dir to save the pretrained model.
decompress: decompress the download or not.
Raises:
RetryError or UrlError if download failed
"""
url = _get_url(architecture)
fname = _download(url, path)
if decompress: _decompress(fname)