mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add checkpoint averaging script. Add headers, shebangs, exec perms to all scripts
This commit is contained in:
parent
4666cc9aed
commit
40fea63ebe
113
avg_checkpoint.py
Executable file
113
avg_checkpoint.py
Executable file
@ -0,0 +1,113 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
""" Checkpoint Averaging Script
|
||||||
|
|
||||||
|
This script averages all model weights for checkpoints in specified path that match
|
||||||
|
the specified filter wildcard. All checkpoints must be from the exact same model.
|
||||||
|
|
||||||
|
For any hope of decent results, the checkpoints should be from the same or child
|
||||||
|
(via resumes) training session. This can be viewed as similar to maintaining running
|
||||||
|
EMA (exponential moving average) of the model weights or performing SWA (stochastic
|
||||||
|
weight averaging), but post-training.
|
||||||
|
|
||||||
|
Hacked together by Ross Wightman (https://github.com/rwightman)
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import glob
|
||||||
|
import hashlib
|
||||||
|
from timm.models.helpers import load_state_dict
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='PyTorch Checkpoint Averager')
|
||||||
|
parser.add_argument('--input', default='', type=str, metavar='PATH',
|
||||||
|
help='path to base input folder containing checkpoints')
|
||||||
|
parser.add_argument('--filter', default='*.pth.tar', type=str, metavar='WILDCARD',
|
||||||
|
help='checkpoint filter (path wildcard)')
|
||||||
|
parser.add_argument('--output', default='./averaged.pth', type=str, metavar='PATH',
|
||||||
|
help='output filename')
|
||||||
|
parser.add_argument('--no-use-ema', dest='no_use_ema', action='store_true',
|
||||||
|
help='Force not using ema version of weights (if present)')
|
||||||
|
parser.add_argument('--no-sort', dest='no_sort', action='store_true',
|
||||||
|
help='Do not sort and select by checkpoint metric, also makes "n" argument irrelevant')
|
||||||
|
parser.add_argument('-n', type=int, default=10, metavar='N',
|
||||||
|
help='Number of checkpoints to average')
|
||||||
|
|
||||||
|
|
||||||
|
def checkpoint_metric(checkpoint_path):
|
||||||
|
if not checkpoint_path or not os.path.isfile(checkpoint_path):
|
||||||
|
return {}
|
||||||
|
print("=> Extracting metric from checkpoint '{}'".format(checkpoint_path))
|
||||||
|
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
||||||
|
metric = None
|
||||||
|
if 'metric' in checkpoint:
|
||||||
|
metric = checkpoint['metric']
|
||||||
|
return metric
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = parser.parse_args()
|
||||||
|
# by default use the EMA weights (if present)
|
||||||
|
args.use_ema = not args.no_use_ema
|
||||||
|
# by default sort by checkpoint metric (if present) and avg top n checkpoints
|
||||||
|
args.sort = not args.no_sort
|
||||||
|
|
||||||
|
if os.path.exists(args.output):
|
||||||
|
print("Error: Output filename ({}) already exists.".format(args.output))
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
pattern = args.input
|
||||||
|
if not args.input.endswith(os.path.sep) and not args.filter.startswith(os.path.sep):
|
||||||
|
pattern += os.path.sep
|
||||||
|
pattern += args.filter
|
||||||
|
checkpoints = glob.glob(pattern, recursive=True)
|
||||||
|
|
||||||
|
if args.sort:
|
||||||
|
checkpoint_metrics = []
|
||||||
|
for c in checkpoints:
|
||||||
|
metric = checkpoint_metric(c)
|
||||||
|
if metric is not None:
|
||||||
|
checkpoint_metrics.append((metric, c))
|
||||||
|
checkpoint_metrics = list(sorted(checkpoint_metrics))
|
||||||
|
checkpoint_metrics = checkpoint_metrics[-args.n:]
|
||||||
|
print("Selected checkpoints:")
|
||||||
|
[print(m, c) for m, c in checkpoint_metrics]
|
||||||
|
avg_checkpoints = [c for m, c in checkpoint_metrics]
|
||||||
|
else:
|
||||||
|
avg_checkpoints = checkpoints
|
||||||
|
print("Selected checkpoints:")
|
||||||
|
[print(c) for c in checkpoints]
|
||||||
|
|
||||||
|
avg_state_dict = {}
|
||||||
|
avg_counts = {}
|
||||||
|
for c in avg_checkpoints:
|
||||||
|
new_state_dict = load_state_dict(c, args.use_ema)
|
||||||
|
if not new_state_dict:
|
||||||
|
print("Error: Checkpoint ({}) doesn't exist".format(args.checkpoint))
|
||||||
|
continue
|
||||||
|
|
||||||
|
for k, v in new_state_dict.items():
|
||||||
|
if k not in avg_state_dict:
|
||||||
|
avg_state_dict[k] = v.clone().to(dtype=torch.float64)
|
||||||
|
avg_counts[k] = 1
|
||||||
|
else:
|
||||||
|
avg_state_dict[k] += v.to(dtype=torch.float64)
|
||||||
|
avg_counts[k] += 1
|
||||||
|
|
||||||
|
for k, v in avg_state_dict.items():
|
||||||
|
v.div_(avg_counts[k])
|
||||||
|
|
||||||
|
# float32 overflow seems unlikely based on weights seen to date, but who knows
|
||||||
|
float32_info = torch.finfo(torch.float32)
|
||||||
|
final_state_dict = {}
|
||||||
|
for k, v in avg_state_dict.items():
|
||||||
|
v = v.clamp(float32_info.min, float32_info.max)
|
||||||
|
final_state_dict[k] = v.to(dtype=torch.float32)
|
||||||
|
|
||||||
|
torch.save(final_state_dict, args.output)
|
||||||
|
with open(args.output, 'rb') as f:
|
||||||
|
sha_hash = hashlib.sha256(f.read()).hexdigest()
|
||||||
|
print("=> Saved state_dict to '{}, SHA256: {}'".format(args.output, sha_hash))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
11
clean_checkpoint.py
Normal file → Executable file
11
clean_checkpoint.py
Normal file → Executable file
@ -1,3 +1,12 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
""" Checkpoint Cleaning Script
|
||||||
|
|
||||||
|
Takes training checkpoints with GPU tensors, optimizer state, extra dict keys, etc.
|
||||||
|
and outputs a CPU tensor checkpoint with only the `state_dict` along with SHA256
|
||||||
|
calculation for model zoo compatibility.
|
||||||
|
|
||||||
|
Hacked together by Ross Wightman (https://github.com/rwightman)
|
||||||
|
"""
|
||||||
import torch
|
import torch
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
@ -5,7 +14,7 @@ import hashlib
|
|||||||
import shutil
|
import shutil
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation')
|
parser = argparse.ArgumentParser(description='PyTorch Checkpoint Cleaner')
|
||||||
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
|
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
|
||||||
help='path to latest checkpoint (default: none)')
|
help='path to latest checkpoint (default: none)')
|
||||||
parser.add_argument('--output', default='', type=str, metavar='PATH',
|
parser.add_argument('--output', default='', type=str, metavar='PATH',
|
||||||
|
12
inference.py
Normal file → Executable file
12
inference.py
Normal file → Executable file
@ -1,10 +1,10 @@
|
|||||||
"""Sample PyTorch Inference script
|
#!/usr/bin/env python
|
||||||
|
"""PyTorch Inference Script
|
||||||
|
|
||||||
|
An example inference script that outputs top-k class ids for images in a folder into a csv.
|
||||||
|
|
||||||
|
Hacked together by Ross Wightman (https://github.com/rwightman)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import argparse
|
import argparse
|
||||||
|
@ -5,12 +5,11 @@ import logging
|
|||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
|
||||||
def load_checkpoint(model, checkpoint_path, use_ema=False):
|
def load_state_dict(checkpoint_path, use_ema=False):
|
||||||
if checkpoint_path and os.path.isfile(checkpoint_path):
|
if checkpoint_path and os.path.isfile(checkpoint_path):
|
||||||
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
||||||
state_dict_key = ''
|
|
||||||
if isinstance(checkpoint, dict):
|
|
||||||
state_dict_key = 'state_dict'
|
state_dict_key = 'state_dict'
|
||||||
|
if isinstance(checkpoint, dict):
|
||||||
if use_ema and 'state_dict_ema' in checkpoint:
|
if use_ema and 'state_dict_ema' in checkpoint:
|
||||||
state_dict_key = 'state_dict_ema'
|
state_dict_key = 'state_dict_ema'
|
||||||
if state_dict_key and state_dict_key in checkpoint:
|
if state_dict_key and state_dict_key in checkpoint:
|
||||||
@ -19,15 +18,21 @@ def load_checkpoint(model, checkpoint_path, use_ema=False):
|
|||||||
# strip `module.` prefix
|
# strip `module.` prefix
|
||||||
name = k[7:] if k.startswith('module') else k
|
name = k[7:] if k.startswith('module') else k
|
||||||
new_state_dict[name] = v
|
new_state_dict[name] = v
|
||||||
model.load_state_dict(new_state_dict)
|
state_dict = new_state_dict
|
||||||
else:
|
else:
|
||||||
model.load_state_dict(checkpoint)
|
state_dict = checkpoint
|
||||||
logging.info("Loaded {} from checkpoint '{}'".format(state_dict_key or 'weights', checkpoint_path))
|
logging.info("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path))
|
||||||
|
return state_dict
|
||||||
else:
|
else:
|
||||||
logging.error("No checkpoint found at '{}'".format(checkpoint_path))
|
logging.error("No checkpoint found at '{}'".format(checkpoint_path))
|
||||||
raise FileNotFoundError()
|
raise FileNotFoundError()
|
||||||
|
|
||||||
|
|
||||||
|
def load_checkpoint(model, checkpoint_path, use_ema=False):
|
||||||
|
state_dict = load_state_dict(checkpoint_path, use_ema)
|
||||||
|
model.load_state_dict(state_dict)
|
||||||
|
|
||||||
|
|
||||||
def resume_checkpoint(model, checkpoint_path):
|
def resume_checkpoint(model, checkpoint_path):
|
||||||
other_state = {}
|
other_state = {}
|
||||||
resume_epoch = None
|
resume_epoch = None
|
||||||
|
17
train.py
Normal file → Executable file
17
train.py
Normal file → Executable file
@ -1,4 +1,19 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
""" ImageNet Training Script
|
||||||
|
|
||||||
|
This is intended to be a lean and easily modifiable ImageNet training script that reproduces ImageNet
|
||||||
|
training results with some of the latest networks and training techniques. It favours canonical PyTorch
|
||||||
|
and standard Python style over trying to be able to 'do it all.' That said, it offers quite a few speed
|
||||||
|
and training result improvements over the usual PyTorch example scripts. Repurpose as you see fit.
|
||||||
|
|
||||||
|
This script was started from an early version of the PyTorch ImageNet example
|
||||||
|
(https://github.com/pytorch/examples/tree/master/imagenet)
|
||||||
|
|
||||||
|
NVIDIA CUDA specific speedups adopted from NVIDIA Apex examples
|
||||||
|
(https://github.com/NVIDIA/apex/tree/master/examples/imagenet)
|
||||||
|
|
||||||
|
Hacked together by Ross Wightman (https://github.com/rwightman)
|
||||||
|
"""
|
||||||
import argparse
|
import argparse
|
||||||
import time
|
import time
|
||||||
import logging
|
import logging
|
||||||
@ -35,7 +50,7 @@ parser.add_argument('-c', '--config', default='', type=str, metavar='FILE',
|
|||||||
help='YAML config file specifying default arguments')
|
help='YAML config file specifying default arguments')
|
||||||
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='Training')
|
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
|
||||||
# Dataset / Model parameters
|
# Dataset / Model parameters
|
||||||
parser.add_argument('data', metavar='DIR',
|
parser.add_argument('data', metavar='DIR',
|
||||||
help='path to dataset')
|
help='path to dataset')
|
||||||
|
14
validate.py
Normal file → Executable file
14
validate.py
Normal file → Executable file
@ -1,7 +1,12 @@
|
|||||||
from __future__ import absolute_import
|
#!/usr/bin/env python
|
||||||
from __future__ import division
|
""" ImageNet Validation Script
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
|
This is intended to be a lean and easily modifiable ImageNet validation script for evaluating pretrained
|
||||||
|
models or training checkpoints against ImageNet or similarly organized image datasets. It prioritizes
|
||||||
|
canonical PyTorch, standard Python style, and good performance. Repurpose as you see fit.
|
||||||
|
|
||||||
|
Hacked together by Ross Wightman (https://github.com/rwightman)
|
||||||
|
"""
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import csv
|
import csv
|
||||||
@ -182,6 +187,7 @@ def main():
|
|||||||
# validate all checkpoints in a path with same model
|
# validate all checkpoints in a path with same model
|
||||||
checkpoints = glob.glob(args.checkpoint + '/*.pth.tar')
|
checkpoints = glob.glob(args.checkpoint + '/*.pth.tar')
|
||||||
checkpoints += glob.glob(args.checkpoint + '/*.pth')
|
checkpoints += glob.glob(args.checkpoint + '/*.pth')
|
||||||
|
model_names = list_models(args.model)
|
||||||
model_cfgs = [(args.model, c) for c in sorted(checkpoints, key=natural_key)]
|
model_cfgs = [(args.model, c) for c in sorted(checkpoints, key=natural_key)]
|
||||||
else:
|
else:
|
||||||
if args.model == 'all':
|
if args.model == 'all':
|
||||||
@ -195,7 +201,7 @@ def main():
|
|||||||
model_cfgs = [(n, '') for n in model_names]
|
model_cfgs = [(n, '') for n in model_names]
|
||||||
|
|
||||||
if len(model_cfgs):
|
if len(model_cfgs):
|
||||||
print('Running bulk validation on these pretrained models:', ', '.join(model_names))
|
logging.info('Running bulk validation on these pretrained models: {}'.format(', '.join(model_names)))
|
||||||
header_written = False
|
header_written = False
|
||||||
with open('./results-all.csv', mode='w') as cf:
|
with open('./results-all.csv', mode='w') as cf:
|
||||||
for m, c in model_cfgs:
|
for m, c in model_cfgs:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user