mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
🎨 cleanup and add a couple comments
This commit is contained in:
parent
2b6ade24b3
commit
e65a2cba3d
@ -7,14 +7,12 @@ from typing import Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.hub import HASH_REGEX, download_url_to_file, urlparse, load_state_dict_from_url
|
from torch.hub import HASH_REGEX, download_url_to_file, urlparse, load_state_dict_from_url
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from torch.hub import get_dir
|
from torch.hub import get_dir
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from torch.hub import _get_torch_home as get_dir
|
from torch.hub import _get_torch_home as get_dir
|
||||||
|
|
||||||
from timm import __version__
|
from timm import __version__
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from huggingface_hub import HfApi, HfFolder, Repository, cached_download, hf_hub_url
|
from huggingface_hub import HfApi, HfFolder, Repository, cached_download, hf_hub_url
|
||||||
cached_download = partial(cached_download, library_name="timm", library_version=__version__)
|
cached_download = partial(cached_download, library_name="timm", library_version=__version__)
|
||||||
@ -158,14 +156,15 @@ def push_to_hf_hub(
|
|||||||
revision=revision,
|
revision=revision,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Prepare a default model card that includes the necessary tags to enable inference.
|
||||||
|
readme_text = f'---\ntags:\n- image-classification\n- timm\nlibrary_tag: timm\n---\n# Model card for {repo_name}'
|
||||||
with repo.commit(commit_message):
|
with repo.commit(commit_message):
|
||||||
# Save model weights and config
|
# Save model weights and config.
|
||||||
save_pretrained_for_hf(model, repo.local_dir, **config_kwargs)
|
save_pretrained_for_hf(model, repo.local_dir, **config_kwargs)
|
||||||
|
|
||||||
# Save a model card if it doesn't exist, enabling inference.
|
# Save a model card if it doesn't exist.
|
||||||
readme_path = Path(repo.local_dir) / 'README.md'
|
readme_path = Path(repo.local_dir) / 'README.md'
|
||||||
readme_txt = f'---\ntags:\n- image-classification\n- timm\nlibrary_tag: timm\n---\n# Model card for {repo_name}'
|
|
||||||
if not readme_path.exists():
|
if not readme_path.exists():
|
||||||
readme_path.write_text(readme_txt)
|
readme_path.write_text(readme_text)
|
||||||
|
|
||||||
return repo.git_remote_url()
|
return repo.git_remote_url()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user