mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Clean a1/a2/3 rsb _0 checkpoints properly, fix v2 loading.
This commit is contained in:
parent
d123042605
commit
ae1ff5792f
@ -13,6 +13,7 @@ import os
|
|||||||
import hashlib
|
import hashlib
|
||||||
import shutil
|
import shutil
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
from timm.models.helpers import load_state_dict
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='PyTorch Checkpoint Cleaner')
|
parser = argparse.ArgumentParser(description='PyTorch Checkpoint Cleaner')
|
||||||
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
|
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
|
||||||
@ -37,17 +38,8 @@ def main():
|
|||||||
# Load an existing checkpoint to CPU, strip everything but the state_dict and re-save
|
# Load an existing checkpoint to CPU, strip everything but the state_dict and re-save
|
||||||
if args.checkpoint and os.path.isfile(args.checkpoint):
|
if args.checkpoint and os.path.isfile(args.checkpoint):
|
||||||
print("=> Loading checkpoint '{}'".format(args.checkpoint))
|
print("=> Loading checkpoint '{}'".format(args.checkpoint))
|
||||||
checkpoint = torch.load(args.checkpoint, map_location='cpu')
|
state_dict = load_state_dict(args.checkpoint, use_ema=args.use_ema)
|
||||||
|
new_state_dict = {}
|
||||||
new_state_dict = OrderedDict()
|
|
||||||
if isinstance(checkpoint, dict):
|
|
||||||
state_dict_key = 'state_dict_ema' if args.use_ema else 'state_dict'
|
|
||||||
if state_dict_key in checkpoint:
|
|
||||||
state_dict = checkpoint[state_dict_key]
|
|
||||||
else:
|
|
||||||
state_dict = checkpoint
|
|
||||||
else:
|
|
||||||
assert False
|
|
||||||
for k, v in state_dict.items():
|
for k, v in state_dict.items():
|
||||||
if args.clean_aux_bn and 'aux_bn' in k:
|
if args.clean_aux_bn and 'aux_bn' in k:
|
||||||
# If all aux_bn keys are removed, the SplitBN layers will end up as normal and
|
# If all aux_bn keys are removed, the SplitBN layers will end up as normal and
|
||||||
|
@ -53,7 +53,7 @@ default_cfgs = {
|
|||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/resnet26t_256_ra2-6f6fa748.pth',
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/resnet26t_256_ra2-6f6fa748.pth',
|
||||||
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8)),
|
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8)),
|
||||||
'resnet50': _cfg(
|
'resnet50': _cfg(
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a1_0-00ca2c6a.pth',
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a1_0-14fe96d1.pth',
|
||||||
interpolation='bicubic'),
|
interpolation='bicubic'),
|
||||||
'resnet50d': _cfg(
|
'resnet50d': _cfg(
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet50d_ra2-464e36ba.pth',
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet50d_ra2-464e36ba.pth',
|
||||||
|
@ -471,7 +471,7 @@ def _create_resnetv2(variant, pretrained=False, **kwargs):
|
|||||||
ResNetV2, variant, pretrained,
|
ResNetV2, variant, pretrained,
|
||||||
default_cfg=default_cfgs[variant],
|
default_cfg=default_cfgs[variant],
|
||||||
feature_cfg=feature_cfg,
|
feature_cfg=feature_cfg,
|
||||||
pretrained_custom_load=True,
|
pretrained_custom_load='_bit' in variant,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user