diff --git a/clean_checkpoint.py b/clean_checkpoint.py index fef3104a..bc86f2ac 100755 --- a/clean_checkpoint.py +++ b/clean_checkpoint.py @@ -21,7 +21,8 @@ parser.add_argument('--output', default='', type=str, metavar='PATH', help='output path') parser.add_argument('--use-ema', dest='use_ema', action='store_true', help='use ema version of weights if present') - +parser.add_argument('--clean-aux-bn', dest='clean_aux_bn', action='store_true', + help='remove auxiliary batch norm layers (from SplitBN training) from checkpoint') _TEMP_NAME = './_checkpoint.pth' @@ -48,6 +49,10 @@ def main(): else: assert False for k, v in state_dict.items(): + if args.clean_aux_bn and 'aux_bn' in k: + # If all aux_bn keys are removed, the SplitBN layers will end up as normal and + # load with the unmodified model using BatchNorm2d. + continue name = k[7:] if k.startswith('module') else k new_state_dict[name] = v print("=> Loaded state_dict from '{}'".format(args.checkpoint))