From ec6cca4b37ecf1f4c6303e354730f3e3dbc596ba Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 14 Apr 2023 17:59:23 -0700 Subject: [PATCH] Add head-init-scale and head-init-bias args that works for all models, fix #1718 --- train.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/train.py b/train.py index 816f4ae8..83663cc6 100755 --- a/train.py +++ b/train.py @@ -141,6 +141,10 @@ group.add_argument('--grad-checkpointing', action='store_true', default=False, group.add_argument('--fast-norm', default=False, action='store_true', help='enable experimental fast-norm') group.add_argument('--model-kwargs', nargs='*', default={}, action=utils.ParseKwargs) +group.add_argument('--head-init-scale', default=None, type=float, + help='Head initialization scale') +group.add_argument('--head-init-bias', default=None, type=float, + help='Head initialization bias value') scripting_group = group.add_mutually_exclusive_group() scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true', @@ -427,6 +431,13 @@ def main(): checkpoint_path=args.initial_checkpoint, **args.model_kwargs, ) + if args.head_init_scale is not None: + with torch.no_grad(): + model.get_classifier().weight.mul_(args.head_init_scale) + model.get_classifier().bias.mul_(args.head_init_scale) + if args.head_init_bias is not None: + nn.init.constant_(model.get_classifier().bias, args.head_init_bias) + if args.num_classes is None: assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly