mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add support to load safetensors weights
This commit is contained in:
parent
f35d6ea57b
commit
8470e29541
@ -17,21 +17,26 @@ import os
|
|||||||
import glob
|
import glob
|
||||||
import hashlib
|
import hashlib
|
||||||
from timm.models import load_state_dict
|
from timm.models import load_state_dict
|
||||||
|
import safetensors.torch
|
||||||
|
|
||||||
|
DEFAULT_OUTPUT = "./average.pth"
|
||||||
|
DEFAULT_SAFE_OUTPUT = "./average.safetensors"
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='PyTorch Checkpoint Averager')
|
parser = argparse.ArgumentParser(description='PyTorch Checkpoint Averager')
|
||||||
parser.add_argument('--input', default='', type=str, metavar='PATH',
|
parser.add_argument('--input', default='', type=str, metavar='PATH',
|
||||||
help='path to base input folder containing checkpoints')
|
help='path to base input folder containing checkpoints')
|
||||||
parser.add_argument('--filter', default='*.pth.tar', type=str, metavar='WILDCARD',
|
parser.add_argument('--filter', default='*.pth.tar', type=str, metavar='WILDCARD',
|
||||||
help='checkpoint filter (path wildcard)')
|
help='checkpoint filter (path wildcard)')
|
||||||
parser.add_argument('--output', default='./averaged.pth', type=str, metavar='PATH',
|
parser.add_argument('--output', default=DEFAULT_OUTPUT, type=str, metavar='PATH',
|
||||||
help='output filename')
|
help=f'Output filename. Defaults to {DEFAULT_SAFE_OUTPUT} when passing --safetensors.')
|
||||||
parser.add_argument('--no-use-ema', dest='no_use_ema', action='store_true',
|
parser.add_argument('--no-use-ema', dest='no_use_ema', action='store_true',
|
||||||
help='Force not using ema version of weights (if present)')
|
help='Force not using ema version of weights (if present)')
|
||||||
parser.add_argument('--no-sort', dest='no_sort', action='store_true',
|
parser.add_argument('--no-sort', dest='no_sort', action='store_true',
|
||||||
help='Do not sort and select by checkpoint metric, also makes "n" argument irrelevant')
|
help='Do not sort and select by checkpoint metric, also makes "n" argument irrelevant')
|
||||||
parser.add_argument('-n', type=int, default=10, metavar='N',
|
parser.add_argument('-n', type=int, default=10, metavar='N',
|
||||||
help='Number of checkpoints to average')
|
help='Number of checkpoints to average')
|
||||||
|
parser.add_argument('--safetensors', action='store_true',
|
||||||
|
help='Save weights using safetensors instead of the default torch way (pickle).')
|
||||||
|
|
||||||
def checkpoint_metric(checkpoint_path):
|
def checkpoint_metric(checkpoint_path):
|
||||||
if not checkpoint_path or not os.path.isfile(checkpoint_path):
|
if not checkpoint_path or not os.path.isfile(checkpoint_path):
|
||||||
@ -55,6 +60,15 @@ def main():
|
|||||||
# by default sort by checkpoint metric (if present) and avg top n checkpoints
|
# by default sort by checkpoint metric (if present) and avg top n checkpoints
|
||||||
args.sort = not args.no_sort
|
args.sort = not args.no_sort
|
||||||
|
|
||||||
|
if args.safetensors and args.output == DEFAULT_OUTPUT:
|
||||||
|
# Default path changes if using safetensors
|
||||||
|
args.output = DEFAULT_SAFE_OUTPUT
|
||||||
|
if args.safetensors and not args.output.endswith(".safetensors"):
|
||||||
|
print(
|
||||||
|
"Warning: saving weights as safetensors but output file extension is not "
|
||||||
|
f"set to '.safetensors': {args.output}"
|
||||||
|
)
|
||||||
|
|
||||||
if os.path.exists(args.output):
|
if os.path.exists(args.output):
|
||||||
print("Error: Output filename ({}) already exists.".format(args.output))
|
print("Error: Output filename ({}) already exists.".format(args.output))
|
||||||
exit(1)
|
exit(1)
|
||||||
@ -107,10 +121,13 @@ def main():
|
|||||||
v = v.clamp(float32_info.min, float32_info.max)
|
v = v.clamp(float32_info.min, float32_info.max)
|
||||||
final_state_dict[k] = v.to(dtype=torch.float32)
|
final_state_dict[k] = v.to(dtype=torch.float32)
|
||||||
|
|
||||||
try:
|
if args.safetensors:
|
||||||
torch.save(final_state_dict, args.output, _use_new_zipfile_serialization=False)
|
safetensors.torch.save_file(final_state_dict, args.output)
|
||||||
except:
|
else:
|
||||||
torch.save(final_state_dict, args.output)
|
try:
|
||||||
|
torch.save(final_state_dict, args.output, _use_new_zipfile_serialization=False)
|
||||||
|
except:
|
||||||
|
torch.save(final_state_dict, args.output)
|
||||||
|
|
||||||
with open(args.output, 'rb') as f:
|
with open(args.output, 'rb') as f:
|
||||||
sha_hash = hashlib.sha256(f.read()).hexdigest()
|
sha_hash = hashlib.sha256(f.read()).hexdigest()
|
||||||
|
@ -11,8 +11,8 @@ import torch
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import safetensors.torch
|
||||||
import shutil
|
import shutil
|
||||||
from collections import OrderedDict
|
|
||||||
from timm.models import load_state_dict
|
from timm.models import load_state_dict
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='PyTorch Checkpoint Cleaner')
|
parser = argparse.ArgumentParser(description='PyTorch Checkpoint Cleaner')
|
||||||
@ -24,6 +24,8 @@ parser.add_argument('--no-use-ema', dest='no_use_ema', action='store_true',
|
|||||||
help='use ema version of weights if present')
|
help='use ema version of weights if present')
|
||||||
parser.add_argument('--clean-aux-bn', dest='clean_aux_bn', action='store_true',
|
parser.add_argument('--clean-aux-bn', dest='clean_aux_bn', action='store_true',
|
||||||
help='remove auxiliary batch norm layers (from SplitBN training) from checkpoint')
|
help='remove auxiliary batch norm layers (from SplitBN training) from checkpoint')
|
||||||
|
parser.add_argument('--safetensors', action='store_true',
|
||||||
|
help='Save weights using safetensors instead of the default torch way (pickle).')
|
||||||
|
|
||||||
_TEMP_NAME = './_checkpoint.pth'
|
_TEMP_NAME = './_checkpoint.pth'
|
||||||
|
|
||||||
@ -35,10 +37,10 @@ def main():
|
|||||||
print("Error: Output filename ({}) already exists.".format(args.output))
|
print("Error: Output filename ({}) already exists.".format(args.output))
|
||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
clean_checkpoint(args.checkpoint, args.output, not args.no_use_ema, args.clean_aux_bn)
|
clean_checkpoint(args.checkpoint, args.output, not args.no_use_ema, args.clean_aux_bn, safe_serialization=args.safetensors)
|
||||||
|
|
||||||
|
|
||||||
def clean_checkpoint(checkpoint, output='', use_ema=True, clean_aux_bn=False):
|
def clean_checkpoint(checkpoint, output='', use_ema=True, clean_aux_bn=False, safe_serialization: bool=False):
|
||||||
# Load an existing checkpoint to CPU, strip everything but the state_dict and re-save
|
# Load an existing checkpoint to CPU, strip everything but the state_dict and re-save
|
||||||
if checkpoint and os.path.isfile(checkpoint):
|
if checkpoint and os.path.isfile(checkpoint):
|
||||||
print("=> Loading checkpoint '{}'".format(checkpoint))
|
print("=> Loading checkpoint '{}'".format(checkpoint))
|
||||||
@ -53,10 +55,13 @@ def clean_checkpoint(checkpoint, output='', use_ema=True, clean_aux_bn=False):
|
|||||||
new_state_dict[name] = v
|
new_state_dict[name] = v
|
||||||
print("=> Loaded state_dict from '{}'".format(checkpoint))
|
print("=> Loaded state_dict from '{}'".format(checkpoint))
|
||||||
|
|
||||||
try:
|
if safe_serialization:
|
||||||
torch.save(new_state_dict, _TEMP_NAME, _use_new_zipfile_serialization=False)
|
safetensors.torch.save_file(new_state_dict, _TEMP_NAME)
|
||||||
except:
|
else:
|
||||||
torch.save(new_state_dict, _TEMP_NAME)
|
try:
|
||||||
|
torch.save(new_state_dict, _TEMP_NAME, _use_new_zipfile_serialization=False)
|
||||||
|
except:
|
||||||
|
torch.save(new_state_dict, _TEMP_NAME)
|
||||||
|
|
||||||
with open(_TEMP_NAME, 'rb') as f:
|
with open(_TEMP_NAME, 'rb') as f:
|
||||||
sha_hash = hashlib.sha256(f.read()).hexdigest()
|
sha_hash = hashlib.sha256(f.read()).hexdigest()
|
||||||
@ -67,7 +72,7 @@ def clean_checkpoint(checkpoint, output='', use_ema=True, clean_aux_bn=False):
|
|||||||
else:
|
else:
|
||||||
checkpoint_root = ''
|
checkpoint_root = ''
|
||||||
checkpoint_base = os.path.splitext(checkpoint)[0]
|
checkpoint_base = os.path.splitext(checkpoint)[0]
|
||||||
final_filename = '-'.join([checkpoint_base, sha_hash[:8]]) + '.pth'
|
final_filename = '-'.join([checkpoint_base, sha_hash[:8]]) + ('.safetensors' if safe_serialization else '.pth')
|
||||||
shutil.move(_TEMP_NAME, os.path.join(checkpoint_root, final_filename))
|
shutil.move(_TEMP_NAME, os.path.join(checkpoint_root, final_filename))
|
||||||
print("=> Saved state_dict to '{}, SHA256: {}'".format(final_filename, sha_hash))
|
print("=> Saved state_dict to '{}, SHA256: {}'".format(final_filename, sha_hash))
|
||||||
return final_filename
|
return final_filename
|
||||||
|
@ -2,3 +2,4 @@ torch>=1.7
|
|||||||
torchvision
|
torchvision
|
||||||
pyyaml
|
pyyaml
|
||||||
huggingface_hub
|
huggingface_hub
|
||||||
|
safetensors>=0.2
|
@ -7,6 +7,7 @@ import os
|
|||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import safetensors.torch
|
||||||
|
|
||||||
import timm.models._builder
|
import timm.models._builder
|
||||||
|
|
||||||
@ -26,7 +27,12 @@ def clean_state_dict(state_dict):
|
|||||||
|
|
||||||
def load_state_dict(checkpoint_path, use_ema=True):
|
def load_state_dict(checkpoint_path, use_ema=True):
|
||||||
if checkpoint_path and os.path.isfile(checkpoint_path):
|
if checkpoint_path and os.path.isfile(checkpoint_path):
|
||||||
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
# Check if safetensors or not and load weights accordingly
|
||||||
|
if str(checkpoint_path).endswith(".safetensors"):
|
||||||
|
checkpoint = safetensors.torch.load_file(checkpoint_path, device='cpu')
|
||||||
|
else:
|
||||||
|
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
||||||
|
|
||||||
state_dict_key = ''
|
state_dict_key = ''
|
||||||
if isinstance(checkpoint, dict):
|
if isinstance(checkpoint, dict):
|
||||||
if use_ema and checkpoint.get('state_dict_ema', None) is not None:
|
if use_ema and checkpoint.get('state_dict_ema', None) is not None:
|
||||||
|
@ -2,19 +2,25 @@ import hashlib
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from tempfile import TemporaryDirectory
|
from tempfile import TemporaryDirectory
|
||||||
from typing import Optional, Union
|
from typing import Iterable, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.hub import HASH_REGEX, download_url_to_file, urlparse
|
from torch.hub import HASH_REGEX, download_url_to_file, urlparse
|
||||||
|
import safetensors.torch
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from torch.hub import get_dir
|
from torch.hub import get_dir
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from torch.hub import _get_torch_home as get_dir
|
from torch.hub import _get_torch_home as get_dir
|
||||||
|
|
||||||
|
if sys.version_info >= (3, 8):
|
||||||
|
from typing import Literal
|
||||||
|
else:
|
||||||
|
from typing_extensions import Literal
|
||||||
|
|
||||||
from timm import __version__
|
from timm import __version__
|
||||||
from timm.models._pretrained import filter_pretrained_cfg
|
from timm.models._pretrained import filter_pretrained_cfg
|
||||||
|
|
||||||
@ -35,6 +41,9 @@ _logger = logging.getLogger(__name__)
|
|||||||
__all__ = ['get_cache_dir', 'download_cached_file', 'has_hf_hub', 'hf_split', 'load_model_config_from_hf',
|
__all__ = ['get_cache_dir', 'download_cached_file', 'has_hf_hub', 'hf_split', 'load_model_config_from_hf',
|
||||||
'load_state_dict_from_hf', 'save_for_hf', 'push_to_hf_hub']
|
'load_state_dict_from_hf', 'save_for_hf', 'push_to_hf_hub']
|
||||||
|
|
||||||
|
# Default name for a weights file hosted on the Huggingface Hub.
|
||||||
|
HF_WEIGHTS_NAME = "pytorch_model.bin" # default pytorch pkl
|
||||||
|
HF_SAFE_WEIGHTS_NAME = "model.safetensors" # safetensors version
|
||||||
|
|
||||||
def get_cache_dir(child_dir=''):
|
def get_cache_dir(child_dir=''):
|
||||||
"""
|
"""
|
||||||
@ -150,11 +159,23 @@ def load_model_config_from_hf(model_id: str):
|
|||||||
return pretrained_cfg, model_name
|
return pretrained_cfg, model_name
|
||||||
|
|
||||||
|
|
||||||
def load_state_dict_from_hf(model_id: str, filename: str = 'pytorch_model.bin'):
|
def load_state_dict_from_hf(model_id: str, filename: str = HF_WEIGHTS_NAME):
|
||||||
assert has_hf_hub(True)
|
assert has_hf_hub(True)
|
||||||
cached_file = download_from_hf(model_id, filename)
|
hf_model_id, hf_revision = hf_split(model_id)
|
||||||
state_dict = torch.load(cached_file, map_location='cpu')
|
|
||||||
return state_dict
|
# Look for .safetensors alternatives and load from it if it exists
|
||||||
|
for safe_filename in _get_safe_alternatives(filename):
|
||||||
|
try:
|
||||||
|
cached_safe_file = hf_hub_download(repo_id=hf_model_id, filename=safe_filename, revision=hf_revision)
|
||||||
|
_logger.warning(f"[{model_id}] Safe alternative available for '{filename}' (as '{safe_filename}'). Loading weights using safetensors.")
|
||||||
|
return safetensors.torch.load_file(cached_safe_file, device="cpu")
|
||||||
|
except EntryNotFoundError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Otherwise, load using pytorch.load
|
||||||
|
cached_file = hf_hub_download(hf_model_id, filename=filename, revision=hf_revision)
|
||||||
|
_logger.warning(f"[{model_id}] Safe alternative not found for '{filename}'. Loading weights using default pytorch.")
|
||||||
|
return torch.load(cached_file, map_location='cpu')
|
||||||
|
|
||||||
|
|
||||||
def save_config_for_hf(model, config_path: str, model_config: Optional[dict] = None):
|
def save_config_for_hf(model, config_path: str, model_config: Optional[dict] = None):
|
||||||
@ -195,13 +216,22 @@ def save_config_for_hf(model, config_path: str, model_config: Optional[dict] = N
|
|||||||
json.dump(hf_config, f, indent=2)
|
json.dump(hf_config, f, indent=2)
|
||||||
|
|
||||||
|
|
||||||
def save_for_hf(model, save_directory: str, model_config: Optional[dict] = None):
|
def save_for_hf(
|
||||||
|
model,
|
||||||
|
save_directory: str,
|
||||||
|
model_config: Optional[dict] = None,
|
||||||
|
safe_serialization: Union[bool, Literal["both"]] = False
|
||||||
|
):
|
||||||
assert has_hf_hub(True)
|
assert has_hf_hub(True)
|
||||||
save_directory = Path(save_directory)
|
save_directory = Path(save_directory)
|
||||||
save_directory.mkdir(exist_ok=True, parents=True)
|
save_directory.mkdir(exist_ok=True, parents=True)
|
||||||
|
|
||||||
weights_path = save_directory / 'pytorch_model.bin'
|
# Save model weights, either safely (using safetensors), or using legacy pytorch approach or both.
|
||||||
torch.save(model.state_dict(), weights_path)
|
tensors = model.state_dict()
|
||||||
|
if safe_serialization is True or safe_serialization == "both":
|
||||||
|
safetensors.torch.save_file(tensors, save_directory / HF_SAFE_WEIGHTS_NAME)
|
||||||
|
if safe_serialization is False or safe_serialization == "both":
|
||||||
|
torch.save(tensors, save_directory / HF_WEIGHTS_NAME)
|
||||||
|
|
||||||
config_path = save_directory / 'config.json'
|
config_path = save_directory / 'config.json'
|
||||||
save_config_for_hf(model, config_path, model_config=model_config)
|
save_config_for_hf(model, config_path, model_config=model_config)
|
||||||
@ -217,7 +247,15 @@ def push_to_hf_hub(
|
|||||||
create_pr: bool = False,
|
create_pr: bool = False,
|
||||||
model_config: Optional[dict] = None,
|
model_config: Optional[dict] = None,
|
||||||
model_card: Optional[dict] = None,
|
model_card: Optional[dict] = None,
|
||||||
|
safe_serialization: Union[bool, Literal["both"]] = False
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Arguments:
|
||||||
|
(...)
|
||||||
|
safe_serialization (`bool` or `"both"`, *optional*, defaults to `False`):
|
||||||
|
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
|
||||||
|
Can be set to `"both"` in order to push both safe and unsafe weights.
|
||||||
|
"""
|
||||||
# Create repo if it doesn't exist yet
|
# Create repo if it doesn't exist yet
|
||||||
repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True)
|
repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True)
|
||||||
|
|
||||||
@ -236,7 +274,7 @@ def push_to_hf_hub(
|
|||||||
# Dump model and push to Hub
|
# Dump model and push to Hub
|
||||||
with TemporaryDirectory() as tmpdir:
|
with TemporaryDirectory() as tmpdir:
|
||||||
# Save model weights and config.
|
# Save model weights and config.
|
||||||
save_for_hf(model, tmpdir, model_config=model_config)
|
save_for_hf(model, tmpdir, model_config=model_config, safe_serialization=safe_serialization)
|
||||||
|
|
||||||
# Add readme if it does not exist
|
# Add readme if it does not exist
|
||||||
if not has_readme:
|
if not has_readme:
|
||||||
@ -302,3 +340,15 @@ def generate_readme(model_card: dict, model_name: str):
|
|||||||
for c in citations:
|
for c in citations:
|
||||||
readme_text += f"```bibtex\n{c}\n```\n"
|
readme_text += f"```bibtex\n{c}\n```\n"
|
||||||
return readme_text
|
return readme_text
|
||||||
|
|
||||||
|
def _get_safe_alternatives(filename: str) -> Iterable[str]:
|
||||||
|
"""Returns potential safetensors alternatives for a given filename.
|
||||||
|
|
||||||
|
Use case:
|
||||||
|
When downloading a model from the Huggingface Hub, we first look if a .safetensors file exists and if yes, we use it.
|
||||||
|
Main use case is filename "pytorch_model.bin" => check for "model.safetensors" or "pytorch_model.safetensors".
|
||||||
|
"""
|
||||||
|
if filename == HF_WEIGHTS_NAME:
|
||||||
|
yield HF_SAFE_WEIGHTS_NAME
|
||||||
|
if filename.endswith(".bin"):
|
||||||
|
yield filename[:-4] + ".safetensors"
|
Loading…
x
Reference in New Issue
Block a user