mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add support to clean_checkpoint.py to remove aux_bn weights/biases from SplitBatchNorm
This commit is contained in:
parent
2a88412413
commit
cc0b1f4130
@ -21,7 +21,8 @@ parser.add_argument('--output', default='', type=str, metavar='PATH',
|
|||||||
help='output path')
|
help='output path')
|
||||||
parser.add_argument('--use-ema', dest='use_ema', action='store_true',
|
parser.add_argument('--use-ema', dest='use_ema', action='store_true',
|
||||||
help='use ema version of weights if present')
|
help='use ema version of weights if present')
|
||||||
|
parser.add_argument('--clean-aux-bn', dest='clean_aux_bn', action='store_true',
|
||||||
|
help='remove auxiliary batch norm layers (from SplitBN training) from checkpoint')
|
||||||
|
|
||||||
_TEMP_NAME = './_checkpoint.pth'
|
_TEMP_NAME = './_checkpoint.pth'
|
||||||
|
|
||||||
@ -48,6 +49,10 @@ def main():
|
|||||||
else:
|
else:
|
||||||
assert False
|
assert False
|
||||||
for k, v in state_dict.items():
|
for k, v in state_dict.items():
|
||||||
|
if args.clean_aux_bn and 'aux_bn' in k:
|
||||||
|
# If all aux_bn keys are removed, the SplitBN layers will end up as normal and
|
||||||
|
# load with the unmodified model using BatchNorm2d.
|
||||||
|
continue
|
||||||
name = k[7:] if k.startswith('module') else k
|
name = k[7:] if k.startswith('module') else k
|
||||||
new_state_dict[name] = v
|
new_state_dict[name] = v
|
||||||
print("=> Loaded state_dict from '{}'".format(args.checkpoint))
|
print("=> Loaded state_dict from '{}'".format(args.checkpoint))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user