Merge pull request #1547 from Wauplin/1546-fix-hf-hub-integration

Create repo before cloning with Repository.clone_from
This commit is contained in:
Ross Wightman 2022-11-16 09:19:56 -08:00 committed by GitHub
commit 25ffac6880
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -3,10 +3,12 @@ import logging
import os import os
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
from typing import Union from tempfile import TemporaryDirectory
from typing import Optional, Union
import torch import torch
from torch.hub import HASH_REGEX, download_url_to_file, urlparse from torch.hub import HASH_REGEX, download_url_to_file, urlparse
try: try:
from torch.hub import get_dir from torch.hub import get_dir
except ImportError: except ImportError:
@ -15,7 +17,10 @@ except ImportError:
from timm import __version__ from timm import __version__
try: try:
from huggingface_hub import HfApi, HfFolder, Repository, hf_hub_download, hf_hub_url from huggingface_hub import (create_repo, get_hf_file_metadata,
hf_hub_download, hf_hub_url,
repo_type_and_id_from_hf_id, upload_folder)
from huggingface_hub.utils import EntryNotFoundError
hf_hub_download = partial(hf_hub_download, library_name="timm", library_version=__version__) hf_hub_download = partial(hf_hub_download, library_name="timm", library_version=__version__)
_has_hf_hub = True _has_hf_hub = True
except ImportError: except ImportError:
@ -121,53 +126,45 @@ def save_for_hf(model, save_directory, model_config=None):
def push_to_hf_hub( def push_to_hf_hub(
model, model,
local_dir, repo_id: str,
repo_namespace_or_url=None, commit_message: str ='Add model',
commit_message='Add model', token: Optional[str] = None,
use_auth_token=True, revision: Optional[str] = None,
git_email=None, private: bool = False,
git_user=None, create_pr: bool = False,
revision=None, model_config: Optional[dict] = None,
model_config=None,
): ):
if repo_namespace_or_url: # Create repo if doesn't exist yet
repo_owner, repo_name = repo_namespace_or_url.rstrip('/').split('/')[-2:] repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True)
else:
if isinstance(use_auth_token, str):
token = use_auth_token
else:
token = HfFolder.get_token()
if token is None: # Infer complete repo_id from repo_url
raise ValueError( # Can be different from the input `repo_id` if repo_owner was implicit
"You must login to the Hugging Face hub on this computer by typing `transformers-cli login` and " _, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url)
"entering your credentials to use `use_auth_token=True`. Alternatively, you can pass your own " repo_id = f"{repo_owner}/{repo_name}"
"token as the `use_auth_token` argument."
)
repo_owner = HfApi().whoami(token)['name'] # Check if README file already exist in repo
repo_name = Path(local_dir).name try:
get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename="README.md", revision=revision))
has_readme = True
except EntryNotFoundError:
has_readme = False
repo_url = f'https://huggingface.co/{repo_owner}/{repo_name}' # Dump model and push to Hub
with TemporaryDirectory() as tmpdir:
repo = Repository(
local_dir,
clone_from=repo_url,
use_auth_token=use_auth_token,
git_user=git_user,
git_email=git_email,
revision=revision,
)
# Prepare a default model card that includes the necessary tags to enable inference.
readme_text = f'---\ntags:\n- image-classification\n- timm\nlibrary_tag: timm\n---\n# Model card for {repo_name}'
with repo.commit(commit_message):
# Save model weights and config. # Save model weights and config.
save_for_hf(model, repo.local_dir, model_config=model_config) save_for_hf(model, tmpdir, model_config=model_config)
# Save a model card if it doesn't exist. # Add readme if does not exist
readme_path = Path(repo.local_dir) / 'README.md' if not has_readme:
if not readme_path.exists(): readme_path = Path(tmpdir) / "README.md"
readme_text = f'---\ntags:\n- image-classification\n- timm\nlibrary_tag: timm\n---\n# Model card for {repo_id}'
readme_path.write_text(readme_text) readme_path.write_text(readme_text)
return repo.git_remote_url() # Upload model and return
return upload_folder(
repo_id=repo_id,
folder_path=tmpdir,
revision=revision,
create_pr=create_pr,
commit_message=commit_message,
)