Add head-init-scale and head-init-bias args that works for all models, fix #1718

This commit is contained in:
Ross Wightman 2023-04-14 17:59:23 -07:00
parent 34df125be6
commit ec6cca4b37

View File

@ -141,6 +141,10 @@ group.add_argument('--grad-checkpointing', action='store_true', default=False,
group.add_argument('--fast-norm', default=False, action='store_true',
help='enable experimental fast-norm')
group.add_argument('--model-kwargs', nargs='*', default={}, action=utils.ParseKwargs)
group.add_argument('--head-init-scale', default=None, type=float,
help='Head initialization scale')
group.add_argument('--head-init-bias', default=None, type=float,
help='Head initialization bias value')
scripting_group = group.add_mutually_exclusive_group()
scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true',
@ -427,6 +431,13 @@ def main():
checkpoint_path=args.initial_checkpoint,
**args.model_kwargs,
)
if args.head_init_scale is not None:
with torch.no_grad():
model.get_classifier().weight.mul_(args.head_init_scale)
model.get_classifier().bias.mul_(args.head_init_scale)
if args.head_init_bias is not None:
nn.init.constant_(model.get_classifier().bias, args.head_init_bias)
if args.num_classes is None:
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly