1. Added a simple multi step LR scheduler
parent
6d8272e92c
commit
daab57a6d9
|
@ -0,0 +1,65 @@
|
|||
""" MultiStep LR Scheduler
|
||||
|
||||
Basic multi step LR schedule with warmup, noise.
|
||||
"""
|
||||
import torch
|
||||
import bisect
|
||||
from timm.scheduler.scheduler import Scheduler
|
||||
from typing import List
|
||||
|
||||
class MultiStepLRScheduler(Scheduler):
|
||||
"""
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
decay_t: List[int],
|
||||
decay_rate: float = 1.,
|
||||
warmup_t=0,
|
||||
warmup_lr_init=0,
|
||||
t_in_epochs=True,
|
||||
noise_range_t=None,
|
||||
noise_pct=0.67,
|
||||
noise_std=1.0,
|
||||
noise_seed=42,
|
||||
initialize=True,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
optimizer, param_group_field="lr",
|
||||
noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
|
||||
initialize=initialize)
|
||||
|
||||
self.decay_t = decay_t
|
||||
self.decay_rate = decay_rate
|
||||
self.warmup_t = warmup_t
|
||||
self.warmup_lr_init = warmup_lr_init
|
||||
self.t_in_epochs = t_in_epochs
|
||||
if self.warmup_t:
|
||||
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
|
||||
super().update_groups(self.warmup_lr_init)
|
||||
else:
|
||||
self.warmup_steps = [1 for _ in self.base_values]
|
||||
|
||||
def get_curr_decay_steps(self, t):
|
||||
# find where in the array t goes,
|
||||
# assumes self.decay_t is sorted
|
||||
return bisect.bisect_right(self.decay_t, t+1)
|
||||
|
||||
def _get_lr(self, t):
|
||||
if t < self.warmup_t:
|
||||
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
|
||||
else:
|
||||
lrs = [v * (self.decay_rate ** self.get_curr_decay_steps(t)) for v in self.base_values]
|
||||
return lrs
|
||||
|
||||
def get_epoch_values(self, epoch: int):
|
||||
if self.t_in_epochs:
|
||||
return self._get_lr(epoch)
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_update_values(self, num_updates: int):
|
||||
if not self.t_in_epochs:
|
||||
return self._get_lr(num_updates)
|
||||
else:
|
||||
return None
|
|
@ -5,6 +5,7 @@ from .cosine_lr import CosineLRScheduler
|
|||
from .tanh_lr import TanhLRScheduler
|
||||
from .step_lr import StepLRScheduler
|
||||
from .plateau_lr import PlateauLRScheduler
|
||||
from .multistep_lr import MultiStepLRScheduler
|
||||
|
||||
|
||||
def create_scheduler(args, optimizer):
|
||||
|
@ -67,6 +68,18 @@ def create_scheduler(args, optimizer):
|
|||
noise_std=getattr(args, 'lr_noise_std', 1.),
|
||||
noise_seed=getattr(args, 'seed', 42),
|
||||
)
|
||||
elif args.sched == 'multistep':
|
||||
lr_scheduler = MultiStepLRScheduler(
|
||||
optimizer,
|
||||
decay_t=args.decay_epochs,
|
||||
decay_rate=args.decay_rate,
|
||||
warmup_lr_init=args.warmup_lr,
|
||||
warmup_t=args.warmup_epochs,
|
||||
noise_range_t=noise_range,
|
||||
noise_pct=getattr(args, 'lr_noise_pct', 0.67),
|
||||
noise_std=getattr(args, 'lr_noise_std', 1.),
|
||||
noise_seed=getattr(args, 'seed', 42),
|
||||
)
|
||||
elif args.sched == 'plateau':
|
||||
mode = 'min' if 'loss' in getattr(args, 'eval_metric', '') else 'max'
|
||||
lr_scheduler = PlateauLRScheduler(
|
||||
|
|
Loading…
Reference in New Issue