mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add checkpoint clean script, add link to pretrained resnext50 weights
This commit is contained in:
parent
6e9697eb9c
commit
8a33a6c90a
45
clean_checkpoint.py
Normal file
45
clean_checkpoint.py
Normal file
@ -0,0 +1,45 @@
|
||||
import torch
|
||||
import argparse
|
||||
import os
|
||||
import hashlib
|
||||
from collections import OrderedDict
|
||||
|
||||
parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation')
|
||||
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
|
||||
help='path to latest checkpoint (default: none)')
|
||||
parser.add_argument('--output', default='./cleaned.pth', type=str, metavar='PATH',
|
||||
help='output path')
|
||||
|
||||
|
||||
def main():
|
||||
args = parser.parse_args()
|
||||
|
||||
if os.path.exists(args.output):
|
||||
print("Error: Output filename ({}) already exists.".format(args.output))
|
||||
exit(1)
|
||||
|
||||
# Load an existing checkpoint to CPU, strip everything but the state_dict and re-save
|
||||
if args.checkpoint and os.path.isfile(args.checkpoint):
|
||||
print("=> Loading checkpoint '{}'".format(args.checkpoint))
|
||||
checkpoint = torch.load(args.checkpoint, map_location='cpu')
|
||||
|
||||
new_state_dict = OrderedDict()
|
||||
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
||||
state_dict = checkpoint['state_dict']
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
for k, v in state_dict.items():
|
||||
name = k[7:] if k.startswith('module') else k
|
||||
new_state_dict[name] = v
|
||||
print("=> Loaded state_dict from '{}'".format(args.checkpoint))
|
||||
|
||||
torch.save(new_state_dict, args.output)
|
||||
with open(args.output, 'rb') as f:
|
||||
sha_hash = hashlib.sha256(f.read()).hexdigest()
|
||||
print("=> Saved state_dict to '{}, SHA256: {}'".format(args.output, sha_hash))
|
||||
else:
|
||||
print("Error: Checkpoint ({}) doesn't exist".format(args.checkpoint))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -16,13 +16,14 @@ __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152
|
||||
'resnext50_32x4d', 'resnext101_32x4d', 'resnext101_64x4d', 'resnext152_32x4d']
|
||||
|
||||
|
||||
def _cfg(url=''):
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||
'crop_pct': 0.875, 'interpolation': 'bilinear',
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'conv1', 'classifier': 'fc',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
@ -32,7 +33,8 @@ default_cfgs = {
|
||||
'resnet50': _cfg(url='https://download.pytorch.org/models/resnet50-19c8e357.pth'),
|
||||
'resnet101': _cfg(url='https://download.pytorch.org/models/resnet101-5d3b4d8f.pth'),
|
||||
'resnet152': _cfg(url='https://download.pytorch.org/models/resnet152-b121ed2d.pth'),
|
||||
'resnext50_32x4d': _cfg(url=''),
|
||||
'resnext50_32x4d': _cfg(url='https://www.dropbox.com/s/yxci33lfew51p6a/resnext50_32x4d-068914d1.pth?dl=1',
|
||||
interpolation='bicubic'),
|
||||
'resnext101_32x4d': _cfg(url=''),
|
||||
'resnext101_64x4d': _cfg(url=''),
|
||||
'resnext152_32x4d': _cfg(url=''),
|
||||
|
@ -23,7 +23,7 @@ __all__ = ['SENet', 'senet154', 'seresnet50', 'seresnet101', 'seresnet152',
|
||||
'seresnext50_32x4d', 'seresnext101_32x4d']
|
||||
|
||||
|
||||
def _cfg(url=''):
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||
'crop_pct': 0.875, 'interpolation': 'bilinear',
|
||||
|
Loading…
x
Reference in New Issue
Block a user