diff --git a/avg_checkpoints.py b/avg_checkpoints.py index 83af5bbd..bdfa2265 100755 --- a/avg_checkpoints.py +++ b/avg_checkpoints.py @@ -17,21 +17,26 @@ import os import glob import hashlib from timm.models import load_state_dict +import safetensors.torch + +DEFAULT_OUTPUT = "./average.pth" +DEFAULT_SAFE_OUTPUT = "./average.safetensors" parser = argparse.ArgumentParser(description='PyTorch Checkpoint Averager') parser.add_argument('--input', default='', type=str, metavar='PATH', help='path to base input folder containing checkpoints') parser.add_argument('--filter', default='*.pth.tar', type=str, metavar='WILDCARD', help='checkpoint filter (path wildcard)') -parser.add_argument('--output', default='./averaged.pth', type=str, metavar='PATH', - help='output filename') +parser.add_argument('--output', default=DEFAULT_OUTPUT, type=str, metavar='PATH', + help=f'Output filename. Defaults to {DEFAULT_SAFE_OUTPUT} when passing --safetensors.') parser.add_argument('--no-use-ema', dest='no_use_ema', action='store_true', help='Force not using ema version of weights (if present)') parser.add_argument('--no-sort', dest='no_sort', action='store_true', help='Do not sort and select by checkpoint metric, also makes "n" argument irrelevant') parser.add_argument('-n', type=int, default=10, metavar='N', help='Number of checkpoints to average') - +parser.add_argument('--safetensors', action='store_true', + help='Save weights using safetensors instead of the default torch way (pickle).') def checkpoint_metric(checkpoint_path): if not checkpoint_path or not os.path.isfile(checkpoint_path): @@ -55,6 +60,15 @@ def main(): # by default sort by checkpoint metric (if present) and avg top n checkpoints args.sort = not args.no_sort + if args.safetensors and args.output == DEFAULT_OUTPUT: + # Default path changes if using safetensors + args.output = DEFAULT_SAFE_OUTPUT + if args.safetensors and not args.output.endswith(".safetensors"): + print( + "Warning: saving weights as safetensors but output file extension is not " + f"set to '.safetensors': {args.output}" + ) + if os.path.exists(args.output): print("Error: Output filename ({}) already exists.".format(args.output)) exit(1) @@ -107,10 +121,13 @@ def main(): v = v.clamp(float32_info.min, float32_info.max) final_state_dict[k] = v.to(dtype=torch.float32) - try: - torch.save(final_state_dict, args.output, _use_new_zipfile_serialization=False) - except: - torch.save(final_state_dict, args.output) + if args.safetensors: + safetensors.torch.save_file(final_state_dict, args.output) + else: + try: + torch.save(final_state_dict, args.output, _use_new_zipfile_serialization=False) + except: + torch.save(final_state_dict, args.output) with open(args.output, 'rb') as f: sha_hash = hashlib.sha256(f.read()).hexdigest() diff --git a/clean_checkpoint.py b/clean_checkpoint.py index 17c270db..d18951bc 100755 --- a/clean_checkpoint.py +++ b/clean_checkpoint.py @@ -11,8 +11,8 @@ import torch import argparse import os import hashlib +import safetensors.torch import shutil -from collections import OrderedDict from timm.models import load_state_dict parser = argparse.ArgumentParser(description='PyTorch Checkpoint Cleaner') @@ -24,6 +24,8 @@ parser.add_argument('--no-use-ema', dest='no_use_ema', action='store_true', help='use ema version of weights if present') parser.add_argument('--clean-aux-bn', dest='clean_aux_bn', action='store_true', help='remove auxiliary batch norm layers (from SplitBN training) from checkpoint') +parser.add_argument('--safetensors', action='store_true', + help='Save weights using safetensors instead of the default torch way (pickle).') _TEMP_NAME = './_checkpoint.pth' @@ -35,10 +37,10 @@ def main(): print("Error: Output filename ({}) already exists.".format(args.output)) exit(1) - clean_checkpoint(args.checkpoint, args.output, not args.no_use_ema, args.clean_aux_bn) + clean_checkpoint(args.checkpoint, args.output, not args.no_use_ema, args.clean_aux_bn, safe_serialization=args.safetensors) -def clean_checkpoint(checkpoint, output='', use_ema=True, clean_aux_bn=False): +def clean_checkpoint(checkpoint, output='', use_ema=True, clean_aux_bn=False, safe_serialization: bool=False): # Load an existing checkpoint to CPU, strip everything but the state_dict and re-save if checkpoint and os.path.isfile(checkpoint): print("=> Loading checkpoint '{}'".format(checkpoint)) @@ -53,10 +55,13 @@ def clean_checkpoint(checkpoint, output='', use_ema=True, clean_aux_bn=False): new_state_dict[name] = v print("=> Loaded state_dict from '{}'".format(checkpoint)) - try: - torch.save(new_state_dict, _TEMP_NAME, _use_new_zipfile_serialization=False) - except: - torch.save(new_state_dict, _TEMP_NAME) + if safe_serialization: + safetensors.torch.save_file(new_state_dict, _TEMP_NAME) + else: + try: + torch.save(new_state_dict, _TEMP_NAME, _use_new_zipfile_serialization=False) + except: + torch.save(new_state_dict, _TEMP_NAME) with open(_TEMP_NAME, 'rb') as f: sha_hash = hashlib.sha256(f.read()).hexdigest() @@ -67,7 +72,7 @@ def clean_checkpoint(checkpoint, output='', use_ema=True, clean_aux_bn=False): else: checkpoint_root = '' checkpoint_base = os.path.splitext(checkpoint)[0] - final_filename = '-'.join([checkpoint_base, sha_hash[:8]]) + '.pth' + final_filename = '-'.join([checkpoint_base, sha_hash[:8]]) + ('.safetensors' if safe_serialization else '.pth') shutil.move(_TEMP_NAME, os.path.join(checkpoint_root, final_filename)) print("=> Saved state_dict to '{}, SHA256: {}'".format(final_filename, sha_hash)) return final_filename diff --git a/requirements.txt b/requirements.txt index 5846bb36..b750d03d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,3 +2,4 @@ torch>=1.7 torchvision pyyaml huggingface_hub +safetensors>=0.2 \ No newline at end of file diff --git a/timm/models/_helpers.py b/timm/models/_helpers.py index 995292aa..adae77eb 100644 --- a/timm/models/_helpers.py +++ b/timm/models/_helpers.py @@ -7,6 +7,7 @@ import os from collections import OrderedDict import torch +import safetensors.torch import timm.models._builder @@ -26,7 +27,12 @@ def clean_state_dict(state_dict): def load_state_dict(checkpoint_path, use_ema=True): if checkpoint_path and os.path.isfile(checkpoint_path): - checkpoint = torch.load(checkpoint_path, map_location='cpu') + # Check if safetensors or not and load weights accordingly + if str(checkpoint_path).endswith(".safetensors"): + checkpoint = safetensors.torch.load_file(checkpoint_path, device='cpu') + else: + checkpoint = torch.load(checkpoint_path, map_location='cpu') + state_dict_key = '' if isinstance(checkpoint, dict): if use_ema and checkpoint.get('state_dict_ema', None) is not None: diff --git a/timm/models/_hub.py b/timm/models/_hub.py index e247382f..f69b00f2 100644 --- a/timm/models/_hub.py +++ b/timm/models/_hub.py @@ -2,19 +2,25 @@ import hashlib import json import logging import os +import sys from functools import partial from pathlib import Path from tempfile import TemporaryDirectory -from typing import Optional, Union - +from typing import Iterable, Optional, Union import torch from torch.hub import HASH_REGEX, download_url_to_file, urlparse +import safetensors.torch try: from torch.hub import get_dir except ImportError: from torch.hub import _get_torch_home as get_dir +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal + from timm import __version__ from timm.models._pretrained import filter_pretrained_cfg @@ -35,6 +41,9 @@ _logger = logging.getLogger(__name__) __all__ = ['get_cache_dir', 'download_cached_file', 'has_hf_hub', 'hf_split', 'load_model_config_from_hf', 'load_state_dict_from_hf', 'save_for_hf', 'push_to_hf_hub'] +# Default name for a weights file hosted on the Huggingface Hub. +HF_WEIGHTS_NAME = "pytorch_model.bin" # default pytorch pkl +HF_SAFE_WEIGHTS_NAME = "model.safetensors" # safetensors version def get_cache_dir(child_dir=''): """ @@ -150,11 +159,23 @@ def load_model_config_from_hf(model_id: str): return pretrained_cfg, model_name -def load_state_dict_from_hf(model_id: str, filename: str = 'pytorch_model.bin'): +def load_state_dict_from_hf(model_id: str, filename: str = HF_WEIGHTS_NAME): assert has_hf_hub(True) - cached_file = download_from_hf(model_id, filename) - state_dict = torch.load(cached_file, map_location='cpu') - return state_dict + hf_model_id, hf_revision = hf_split(model_id) + + # Look for .safetensors alternatives and load from it if it exists + for safe_filename in _get_safe_alternatives(filename): + try: + cached_safe_file = hf_hub_download(repo_id=hf_model_id, filename=safe_filename, revision=hf_revision) + _logger.warning(f"[{model_id}] Safe alternative available for '{filename}' (as '{safe_filename}'). Loading weights using safetensors.") + return safetensors.torch.load_file(cached_safe_file, device="cpu") + except EntryNotFoundError: + pass + + # Otherwise, load using pytorch.load + cached_file = hf_hub_download(hf_model_id, filename=filename, revision=hf_revision) + _logger.warning(f"[{model_id}] Safe alternative not found for '{filename}'. Loading weights using default pytorch.") + return torch.load(cached_file, map_location='cpu') def save_config_for_hf(model, config_path: str, model_config: Optional[dict] = None): @@ -195,13 +216,22 @@ def save_config_for_hf(model, config_path: str, model_config: Optional[dict] = N json.dump(hf_config, f, indent=2) -def save_for_hf(model, save_directory: str, model_config: Optional[dict] = None): +def save_for_hf( + model, + save_directory: str, + model_config: Optional[dict] = None, + safe_serialization: Union[bool, Literal["both"]] = False + ): assert has_hf_hub(True) save_directory = Path(save_directory) save_directory.mkdir(exist_ok=True, parents=True) - weights_path = save_directory / 'pytorch_model.bin' - torch.save(model.state_dict(), weights_path) + # Save model weights, either safely (using safetensors), or using legacy pytorch approach or both. + tensors = model.state_dict() + if safe_serialization is True or safe_serialization == "both": + safetensors.torch.save_file(tensors, save_directory / HF_SAFE_WEIGHTS_NAME) + if safe_serialization is False or safe_serialization == "both": + torch.save(tensors, save_directory / HF_WEIGHTS_NAME) config_path = save_directory / 'config.json' save_config_for_hf(model, config_path, model_config=model_config) @@ -217,7 +247,15 @@ def push_to_hf_hub( create_pr: bool = False, model_config: Optional[dict] = None, model_card: Optional[dict] = None, + safe_serialization: Union[bool, Literal["both"]] = False ): + """ + Arguments: + (...) + safe_serialization (`bool` or `"both"`, *optional*, defaults to `False`): + Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). + Can be set to `"both"` in order to push both safe and unsafe weights. + """ # Create repo if it doesn't exist yet repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True) @@ -236,7 +274,7 @@ def push_to_hf_hub( # Dump model and push to Hub with TemporaryDirectory() as tmpdir: # Save model weights and config. - save_for_hf(model, tmpdir, model_config=model_config) + save_for_hf(model, tmpdir, model_config=model_config, safe_serialization=safe_serialization) # Add readme if it does not exist if not has_readme: @@ -302,3 +340,15 @@ def generate_readme(model_card: dict, model_name: str): for c in citations: readme_text += f"```bibtex\n{c}\n```\n" return readme_text + +def _get_safe_alternatives(filename: str) -> Iterable[str]: + """Returns potential safetensors alternatives for a given filename. + + Use case: + When downloading a model from the Huggingface Hub, we first look if a .safetensors file exists and if yes, we use it. + Main use case is filename "pytorch_model.bin" => check for "model.safetensors" or "pytorch_model.safetensors". + """ + if filename == HF_WEIGHTS_NAME: + yield HF_SAFE_WEIGHTS_NAME + if filename.endswith(".bin"): + yield filename[:-4] + ".safetensors" \ No newline at end of file