Fix learning rate scaling to include gradient accumulation steps
parent
95b61eb500
commit
235eac76c9
dinov2/utils
|
@ -22,7 +22,7 @@ def apply_scaling_rules_to_cfg(cfg): # to fix
|
|||
if cfg.optim.scaling_rule == "sqrt_wrt_1024":
|
||||
base_lr = cfg.optim.base_lr
|
||||
cfg.optim.lr = base_lr
|
||||
cfg.optim.lr *= math.sqrt(cfg.train.batch_size_per_gpu * distributed.get_global_size() / 1024.0)
|
||||
cfg.optim.lr *= math.sqrt(cfg.train.batch_size_per_gpu * cfg.train.grad_accum_steps * distributed.get_global_size() / 1024.0)
|
||||
logger.info(f"sqrt scaling learning rate; base: {base_lr}, new: {cfg.optim.lr}")
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
@ -43,6 +43,7 @@ def get_cfg_from_args(args):
|
|||
default_cfg = OmegaConf.create(dinov2_default_config)
|
||||
cfg = OmegaConf.load(args.config_file)
|
||||
cfg = OmegaConf.merge(default_cfg, cfg, OmegaConf.from_cli(args.opts))
|
||||
cfg.train.grad_accum_steps = max(1, getattr(cfg.train, "grad_accum_steps", 1))
|
||||
return cfg
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue