mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge pull request #1547 from Wauplin/1546-fix-hf-hub-integration
Create repo before cloning with Repository.clone_from
This commit is contained in:
commit
25ffac6880
@ -3,10 +3,12 @@ import logging
|
|||||||
import os
|
import os
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from tempfile import TemporaryDirectory
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.hub import HASH_REGEX, download_url_to_file, urlparse
|
from torch.hub import HASH_REGEX, download_url_to_file, urlparse
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from torch.hub import get_dir
|
from torch.hub import get_dir
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -15,7 +17,10 @@ except ImportError:
|
|||||||
from timm import __version__
|
from timm import __version__
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from huggingface_hub import HfApi, HfFolder, Repository, hf_hub_download, hf_hub_url
|
from huggingface_hub import (create_repo, get_hf_file_metadata,
|
||||||
|
hf_hub_download, hf_hub_url,
|
||||||
|
repo_type_and_id_from_hf_id, upload_folder)
|
||||||
|
from huggingface_hub.utils import EntryNotFoundError
|
||||||
hf_hub_download = partial(hf_hub_download, library_name="timm", library_version=__version__)
|
hf_hub_download = partial(hf_hub_download, library_name="timm", library_version=__version__)
|
||||||
_has_hf_hub = True
|
_has_hf_hub = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -121,53 +126,45 @@ def save_for_hf(model, save_directory, model_config=None):
|
|||||||
|
|
||||||
def push_to_hf_hub(
|
def push_to_hf_hub(
|
||||||
model,
|
model,
|
||||||
local_dir,
|
repo_id: str,
|
||||||
repo_namespace_or_url=None,
|
commit_message: str ='Add model',
|
||||||
commit_message='Add model',
|
token: Optional[str] = None,
|
||||||
use_auth_token=True,
|
revision: Optional[str] = None,
|
||||||
git_email=None,
|
private: bool = False,
|
||||||
git_user=None,
|
create_pr: bool = False,
|
||||||
revision=None,
|
model_config: Optional[dict] = None,
|
||||||
model_config=None,
|
|
||||||
):
|
):
|
||||||
if repo_namespace_or_url:
|
# Create repo if doesn't exist yet
|
||||||
repo_owner, repo_name = repo_namespace_or_url.rstrip('/').split('/')[-2:]
|
repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True)
|
||||||
else:
|
|
||||||
if isinstance(use_auth_token, str):
|
|
||||||
token = use_auth_token
|
|
||||||
else:
|
|
||||||
token = HfFolder.get_token()
|
|
||||||
|
|
||||||
if token is None:
|
# Infer complete repo_id from repo_url
|
||||||
raise ValueError(
|
# Can be different from the input `repo_id` if repo_owner was implicit
|
||||||
"You must login to the Hugging Face hub on this computer by typing `transformers-cli login` and "
|
_, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url)
|
||||||
"entering your credentials to use `use_auth_token=True`. Alternatively, you can pass your own "
|
repo_id = f"{repo_owner}/{repo_name}"
|
||||||
"token as the `use_auth_token` argument."
|
|
||||||
)
|
|
||||||
|
|
||||||
repo_owner = HfApi().whoami(token)['name']
|
# Check if README file already exist in repo
|
||||||
repo_name = Path(local_dir).name
|
try:
|
||||||
|
get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename="README.md", revision=revision))
|
||||||
|
has_readme = True
|
||||||
|
except EntryNotFoundError:
|
||||||
|
has_readme = False
|
||||||
|
|
||||||
repo_url = f'https://huggingface.co/{repo_owner}/{repo_name}'
|
# Dump model and push to Hub
|
||||||
|
with TemporaryDirectory() as tmpdir:
|
||||||
repo = Repository(
|
|
||||||
local_dir,
|
|
||||||
clone_from=repo_url,
|
|
||||||
use_auth_token=use_auth_token,
|
|
||||||
git_user=git_user,
|
|
||||||
git_email=git_email,
|
|
||||||
revision=revision,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Prepare a default model card that includes the necessary tags to enable inference.
|
|
||||||
readme_text = f'---\ntags:\n- image-classification\n- timm\nlibrary_tag: timm\n---\n# Model card for {repo_name}'
|
|
||||||
with repo.commit(commit_message):
|
|
||||||
# Save model weights and config.
|
# Save model weights and config.
|
||||||
save_for_hf(model, repo.local_dir, model_config=model_config)
|
save_for_hf(model, tmpdir, model_config=model_config)
|
||||||
|
|
||||||
# Save a model card if it doesn't exist.
|
# Add readme if does not exist
|
||||||
readme_path = Path(repo.local_dir) / 'README.md'
|
if not has_readme:
|
||||||
if not readme_path.exists():
|
readme_path = Path(tmpdir) / "README.md"
|
||||||
|
readme_text = f'---\ntags:\n- image-classification\n- timm\nlibrary_tag: timm\n---\n# Model card for {repo_id}'
|
||||||
readme_path.write_text(readme_text)
|
readme_path.write_text(readme_text)
|
||||||
|
|
||||||
return repo.git_remote_url()
|
# Upload model and return
|
||||||
|
return upload_folder(
|
||||||
|
repo_id=repo_id,
|
||||||
|
folder_path=tmpdir,
|
||||||
|
revision=revision,
|
||||||
|
create_pr=create_pr,
|
||||||
|
commit_message=commit_message,
|
||||||
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user