Merge pull request #1351 from nateraw/use-hf-hub-download
Use hf_hub_download instead of cached_downloadfix_tests
commit
4283c0c478
|
@ -14,11 +14,11 @@ except ImportError:
|
||||||
|
|
||||||
from timm import __version__
|
from timm import __version__
|
||||||
try:
|
try:
|
||||||
from huggingface_hub import HfApi, HfFolder, Repository, cached_download, hf_hub_url
|
from huggingface_hub import HfApi, HfFolder, Repository, hf_hub_download, hf_hub_url
|
||||||
cached_download = partial(cached_download, library_name="timm", library_version=__version__)
|
hf_hub_download = partial(hf_hub_download, library_name="timm", library_version=__version__)
|
||||||
_has_hf_hub = True
|
_has_hf_hub = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
cached_download = None
|
hf_hub_download = None
|
||||||
_has_hf_hub = False
|
_has_hf_hub = False
|
||||||
|
|
||||||
_logger = logging.getLogger(__name__)
|
_logger = logging.getLogger(__name__)
|
||||||
|
@ -78,8 +78,7 @@ def load_cfg_from_json(json_file: Union[str, os.PathLike]):
|
||||||
|
|
||||||
def _download_from_hf(model_id: str, filename: str):
|
def _download_from_hf(model_id: str, filename: str):
|
||||||
hf_model_id, hf_revision = hf_split(model_id)
|
hf_model_id, hf_revision = hf_split(model_id)
|
||||||
url = hf_hub_url(hf_model_id, filename, revision=hf_revision)
|
return hf_hub_download(hf_model_id, filename, revision=hf_revision, cache_dir=get_cache_dir('hf'))
|
||||||
return cached_download(url, cache_dir=get_cache_dir('hf'))
|
|
||||||
|
|
||||||
|
|
||||||
def load_model_config_from_hf(model_id: str):
|
def load_model_config_from_hf(model_id: str):
|
||||||
|
|
Loading…
Reference in New Issue