mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add support to clean_checkpoint.py to remove aux_bn weights/biases from SplitBatchNorm
This commit is contained in:
parent
2a88412413
commit
cc0b1f4130
@ -21,7 +21,8 @@ parser.add_argument('--output', default='', type=str, metavar='PATH',
|
||||
help='output path')
|
||||
parser.add_argument('--use-ema', dest='use_ema', action='store_true',
|
||||
help='use ema version of weights if present')
|
||||
|
||||
parser.add_argument('--clean-aux-bn', dest='clean_aux_bn', action='store_true',
|
||||
help='remove auxiliary batch norm layers (from SplitBN training) from checkpoint')
|
||||
|
||||
_TEMP_NAME = './_checkpoint.pth'
|
||||
|
||||
@ -48,6 +49,10 @@ def main():
|
||||
else:
|
||||
assert False
|
||||
for k, v in state_dict.items():
|
||||
if args.clean_aux_bn and 'aux_bn' in k:
|
||||
# If all aux_bn keys are removed, the SplitBN layers will end up as normal and
|
||||
# load with the unmodified model using BatchNorm2d.
|
||||
continue
|
||||
name = k[7:] if k.startswith('module') else k
|
||||
new_state_dict[name] = v
|
||||
print("=> Loaded state_dict from '{}'".format(args.checkpoint))
|
||||
|
Loading…
x
Reference in New Issue
Block a user