mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Make safetensor import option for now. Improve avg/clean checkpoints ext handling a bit (more consistent).
This commit is contained in:
parent
7d9e321b76
commit
d0b45c9b4d
@ -17,10 +17,14 @@ import os
|
||||
import glob
|
||||
import hashlib
|
||||
from timm.models import load_state_dict
|
||||
import safetensors.torch
|
||||
try:
|
||||
import safetensors.torch
|
||||
_has_safetensors = True
|
||||
except ImportError:
|
||||
_has_safetensors = False
|
||||
|
||||
DEFAULT_OUTPUT = "./average.pth"
|
||||
DEFAULT_SAFE_OUTPUT = "./average.safetensors"
|
||||
DEFAULT_OUTPUT = "./averaged.pth"
|
||||
DEFAULT_SAFE_OUTPUT = "./averaged.safetensors"
|
||||
|
||||
parser = argparse.ArgumentParser(description='PyTorch Checkpoint Averager')
|
||||
parser.add_argument('--input', default='', type=str, metavar='PATH',
|
||||
@ -38,6 +42,7 @@ parser.add_argument('-n', type=int, default=10, metavar='N',
|
||||
parser.add_argument('--safetensors', action='store_true',
|
||||
help='Save weights using safetensors instead of the default torch way (pickle).')
|
||||
|
||||
|
||||
def checkpoint_metric(checkpoint_path):
|
||||
if not checkpoint_path or not os.path.isfile(checkpoint_path):
|
||||
return {}
|
||||
@ -63,14 +68,20 @@ def main():
|
||||
if args.safetensors and args.output == DEFAULT_OUTPUT:
|
||||
# Default path changes if using safetensors
|
||||
args.output = DEFAULT_SAFE_OUTPUT
|
||||
if args.safetensors and not args.output.endswith(".safetensors"):
|
||||
|
||||
output, output_ext = os.path.splitext(args.output)
|
||||
if not output_ext:
|
||||
output_ext = ('.safetensors' if args.safetensors else '.pth')
|
||||
output = output + output_ext
|
||||
|
||||
if args.safetensors and not output_ext == ".safetensors":
|
||||
print(
|
||||
"Warning: saving weights as safetensors but output file extension is not "
|
||||
f"set to '.safetensors': {args.output}"
|
||||
)
|
||||
|
||||
if os.path.exists(args.output):
|
||||
print("Error: Output filename ({}) already exists.".format(args.output))
|
||||
if os.path.exists(output):
|
||||
print("Error: Output filename ({}) already exists.".format(output))
|
||||
exit(1)
|
||||
|
||||
pattern = args.input
|
||||
@ -87,22 +98,27 @@ def main():
|
||||
checkpoint_metrics.append((metric, c))
|
||||
checkpoint_metrics = list(sorted(checkpoint_metrics))
|
||||
checkpoint_metrics = checkpoint_metrics[-args.n:]
|
||||
print("Selected checkpoints:")
|
||||
[print(m, c) for m, c in checkpoint_metrics]
|
||||
if checkpoint_metrics:
|
||||
print("Selected checkpoints:")
|
||||
[print(m, c) for m, c in checkpoint_metrics]
|
||||
avg_checkpoints = [c for m, c in checkpoint_metrics]
|
||||
else:
|
||||
avg_checkpoints = checkpoints
|
||||
print("Selected checkpoints:")
|
||||
[print(c) for c in checkpoints]
|
||||
if avg_checkpoints:
|
||||
print("Selected checkpoints:")
|
||||
[print(c) for c in checkpoints]
|
||||
|
||||
if not avg_checkpoints:
|
||||
print('Error: No checkpoints found to average.')
|
||||
exit(1)
|
||||
|
||||
avg_state_dict = {}
|
||||
avg_counts = {}
|
||||
for c in avg_checkpoints:
|
||||
new_state_dict = load_state_dict(c, args.use_ema)
|
||||
if not new_state_dict:
|
||||
print("Error: Checkpoint ({}) doesn't exist".format(args.checkpoint))
|
||||
print(f"Error: Checkpoint ({c}) doesn't exist")
|
||||
continue
|
||||
|
||||
for k, v in new_state_dict.items():
|
||||
if k not in avg_state_dict:
|
||||
avg_state_dict[k] = v.clone().to(dtype=torch.float64)
|
||||
@ -122,16 +138,14 @@ def main():
|
||||
final_state_dict[k] = v.to(dtype=torch.float32)
|
||||
|
||||
if args.safetensors:
|
||||
safetensors.torch.save_file(final_state_dict, args.output)
|
||||
assert _has_safetensors, "`pip install safetensors` to use .safetensors"
|
||||
safetensors.torch.save_file(final_state_dict, output)
|
||||
else:
|
||||
try:
|
||||
torch.save(final_state_dict, args.output, _use_new_zipfile_serialization=False)
|
||||
except:
|
||||
torch.save(final_state_dict, args.output)
|
||||
torch.save(final_state_dict, output)
|
||||
|
||||
with open(args.output, 'rb') as f:
|
||||
with open(output, 'rb') as f:
|
||||
sha_hash = hashlib.sha256(f.read()).hexdigest()
|
||||
print("=> Saved state_dict to '{}, SHA256: {}'".format(args.output, sha_hash))
|
||||
print(f"=> Saved state_dict to '{output}, SHA256: {sha_hash}'")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -11,9 +11,14 @@ import torch
|
||||
import argparse
|
||||
import os
|
||||
import hashlib
|
||||
import safetensors.torch
|
||||
import shutil
|
||||
import tempfile
|
||||
from timm.models import load_state_dict
|
||||
try:
|
||||
import safetensors.torch
|
||||
_has_safetensors = True
|
||||
except ImportError:
|
||||
_has_safetensors = False
|
||||
|
||||
parser = argparse.ArgumentParser(description='PyTorch Checkpoint Cleaner')
|
||||
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
|
||||
@ -22,13 +27,13 @@ parser.add_argument('--output', default='', type=str, metavar='PATH',
|
||||
help='output path')
|
||||
parser.add_argument('--no-use-ema', dest='no_use_ema', action='store_true',
|
||||
help='use ema version of weights if present')
|
||||
parser.add_argument('--no-hash', dest='no_hash', action='store_true',
|
||||
help='no hash in output filename')
|
||||
parser.add_argument('--clean-aux-bn', dest='clean_aux_bn', action='store_true',
|
||||
help='remove auxiliary batch norm layers (from SplitBN training) from checkpoint')
|
||||
parser.add_argument('--safetensors', action='store_true',
|
||||
help='Save weights using safetensors instead of the default torch way (pickle).')
|
||||
|
||||
_TEMP_NAME = './_checkpoint.pth'
|
||||
|
||||
|
||||
def main():
|
||||
args = parser.parse_args()
|
||||
@ -37,10 +42,24 @@ def main():
|
||||
print("Error: Output filename ({}) already exists.".format(args.output))
|
||||
exit(1)
|
||||
|
||||
clean_checkpoint(args.checkpoint, args.output, not args.no_use_ema, args.clean_aux_bn, safe_serialization=args.safetensors)
|
||||
clean_checkpoint(
|
||||
args.checkpoint,
|
||||
args.output,
|
||||
not args.no_use_ema,
|
||||
args.no_hash,
|
||||
args.clean_aux_bn,
|
||||
safe_serialization=args.safetensors,
|
||||
)
|
||||
|
||||
|
||||
def clean_checkpoint(checkpoint, output='', use_ema=True, clean_aux_bn=False, safe_serialization: bool=False):
|
||||
def clean_checkpoint(
|
||||
checkpoint,
|
||||
output,
|
||||
use_ema=True,
|
||||
no_hash=False,
|
||||
clean_aux_bn=False,
|
||||
safe_serialization: bool=False,
|
||||
):
|
||||
# Load an existing checkpoint to CPU, strip everything but the state_dict and re-save
|
||||
if checkpoint and os.path.isfile(checkpoint):
|
||||
print("=> Loading checkpoint '{}'".format(checkpoint))
|
||||
@ -55,25 +74,36 @@ def clean_checkpoint(checkpoint, output='', use_ema=True, clean_aux_bn=False, sa
|
||||
new_state_dict[name] = v
|
||||
print("=> Loaded state_dict from '{}'".format(checkpoint))
|
||||
|
||||
if safe_serialization:
|
||||
safetensors.torch.save_file(new_state_dict, _TEMP_NAME)
|
||||
else:
|
||||
try:
|
||||
torch.save(new_state_dict, _TEMP_NAME, _use_new_zipfile_serialization=False)
|
||||
except:
|
||||
torch.save(new_state_dict, _TEMP_NAME)
|
||||
|
||||
with open(_TEMP_NAME, 'rb') as f:
|
||||
sha_hash = hashlib.sha256(f.read()).hexdigest()
|
||||
|
||||
ext = ''
|
||||
if output:
|
||||
checkpoint_root, checkpoint_base = os.path.split(output)
|
||||
checkpoint_base = os.path.splitext(checkpoint_base)[0]
|
||||
checkpoint_base, ext = os.path.splitext(checkpoint_base)
|
||||
else:
|
||||
checkpoint_root = ''
|
||||
checkpoint_base = os.path.splitext(checkpoint)[0]
|
||||
final_filename = '-'.join([checkpoint_base, sha_hash[:8]]) + ('.safetensors' if safe_serialization else '.pth')
|
||||
shutil.move(_TEMP_NAME, os.path.join(checkpoint_root, final_filename))
|
||||
checkpoint_base = os.path.split(checkpoint)[1]
|
||||
checkpoint_base = os.path.splitext(checkpoint_base)[0]
|
||||
|
||||
temp_filename = '__' + checkpoint_base
|
||||
if safe_serialization:
|
||||
assert _has_safetensors, "`pip install safetensors` to use .safetensors"
|
||||
safetensors.torch.save_file(new_state_dict, temp_filename)
|
||||
else:
|
||||
torch.save(new_state_dict, temp_filename)
|
||||
|
||||
with open(temp_filename, 'rb') as f:
|
||||
sha_hash = hashlib.sha256(f.read()).hexdigest()
|
||||
|
||||
if ext:
|
||||
final_ext = ext
|
||||
else:
|
||||
final_ext = ('.safetensors' if safe_serialization else '.pth')
|
||||
|
||||
if no_hash:
|
||||
final_filename = checkpoint_base + final_ext
|
||||
else:
|
||||
final_filename = '-'.join([checkpoint_base, sha_hash[:8]]) + final_ext
|
||||
|
||||
shutil.move(temp_filename, os.path.join(checkpoint_root, final_filename))
|
||||
print("=> Saved state_dict to '{}, SHA256: {}'".format(final_filename, sha_hash))
|
||||
return final_filename
|
||||
else:
|
||||
|
@ -7,7 +7,11 @@ import os
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import safetensors.torch
|
||||
try:
|
||||
import safetensors.torch
|
||||
_has_safetensors = True
|
||||
except ImportError:
|
||||
_has_safetensors = False
|
||||
|
||||
import timm.models._builder
|
||||
|
||||
@ -29,6 +33,7 @@ def load_state_dict(checkpoint_path, use_ema=True):
|
||||
if checkpoint_path and os.path.isfile(checkpoint_path):
|
||||
# Check if safetensors or not and load weights accordingly
|
||||
if str(checkpoint_path).endswith(".safetensors"):
|
||||
assert _has_safetensors, "`pip install safetensors` to use .safetensors"
|
||||
checkpoint = safetensors.torch.load_file(checkpoint_path, device='cpu')
|
||||
else:
|
||||
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
||||
|
@ -7,15 +7,21 @@ from functools import partial
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Iterable, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.hub import HASH_REGEX, download_url_to_file, urlparse
|
||||
import safetensors.torch
|
||||
|
||||
try:
|
||||
from torch.hub import get_dir
|
||||
except ImportError:
|
||||
from torch.hub import _get_torch_home as get_dir
|
||||
|
||||
try:
|
||||
import safetensors.torch
|
||||
_has_safetensors = True
|
||||
except ImportError:
|
||||
_has_safetensors = False
|
||||
|
||||
if sys.version_info >= (3, 8):
|
||||
from typing import Literal
|
||||
else:
|
||||
@ -45,6 +51,7 @@ __all__ = ['get_cache_dir', 'download_cached_file', 'has_hf_hub', 'hf_split', 'l
|
||||
HF_WEIGHTS_NAME = "pytorch_model.bin" # default pytorch pkl
|
||||
HF_SAFE_WEIGHTS_NAME = "model.safetensors" # safetensors version
|
||||
|
||||
|
||||
def get_cache_dir(child_dir=''):
|
||||
"""
|
||||
Returns the location of the directory where models are cached (and creates it if necessary).
|
||||
@ -164,21 +171,28 @@ def load_state_dict_from_hf(model_id: str, filename: str = HF_WEIGHTS_NAME):
|
||||
hf_model_id, hf_revision = hf_split(model_id)
|
||||
|
||||
# Look for .safetensors alternatives and load from it if it exists
|
||||
for safe_filename in _get_safe_alternatives(filename):
|
||||
try:
|
||||
cached_safe_file = hf_hub_download(repo_id=hf_model_id, filename=safe_filename, revision=hf_revision)
|
||||
_logger.info(f"[{model_id}] Safe alternative available for '{filename}' (as '{safe_filename}'). Loading weights using safetensors.")
|
||||
return safetensors.torch.load_file(cached_safe_file, device="cpu")
|
||||
except EntryNotFoundError:
|
||||
pass
|
||||
if _has_safetensors:
|
||||
for safe_filename in _get_safe_alternatives(filename):
|
||||
try:
|
||||
cached_safe_file = hf_hub_download(repo_id=hf_model_id, filename=safe_filename, revision=hf_revision)
|
||||
_logger.info(
|
||||
f"[{model_id}] Safe alternative available for '{filename}' "
|
||||
f"(as '{safe_filename}'). Loading weights using safetensors.")
|
||||
return safetensors.torch.load_file(cached_safe_file, device="cpu")
|
||||
except EntryNotFoundError:
|
||||
pass
|
||||
|
||||
# Otherwise, load using pytorch.load
|
||||
cached_file = hf_hub_download(hf_model_id, filename=filename, revision=hf_revision)
|
||||
_logger.info(f"[{model_id}] Safe alternative not found for '{filename}'. Loading weights using default pytorch.")
|
||||
_logger.debug(f"[{model_id}] Safe alternative not found for '{filename}'. Loading weights using default pytorch.")
|
||||
return torch.load(cached_file, map_location='cpu')
|
||||
|
||||
|
||||
def save_config_for_hf(model, config_path: str, model_config: Optional[dict] = None):
|
||||
def save_config_for_hf(
|
||||
model,
|
||||
config_path: str,
|
||||
model_config: Optional[dict] = None
|
||||
):
|
||||
model_config = model_config or {}
|
||||
hf_config = {}
|
||||
pretrained_cfg = filter_pretrained_cfg(model.pretrained_cfg, remove_source=True, remove_null=True)
|
||||
@ -220,8 +234,8 @@ def save_for_hf(
|
||||
model,
|
||||
save_directory: str,
|
||||
model_config: Optional[dict] = None,
|
||||
safe_serialization: Union[bool, Literal["both"]] = False
|
||||
):
|
||||
safe_serialization: Union[bool, Literal["both"]] = False,
|
||||
):
|
||||
assert has_hf_hub(True)
|
||||
save_directory = Path(save_directory)
|
||||
save_directory.mkdir(exist_ok=True, parents=True)
|
||||
@ -229,6 +243,7 @@ def save_for_hf(
|
||||
# Save model weights, either safely (using safetensors), or using legacy pytorch approach or both.
|
||||
tensors = model.state_dict()
|
||||
if safe_serialization is True or safe_serialization == "both":
|
||||
assert _has_safetensors, "`pip install safetensors` to use .safetensors"
|
||||
safetensors.torch.save_file(tensors, save_directory / HF_SAFE_WEIGHTS_NAME)
|
||||
if safe_serialization is False or safe_serialization == "both":
|
||||
torch.save(tensors, save_directory / HF_WEIGHTS_NAME)
|
||||
@ -238,16 +253,16 @@ def save_for_hf(
|
||||
|
||||
|
||||
def push_to_hf_hub(
|
||||
model,
|
||||
repo_id: str,
|
||||
commit_message: str = 'Add model',
|
||||
token: Optional[str] = None,
|
||||
revision: Optional[str] = None,
|
||||
private: bool = False,
|
||||
create_pr: bool = False,
|
||||
model_config: Optional[dict] = None,
|
||||
model_card: Optional[dict] = None,
|
||||
safe_serialization: Union[bool, Literal["both"]] = False
|
||||
model,
|
||||
repo_id: str,
|
||||
commit_message: str = 'Add model',
|
||||
token: Optional[str] = None,
|
||||
revision: Optional[str] = None,
|
||||
private: bool = False,
|
||||
create_pr: bool = False,
|
||||
model_config: Optional[dict] = None,
|
||||
model_card: Optional[dict] = None,
|
||||
safe_serialization: Union[bool, Literal["both"]] = False,
|
||||
):
|
||||
"""
|
||||
Arguments:
|
||||
@ -341,6 +356,7 @@ def generate_readme(model_card: dict, model_name: str):
|
||||
readme_text += f"```bibtex\n{c}\n```\n"
|
||||
return readme_text
|
||||
|
||||
|
||||
def _get_safe_alternatives(filename: str) -> Iterable[str]:
|
||||
"""Returns potential safetensors alternatives for a given filename.
|
||||
|
||||
@ -350,5 +366,5 @@ def _get_safe_alternatives(filename: str) -> Iterable[str]:
|
||||
"""
|
||||
if filename == HF_WEIGHTS_NAME:
|
||||
yield HF_SAFE_WEIGHTS_NAME
|
||||
if filename.endswith(".bin"):
|
||||
yield filename[:-4] + ".safetensors"
|
||||
if filename != HF_WEIGHTS_NAME and filename.endswith(".bin"):
|
||||
return filename[:-4] + ".safetensors"
|
||||
|
Loading…
x
Reference in New Issue
Block a user