2020-07-27 13:44:56 -07:00
|
|
|
""" Scheduler Factory
|
2021-09-01 17:33:11 -07:00
|
|
|
Hacked together by / Copyright 2021 Ross Wightman
|
2020-07-27 13:44:56 -07:00
|
|
|
"""
|
2019-06-19 17:19:37 -07:00
|
|
|
from .cosine_lr import CosineLRScheduler
|
2021-07-09 16:18:27 -04:00
|
|
|
from .multistep_lr import MultiStepLRScheduler
|
2021-09-01 17:33:11 -07:00
|
|
|
from .plateau_lr import PlateauLRScheduler
|
|
|
|
from .poly_lr import PolyLRScheduler
|
|
|
|
from .step_lr import StepLRScheduler
|
|
|
|
from .tanh_lr import TanhLRScheduler
|
2019-04-11 21:32:16 -07:00
|
|
|
|
|
|
|
|
|
|
|
def create_scheduler(args, optimizer):
|
|
|
|
num_epochs = args.epochs
|
2020-02-22 16:23:15 -08:00
|
|
|
|
2020-06-12 13:33:12 -07:00
|
|
|
if getattr(args, 'lr_noise', None) is not None:
|
|
|
|
lr_noise = getattr(args, 'lr_noise')
|
|
|
|
if isinstance(lr_noise, (list, tuple)):
|
|
|
|
noise_range = [n * num_epochs for n in lr_noise]
|
2020-02-29 20:37:20 -08:00
|
|
|
if len(noise_range) == 1:
|
|
|
|
noise_range = noise_range[0]
|
2020-02-22 16:23:15 -08:00
|
|
|
else:
|
2020-06-12 13:33:12 -07:00
|
|
|
noise_range = lr_noise * num_epochs
|
2020-02-22 16:23:15 -08:00
|
|
|
else:
|
|
|
|
noise_range = None
|
2021-08-14 22:53:17 +02:00
|
|
|
noise_args = dict(
|
|
|
|
noise_range_t=noise_range,
|
|
|
|
noise_pct=getattr(args, 'lr_noise_pct', 0.67),
|
|
|
|
noise_std=getattr(args, 'lr_noise_std', 1.),
|
|
|
|
noise_seed=getattr(args, 'seed', 42),
|
|
|
|
)
|
2021-09-01 17:33:11 -07:00
|
|
|
cycle_args = dict(
|
|
|
|
cycle_mul=getattr(args, 'lr_cycle_mul', 1.),
|
|
|
|
cycle_decay=getattr(args, 'lr_cycle_decay', 0.1),
|
|
|
|
cycle_limit=getattr(args, 'lr_cycle_limit', 1),
|
|
|
|
)
|
2020-02-22 16:23:15 -08:00
|
|
|
|
2019-08-28 00:14:10 -07:00
|
|
|
lr_scheduler = None
|
2019-04-11 21:32:16 -07:00
|
|
|
if args.sched == 'cosine':
|
|
|
|
lr_scheduler = CosineLRScheduler(
|
|
|
|
optimizer,
|
|
|
|
t_initial=num_epochs,
|
2019-07-26 09:35:31 -07:00
|
|
|
lr_min=args.min_lr,
|
2019-04-11 21:32:16 -07:00
|
|
|
warmup_lr_init=args.warmup_lr,
|
|
|
|
warmup_t=args.warmup_epochs,
|
2021-09-01 17:33:11 -07:00
|
|
|
k_decay=getattr(args, 'lr_k_decay', 1.0),
|
|
|
|
**cycle_args,
|
2021-08-14 22:53:17 +02:00
|
|
|
**noise_args,
|
2019-04-11 21:32:16 -07:00
|
|
|
)
|
2019-07-26 09:35:31 -07:00
|
|
|
num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs
|
2019-04-11 21:32:16 -07:00
|
|
|
elif args.sched == 'tanh':
|
|
|
|
lr_scheduler = TanhLRScheduler(
|
|
|
|
optimizer,
|
|
|
|
t_initial=num_epochs,
|
2019-07-26 09:35:31 -07:00
|
|
|
lr_min=args.min_lr,
|
2019-04-11 21:32:16 -07:00
|
|
|
warmup_lr_init=args.warmup_lr,
|
|
|
|
warmup_t=args.warmup_epochs,
|
|
|
|
t_in_epochs=True,
|
2021-09-01 17:33:11 -07:00
|
|
|
**cycle_args,
|
2021-08-14 22:53:17 +02:00
|
|
|
**noise_args,
|
2019-04-11 21:32:16 -07:00
|
|
|
)
|
2019-07-26 09:35:31 -07:00
|
|
|
num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs
|
2019-08-28 00:14:10 -07:00
|
|
|
elif args.sched == 'step':
|
2019-04-11 21:32:16 -07:00
|
|
|
lr_scheduler = StepLRScheduler(
|
|
|
|
optimizer,
|
|
|
|
decay_t=args.decay_epochs,
|
2021-07-09 16:18:27 -04:00
|
|
|
decay_rate=args.decay_rate,
|
|
|
|
warmup_lr_init=args.warmup_lr,
|
|
|
|
warmup_t=args.warmup_epochs,
|
2021-08-14 22:53:17 +02:00
|
|
|
**noise_args,
|
2021-07-09 16:18:27 -04:00
|
|
|
)
|
|
|
|
elif args.sched == 'multistep':
|
|
|
|
lr_scheduler = MultiStepLRScheduler(
|
|
|
|
optimizer,
|
2022-05-10 07:57:19 +09:00
|
|
|
decay_t=args.decay_milestones,
|
2019-04-11 21:32:16 -07:00
|
|
|
decay_rate=args.decay_rate,
|
|
|
|
warmup_lr_init=args.warmup_lr,
|
|
|
|
warmup_t=args.warmup_epochs,
|
2021-08-14 22:53:17 +02:00
|
|
|
**noise_args,
|
2020-02-22 16:23:15 -08:00
|
|
|
)
|
|
|
|
elif args.sched == 'plateau':
|
2020-06-12 13:33:12 -07:00
|
|
|
mode = 'min' if 'loss' in getattr(args, 'eval_metric', '') else 'max'
|
2020-02-22 16:23:15 -08:00
|
|
|
lr_scheduler = PlateauLRScheduler(
|
|
|
|
optimizer,
|
|
|
|
decay_rate=args.decay_rate,
|
|
|
|
patience_t=args.patience_epochs,
|
|
|
|
lr_min=args.min_lr,
|
2020-06-12 13:33:12 -07:00
|
|
|
mode=mode,
|
2020-02-22 16:23:15 -08:00
|
|
|
warmup_lr_init=args.warmup_lr,
|
|
|
|
warmup_t=args.warmup_epochs,
|
2020-06-12 13:33:12 -07:00
|
|
|
cooldown_t=0,
|
2021-08-14 22:53:17 +02:00
|
|
|
**noise_args,
|
2019-04-11 21:32:16 -07:00
|
|
|
)
|
2021-09-01 17:33:11 -07:00
|
|
|
elif args.sched == 'poly':
|
|
|
|
lr_scheduler = PolyLRScheduler(
|
|
|
|
optimizer,
|
|
|
|
power=args.decay_rate, # overloading 'decay_rate' as polynomial power
|
|
|
|
t_initial=num_epochs,
|
|
|
|
lr_min=args.min_lr,
|
|
|
|
warmup_lr_init=args.warmup_lr,
|
|
|
|
warmup_t=args.warmup_epochs,
|
|
|
|
k_decay=getattr(args, 'lr_k_decay', 1.0),
|
|
|
|
**cycle_args,
|
|
|
|
**noise_args,
|
|
|
|
)
|
|
|
|
num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs
|
2020-02-22 16:23:15 -08:00
|
|
|
|
2019-04-11 21:32:16 -07:00
|
|
|
return lr_scheduler, num_epochs
|