refactor push_to_hub helper
parent
ae0a0db7de
commit
9b114754db
|
@ -3,10 +3,12 @@ import logging
|
|||
import os
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.hub import HASH_REGEX, download_url_to_file, urlparse
|
||||
|
||||
try:
|
||||
from torch.hub import get_dir
|
||||
except ImportError:
|
||||
|
@ -15,7 +17,10 @@ except ImportError:
|
|||
from timm import __version__
|
||||
|
||||
try:
|
||||
from huggingface_hub import HfApi, HfFolder, Repository, hf_hub_download, hf_hub_url
|
||||
from huggingface_hub import (create_repo, get_hf_file_metadata,
|
||||
hf_hub_download, hf_hub_url,
|
||||
repo_type_and_id_from_hf_id, upload_folder)
|
||||
from huggingface_hub.utils import EntryNotFoundError
|
||||
hf_hub_download = partial(hf_hub_download, library_name="timm", library_version=__version__)
|
||||
_has_hf_hub = True
|
||||
except ImportError:
|
||||
|
@ -121,56 +126,45 @@ def save_for_hf(model, save_directory, model_config=None):
|
|||
|
||||
def push_to_hf_hub(
|
||||
model,
|
||||
local_dir,
|
||||
repo_namespace_or_url=None,
|
||||
commit_message='Add model',
|
||||
use_auth_token=True,
|
||||
git_email=None,
|
||||
git_user=None,
|
||||
revision=None,
|
||||
model_config=None,
|
||||
repo_id: str,
|
||||
commit_message: str ='Add model',
|
||||
token: Optional[str] = None,
|
||||
revision: Optional[str] = None,
|
||||
private: bool = False,
|
||||
create_pr: bool = False,
|
||||
model_config: Optional[dict] = None,
|
||||
):
|
||||
if isinstance(use_auth_token, str):
|
||||
token = use_auth_token
|
||||
else:
|
||||
token = HfFolder.get_token()
|
||||
if token is None:
|
||||
raise ValueError(
|
||||
"You must login to the Hugging Face hub on this computer by typing `huggingface-cli login` and "
|
||||
"entering your credentials to use `use_auth_token=True`. Alternatively, you can pass your own "
|
||||
"token as the `use_auth_token` argument."
|
||||
)
|
||||
|
||||
if repo_namespace_or_url:
|
||||
repo_owner, repo_name = repo_namespace_or_url.rstrip('/').split('/')[-2:]
|
||||
else:
|
||||
repo_owner = HfApi().whoami(token)['name']
|
||||
repo_name = Path(local_dir).name
|
||||
|
||||
repo_id = f'{repo_owner}/{repo_name}'
|
||||
repo_url = f'https://huggingface.co/{repo_id}'
|
||||
|
||||
# Create repo if doesn't exist yet
|
||||
HfApi().create_repo(repo_id, token=use_auth_token, exist_ok=True)
|
||||
repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True)
|
||||
|
||||
repo = Repository(
|
||||
local_dir,
|
||||
clone_from=repo_url,
|
||||
use_auth_token=use_auth_token,
|
||||
git_user=git_user,
|
||||
git_email=git_email,
|
||||
revision=revision,
|
||||
)
|
||||
# Infer complete repo_id from repo_url
|
||||
# Can be different from the input `repo_id` if repo_owner was implicit
|
||||
_, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url)
|
||||
repo_id = f"{repo_owner}/{repo_name}"
|
||||
|
||||
# Prepare a default model card that includes the necessary tags to enable inference.
|
||||
readme_text = f'---\ntags:\n- image-classification\n- timm\nlibrary_tag: timm\n---\n# Model card for {repo_name}'
|
||||
with repo.commit(commit_message):
|
||||
# Check if README file already exist in repo
|
||||
try:
|
||||
get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename="README.md", revision=revision))
|
||||
has_readme = True
|
||||
except EntryNotFoundError:
|
||||
has_readme = False
|
||||
|
||||
# Dump model and push to Hub
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
# Save model weights and config.
|
||||
save_for_hf(model, repo.local_dir, model_config=model_config)
|
||||
save_for_hf(model, tmpdir, model_config=model_config)
|
||||
|
||||
# Save a model card if it doesn't exist.
|
||||
readme_path = Path(repo.local_dir) / 'README.md'
|
||||
if not readme_path.exists():
|
||||
# Add readme if does not exist
|
||||
if not has_readme:
|
||||
readme_path = Path(tmpdir) / "README.md"
|
||||
readme_text = f'---\ntags:\n- image-classification\n- timm\nlibrary_tag: timm\n---\n# Model card for {repo_id}'
|
||||
readme_path.write_text(readme_text)
|
||||
|
||||
return repo.git_remote_url()
|
||||
# Upload model and return
|
||||
return upload_folder(
|
||||
repo_id=repo_id,
|
||||
folder_path=tmpdir,
|
||||
revision=revision,
|
||||
create_pr=create_pr,
|
||||
commit_message=commit_message,
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue