From cc0b1f41305489018da6360e1f2a1a32e4d337b7 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 12 Jan 2020 17:52:19 -0800 Subject: [PATCH] Add support to clean_checkpoint.py to remove aux_bn weights/biases from SplitBatchNorm --- clean_checkpoint.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/clean_checkpoint.py b/clean_checkpoint.py index fef3104a..bc86f2ac 100755 --- a/clean_checkpoint.py +++ b/clean_checkpoint.py @@ -21,7 +21,8 @@ parser.add_argument('--output', default='', type=str, metavar='PATH', help='output path') parser.add_argument('--use-ema', dest='use_ema', action='store_true', help='use ema version of weights if present') - +parser.add_argument('--clean-aux-bn', dest='clean_aux_bn', action='store_true', + help='remove auxiliary batch norm layers (from SplitBN training) from checkpoint') _TEMP_NAME = './_checkpoint.pth' @@ -48,6 +49,10 @@ def main(): else: assert False for k, v in state_dict.items(): + if args.clean_aux_bn and 'aux_bn' in k: + # If all aux_bn keys are removed, the SplitBN layers will end up as normal and + # load with the unmodified model using BatchNorm2d. + continue name = k[7:] if k.startswith('module') else k new_state_dict[name] = v print("=> Loaded state_dict from '{}'".format(args.checkpoint))