mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add head-init-scale and head-init-bias args that works for all models, fix #1718
This commit is contained in:
parent
34df125be6
commit
ec6cca4b37
11
train.py
11
train.py
@ -141,6 +141,10 @@ group.add_argument('--grad-checkpointing', action='store_true', default=False,
|
||||
group.add_argument('--fast-norm', default=False, action='store_true',
|
||||
help='enable experimental fast-norm')
|
||||
group.add_argument('--model-kwargs', nargs='*', default={}, action=utils.ParseKwargs)
|
||||
group.add_argument('--head-init-scale', default=None, type=float,
|
||||
help='Head initialization scale')
|
||||
group.add_argument('--head-init-bias', default=None, type=float,
|
||||
help='Head initialization bias value')
|
||||
|
||||
scripting_group = group.add_mutually_exclusive_group()
|
||||
scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true',
|
||||
@ -427,6 +431,13 @@ def main():
|
||||
checkpoint_path=args.initial_checkpoint,
|
||||
**args.model_kwargs,
|
||||
)
|
||||
if args.head_init_scale is not None:
|
||||
with torch.no_grad():
|
||||
model.get_classifier().weight.mul_(args.head_init_scale)
|
||||
model.get_classifier().bias.mul_(args.head_init_scale)
|
||||
if args.head_init_bias is not None:
|
||||
nn.init.constant_(model.get_classifier().bias, args.head_init_bias)
|
||||
|
||||
if args.num_classes is None:
|
||||
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
|
||||
args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly
|
||||
|
Loading…
x
Reference in New Issue
Block a user