mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add support to load safetensors weights
This commit is contained in:
parent
f35d6ea57b
commit
8470e29541
@ -17,21 +17,26 @@ import os
|
||||
import glob
|
||||
import hashlib
|
||||
from timm.models import load_state_dict
|
||||
import safetensors.torch
|
||||
|
||||
DEFAULT_OUTPUT = "./average.pth"
|
||||
DEFAULT_SAFE_OUTPUT = "./average.safetensors"
|
||||
|
||||
parser = argparse.ArgumentParser(description='PyTorch Checkpoint Averager')
|
||||
parser.add_argument('--input', default='', type=str, metavar='PATH',
|
||||
help='path to base input folder containing checkpoints')
|
||||
parser.add_argument('--filter', default='*.pth.tar', type=str, metavar='WILDCARD',
|
||||
help='checkpoint filter (path wildcard)')
|
||||
parser.add_argument('--output', default='./averaged.pth', type=str, metavar='PATH',
|
||||
help='output filename')
|
||||
parser.add_argument('--output', default=DEFAULT_OUTPUT, type=str, metavar='PATH',
|
||||
help=f'Output filename. Defaults to {DEFAULT_SAFE_OUTPUT} when passing --safetensors.')
|
||||
parser.add_argument('--no-use-ema', dest='no_use_ema', action='store_true',
|
||||
help='Force not using ema version of weights (if present)')
|
||||
parser.add_argument('--no-sort', dest='no_sort', action='store_true',
|
||||
help='Do not sort and select by checkpoint metric, also makes "n" argument irrelevant')
|
||||
parser.add_argument('-n', type=int, default=10, metavar='N',
|
||||
help='Number of checkpoints to average')
|
||||
|
||||
parser.add_argument('--safetensors', action='store_true',
|
||||
help='Save weights using safetensors instead of the default torch way (pickle).')
|
||||
|
||||
def checkpoint_metric(checkpoint_path):
|
||||
if not checkpoint_path or not os.path.isfile(checkpoint_path):
|
||||
@ -55,6 +60,15 @@ def main():
|
||||
# by default sort by checkpoint metric (if present) and avg top n checkpoints
|
||||
args.sort = not args.no_sort
|
||||
|
||||
if args.safetensors and args.output == DEFAULT_OUTPUT:
|
||||
# Default path changes if using safetensors
|
||||
args.output = DEFAULT_SAFE_OUTPUT
|
||||
if args.safetensors and not args.output.endswith(".safetensors"):
|
||||
print(
|
||||
"Warning: saving weights as safetensors but output file extension is not "
|
||||
f"set to '.safetensors': {args.output}"
|
||||
)
|
||||
|
||||
if os.path.exists(args.output):
|
||||
print("Error: Output filename ({}) already exists.".format(args.output))
|
||||
exit(1)
|
||||
@ -107,6 +121,9 @@ def main():
|
||||
v = v.clamp(float32_info.min, float32_info.max)
|
||||
final_state_dict[k] = v.to(dtype=torch.float32)
|
||||
|
||||
if args.safetensors:
|
||||
safetensors.torch.save_file(final_state_dict, args.output)
|
||||
else:
|
||||
try:
|
||||
torch.save(final_state_dict, args.output, _use_new_zipfile_serialization=False)
|
||||
except:
|
||||
|
@ -11,8 +11,8 @@ import torch
|
||||
import argparse
|
||||
import os
|
||||
import hashlib
|
||||
import safetensors.torch
|
||||
import shutil
|
||||
from collections import OrderedDict
|
||||
from timm.models import load_state_dict
|
||||
|
||||
parser = argparse.ArgumentParser(description='PyTorch Checkpoint Cleaner')
|
||||
@ -24,6 +24,8 @@ parser.add_argument('--no-use-ema', dest='no_use_ema', action='store_true',
|
||||
help='use ema version of weights if present')
|
||||
parser.add_argument('--clean-aux-bn', dest='clean_aux_bn', action='store_true',
|
||||
help='remove auxiliary batch norm layers (from SplitBN training) from checkpoint')
|
||||
parser.add_argument('--safetensors', action='store_true',
|
||||
help='Save weights using safetensors instead of the default torch way (pickle).')
|
||||
|
||||
_TEMP_NAME = './_checkpoint.pth'
|
||||
|
||||
@ -35,10 +37,10 @@ def main():
|
||||
print("Error: Output filename ({}) already exists.".format(args.output))
|
||||
exit(1)
|
||||
|
||||
clean_checkpoint(args.checkpoint, args.output, not args.no_use_ema, args.clean_aux_bn)
|
||||
clean_checkpoint(args.checkpoint, args.output, not args.no_use_ema, args.clean_aux_bn, safe_serialization=args.safetensors)
|
||||
|
||||
|
||||
def clean_checkpoint(checkpoint, output='', use_ema=True, clean_aux_bn=False):
|
||||
def clean_checkpoint(checkpoint, output='', use_ema=True, clean_aux_bn=False, safe_serialization: bool=False):
|
||||
# Load an existing checkpoint to CPU, strip everything but the state_dict and re-save
|
||||
if checkpoint and os.path.isfile(checkpoint):
|
||||
print("=> Loading checkpoint '{}'".format(checkpoint))
|
||||
@ -53,6 +55,9 @@ def clean_checkpoint(checkpoint, output='', use_ema=True, clean_aux_bn=False):
|
||||
new_state_dict[name] = v
|
||||
print("=> Loaded state_dict from '{}'".format(checkpoint))
|
||||
|
||||
if safe_serialization:
|
||||
safetensors.torch.save_file(new_state_dict, _TEMP_NAME)
|
||||
else:
|
||||
try:
|
||||
torch.save(new_state_dict, _TEMP_NAME, _use_new_zipfile_serialization=False)
|
||||
except:
|
||||
@ -67,7 +72,7 @@ def clean_checkpoint(checkpoint, output='', use_ema=True, clean_aux_bn=False):
|
||||
else:
|
||||
checkpoint_root = ''
|
||||
checkpoint_base = os.path.splitext(checkpoint)[0]
|
||||
final_filename = '-'.join([checkpoint_base, sha_hash[:8]]) + '.pth'
|
||||
final_filename = '-'.join([checkpoint_base, sha_hash[:8]]) + ('.safetensors' if safe_serialization else '.pth')
|
||||
shutil.move(_TEMP_NAME, os.path.join(checkpoint_root, final_filename))
|
||||
print("=> Saved state_dict to '{}, SHA256: {}'".format(final_filename, sha_hash))
|
||||
return final_filename
|
||||
|
@ -2,3 +2,4 @@ torch>=1.7
|
||||
torchvision
|
||||
pyyaml
|
||||
huggingface_hub
|
||||
safetensors>=0.2
|
@ -7,6 +7,7 @@ import os
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import safetensors.torch
|
||||
|
||||
import timm.models._builder
|
||||
|
||||
@ -26,7 +27,12 @@ def clean_state_dict(state_dict):
|
||||
|
||||
def load_state_dict(checkpoint_path, use_ema=True):
|
||||
if checkpoint_path and os.path.isfile(checkpoint_path):
|
||||
# Check if safetensors or not and load weights accordingly
|
||||
if str(checkpoint_path).endswith(".safetensors"):
|
||||
checkpoint = safetensors.torch.load_file(checkpoint_path, device='cpu')
|
||||
else:
|
||||
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
||||
|
||||
state_dict_key = ''
|
||||
if isinstance(checkpoint, dict):
|
||||
if use_ema and checkpoint.get('state_dict_ema', None) is not None:
|
||||
|
@ -2,19 +2,25 @@ import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Optional, Union
|
||||
|
||||
from typing import Iterable, Optional, Union
|
||||
import torch
|
||||
from torch.hub import HASH_REGEX, download_url_to_file, urlparse
|
||||
import safetensors.torch
|
||||
|
||||
try:
|
||||
from torch.hub import get_dir
|
||||
except ImportError:
|
||||
from torch.hub import _get_torch_home as get_dir
|
||||
|
||||
if sys.version_info >= (3, 8):
|
||||
from typing import Literal
|
||||
else:
|
||||
from typing_extensions import Literal
|
||||
|
||||
from timm import __version__
|
||||
from timm.models._pretrained import filter_pretrained_cfg
|
||||
|
||||
@ -35,6 +41,9 @@ _logger = logging.getLogger(__name__)
|
||||
__all__ = ['get_cache_dir', 'download_cached_file', 'has_hf_hub', 'hf_split', 'load_model_config_from_hf',
|
||||
'load_state_dict_from_hf', 'save_for_hf', 'push_to_hf_hub']
|
||||
|
||||
# Default name for a weights file hosted on the Huggingface Hub.
|
||||
HF_WEIGHTS_NAME = "pytorch_model.bin" # default pytorch pkl
|
||||
HF_SAFE_WEIGHTS_NAME = "model.safetensors" # safetensors version
|
||||
|
||||
def get_cache_dir(child_dir=''):
|
||||
"""
|
||||
@ -150,11 +159,23 @@ def load_model_config_from_hf(model_id: str):
|
||||
return pretrained_cfg, model_name
|
||||
|
||||
|
||||
def load_state_dict_from_hf(model_id: str, filename: str = 'pytorch_model.bin'):
|
||||
def load_state_dict_from_hf(model_id: str, filename: str = HF_WEIGHTS_NAME):
|
||||
assert has_hf_hub(True)
|
||||
cached_file = download_from_hf(model_id, filename)
|
||||
state_dict = torch.load(cached_file, map_location='cpu')
|
||||
return state_dict
|
||||
hf_model_id, hf_revision = hf_split(model_id)
|
||||
|
||||
# Look for .safetensors alternatives and load from it if it exists
|
||||
for safe_filename in _get_safe_alternatives(filename):
|
||||
try:
|
||||
cached_safe_file = hf_hub_download(repo_id=hf_model_id, filename=safe_filename, revision=hf_revision)
|
||||
_logger.warning(f"[{model_id}] Safe alternative available for '{filename}' (as '{safe_filename}'). Loading weights using safetensors.")
|
||||
return safetensors.torch.load_file(cached_safe_file, device="cpu")
|
||||
except EntryNotFoundError:
|
||||
pass
|
||||
|
||||
# Otherwise, load using pytorch.load
|
||||
cached_file = hf_hub_download(hf_model_id, filename=filename, revision=hf_revision)
|
||||
_logger.warning(f"[{model_id}] Safe alternative not found for '{filename}'. Loading weights using default pytorch.")
|
||||
return torch.load(cached_file, map_location='cpu')
|
||||
|
||||
|
||||
def save_config_for_hf(model, config_path: str, model_config: Optional[dict] = None):
|
||||
@ -195,13 +216,22 @@ def save_config_for_hf(model, config_path: str, model_config: Optional[dict] = N
|
||||
json.dump(hf_config, f, indent=2)
|
||||
|
||||
|
||||
def save_for_hf(model, save_directory: str, model_config: Optional[dict] = None):
|
||||
def save_for_hf(
|
||||
model,
|
||||
save_directory: str,
|
||||
model_config: Optional[dict] = None,
|
||||
safe_serialization: Union[bool, Literal["both"]] = False
|
||||
):
|
||||
assert has_hf_hub(True)
|
||||
save_directory = Path(save_directory)
|
||||
save_directory.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
weights_path = save_directory / 'pytorch_model.bin'
|
||||
torch.save(model.state_dict(), weights_path)
|
||||
# Save model weights, either safely (using safetensors), or using legacy pytorch approach or both.
|
||||
tensors = model.state_dict()
|
||||
if safe_serialization is True or safe_serialization == "both":
|
||||
safetensors.torch.save_file(tensors, save_directory / HF_SAFE_WEIGHTS_NAME)
|
||||
if safe_serialization is False or safe_serialization == "both":
|
||||
torch.save(tensors, save_directory / HF_WEIGHTS_NAME)
|
||||
|
||||
config_path = save_directory / 'config.json'
|
||||
save_config_for_hf(model, config_path, model_config=model_config)
|
||||
@ -217,7 +247,15 @@ def push_to_hf_hub(
|
||||
create_pr: bool = False,
|
||||
model_config: Optional[dict] = None,
|
||||
model_card: Optional[dict] = None,
|
||||
safe_serialization: Union[bool, Literal["both"]] = False
|
||||
):
|
||||
"""
|
||||
Arguments:
|
||||
(...)
|
||||
safe_serialization (`bool` or `"both"`, *optional*, defaults to `False`):
|
||||
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
|
||||
Can be set to `"both"` in order to push both safe and unsafe weights.
|
||||
"""
|
||||
# Create repo if it doesn't exist yet
|
||||
repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True)
|
||||
|
||||
@ -236,7 +274,7 @@ def push_to_hf_hub(
|
||||
# Dump model and push to Hub
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
# Save model weights and config.
|
||||
save_for_hf(model, tmpdir, model_config=model_config)
|
||||
save_for_hf(model, tmpdir, model_config=model_config, safe_serialization=safe_serialization)
|
||||
|
||||
# Add readme if it does not exist
|
||||
if not has_readme:
|
||||
@ -302,3 +340,15 @@ def generate_readme(model_card: dict, model_name: str):
|
||||
for c in citations:
|
||||
readme_text += f"```bibtex\n{c}\n```\n"
|
||||
return readme_text
|
||||
|
||||
def _get_safe_alternatives(filename: str) -> Iterable[str]:
|
||||
"""Returns potential safetensors alternatives for a given filename.
|
||||
|
||||
Use case:
|
||||
When downloading a model from the Huggingface Hub, we first look if a .safetensors file exists and if yes, we use it.
|
||||
Main use case is filename "pytorch_model.bin" => check for "model.safetensors" or "pytorch_model.safetensors".
|
||||
"""
|
||||
if filename == HF_WEIGHTS_NAME:
|
||||
yield HF_SAFE_WEIGHTS_NAME
|
||||
if filename.endswith(".bin"):
|
||||
yield filename[:-4] + ".safetensors"
|
Loading…
x
Reference in New Issue
Block a user