mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
LR scheduler update:
* add polynomial decay 'poly' * cleanup cycle specific args for cosine, poly, and tanh sched, t_mul -> cycle_mul, decay -> cycle_decay, default cycle_limit to 1 in each opt * add k-decay for cosine and poly sched as per https://arxiv.org/abs/2004.05909 * change default tanh ub/lb to push inflection to later epochs
This commit is contained in:
parent
492c0a4e20
commit
29a37e23ee
@ -1,5 +1,8 @@
|
||||
from .cosine_lr import CosineLRScheduler
|
||||
from .multistep_lr import MultiStepLRScheduler
|
||||
from .plateau_lr import PlateauLRScheduler
|
||||
from .poly_lr import PolyLRScheduler
|
||||
from .step_lr import StepLRScheduler
|
||||
from .tanh_lr import TanhLRScheduler
|
||||
|
||||
from .scheduler_factory import create_scheduler
|
||||
|
@ -1,8 +1,8 @@
|
||||
""" Cosine Scheduler
|
||||
|
||||
Cosine LR schedule with warmup, cycle/restarts, noise.
|
||||
Cosine LR schedule with warmup, cycle/restarts, noise, k-decay.
|
||||
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
Hacked together by / Copyright 2021 Ross Wightman
|
||||
"""
|
||||
import logging
|
||||
import math
|
||||
@ -22,23 +22,26 @@ class CosineLRScheduler(Scheduler):
|
||||
|
||||
Inspiration from
|
||||
https://github.com/allenai/allennlp/blob/master/allennlp/training/learning_rate_schedulers/cosine.py
|
||||
|
||||
k-decay option based on `k-decay: A New Method For Learning Rate Schedule` - https://arxiv.org/abs/2004.05909
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
t_initial: int,
|
||||
t_mul: float = 1.,
|
||||
lr_min: float = 0.,
|
||||
decay_rate: float = 1.,
|
||||
cycle_mul: float = 1.,
|
||||
cycle_decay: float = 1.,
|
||||
cycle_limit: int = 1,
|
||||
warmup_t=0,
|
||||
warmup_lr_init=0,
|
||||
warmup_prefix=False,
|
||||
cycle_limit=0,
|
||||
t_in_epochs=True,
|
||||
noise_range_t=None,
|
||||
noise_pct=0.67,
|
||||
noise_std=1.0,
|
||||
noise_seed=42,
|
||||
k_decay=1.0,
|
||||
initialize=True) -> None:
|
||||
super().__init__(
|
||||
optimizer, param_group_field="lr",
|
||||
@ -47,18 +50,19 @@ class CosineLRScheduler(Scheduler):
|
||||
|
||||
assert t_initial > 0
|
||||
assert lr_min >= 0
|
||||
if t_initial == 1 and t_mul == 1 and decay_rate == 1:
|
||||
if t_initial == 1 and cycle_mul == 1 and cycle_decay == 1:
|
||||
_logger.warning("Cosine annealing scheduler will have no effect on the learning "
|
||||
"rate since t_initial = t_mul = eta_mul = 1.")
|
||||
self.t_initial = t_initial
|
||||
self.t_mul = t_mul
|
||||
self.lr_min = lr_min
|
||||
self.decay_rate = decay_rate
|
||||
self.cycle_mul = cycle_mul
|
||||
self.cycle_decay = cycle_decay
|
||||
self.cycle_limit = cycle_limit
|
||||
self.warmup_t = warmup_t
|
||||
self.warmup_lr_init = warmup_lr_init
|
||||
self.warmup_prefix = warmup_prefix
|
||||
self.t_in_epochs = t_in_epochs
|
||||
self.k_decay = k_decay
|
||||
if self.warmup_t:
|
||||
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
|
||||
super().update_groups(self.warmup_lr_init)
|
||||
@ -72,22 +76,23 @@ class CosineLRScheduler(Scheduler):
|
||||
if self.warmup_prefix:
|
||||
t = t - self.warmup_t
|
||||
|
||||
if self.t_mul != 1:
|
||||
i = math.floor(math.log(1 - t / self.t_initial * (1 - self.t_mul), self.t_mul))
|
||||
t_i = self.t_mul ** i * self.t_initial
|
||||
t_curr = t - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial
|
||||
if self.cycle_mul != 1:
|
||||
i = math.floor(math.log(1 - t / self.t_initial * (1 - self.cycle_mul), self.cycle_mul))
|
||||
t_i = self.cycle_mul ** i * self.t_initial
|
||||
t_curr = t - (1 - self.cycle_mul ** i) / (1 - self.cycle_mul) * self.t_initial
|
||||
else:
|
||||
i = t // self.t_initial
|
||||
t_i = self.t_initial
|
||||
t_curr = t - (self.t_initial * i)
|
||||
|
||||
gamma = self.decay_rate ** i
|
||||
lr_min = self.lr_min * gamma
|
||||
gamma = self.cycle_decay ** i
|
||||
lr_max_values = [v * gamma for v in self.base_values]
|
||||
k = self.k_decay
|
||||
|
||||
if self.cycle_limit == 0 or (self.cycle_limit > 0 and i < self.cycle_limit):
|
||||
if i < self.cycle_limit:
|
||||
lrs = [
|
||||
lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(math.pi * t_curr / t_i)) for lr_max in lr_max_values
|
||||
self.lr_min + 0.5 * (lr_max - self.lr_min) * (1 + math.cos(math.pi * t_curr ** k / t_i ** k))
|
||||
for lr_max in lr_max_values
|
||||
]
|
||||
else:
|
||||
lrs = [self.lr_min for _ in self.base_values]
|
||||
@ -107,10 +112,8 @@ class CosineLRScheduler(Scheduler):
|
||||
return None
|
||||
|
||||
def get_cycle_length(self, cycles=0):
|
||||
if not cycles:
|
||||
cycles = self.cycle_limit
|
||||
cycles = max(1, cycles)
|
||||
if self.t_mul == 1.0:
|
||||
cycles = max(1, cycles or self.cycle_limit)
|
||||
if self.cycle_mul == 1.0:
|
||||
return self.t_initial * cycles
|
||||
else:
|
||||
return int(math.floor(-self.t_initial * (self.t_mul ** cycles - 1) / (1 - self.t_mul)))
|
||||
return int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul)))
|
||||
|
116
timm/scheduler/poly_lr.py
Normal file
116
timm/scheduler/poly_lr.py
Normal file
@ -0,0 +1,116 @@
|
||||
""" Polynomial Scheduler
|
||||
|
||||
Polynomial LR schedule with warmup, noise.
|
||||
|
||||
Hacked together by / Copyright 2021 Ross Wightman
|
||||
"""
|
||||
import math
|
||||
import logging
|
||||
|
||||
import torch
|
||||
|
||||
from .scheduler import Scheduler
|
||||
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PolyLRScheduler(Scheduler):
|
||||
""" Polynomial LR Scheduler w/ warmup, noise, and k-decay
|
||||
|
||||
k-decay option based on `k-decay: A New Method For Learning Rate Schedule` - https://arxiv.org/abs/2004.05909
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
t_initial: int,
|
||||
power: float = 0.5,
|
||||
lr_min: float = 0.,
|
||||
cycle_mul: float = 1.,
|
||||
cycle_decay: float = 1.,
|
||||
cycle_limit: int = 1,
|
||||
warmup_t=0,
|
||||
warmup_lr_init=0,
|
||||
warmup_prefix=False,
|
||||
t_in_epochs=True,
|
||||
noise_range_t=None,
|
||||
noise_pct=0.67,
|
||||
noise_std=1.0,
|
||||
noise_seed=42,
|
||||
k_decay=.5,
|
||||
initialize=True) -> None:
|
||||
super().__init__(
|
||||
optimizer, param_group_field="lr",
|
||||
noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
|
||||
initialize=initialize)
|
||||
|
||||
assert t_initial > 0
|
||||
assert lr_min >= 0
|
||||
if t_initial == 1 and cycle_mul == 1 and cycle_decay == 1:
|
||||
_logger.warning("Cosine annealing scheduler will have no effect on the learning "
|
||||
"rate since t_initial = t_mul = eta_mul = 1.")
|
||||
self.t_initial = t_initial
|
||||
self.power = power
|
||||
self.lr_min = lr_min
|
||||
self.cycle_mul = cycle_mul
|
||||
self.cycle_decay = cycle_decay
|
||||
self.cycle_limit = cycle_limit
|
||||
self.warmup_t = warmup_t
|
||||
self.warmup_lr_init = warmup_lr_init
|
||||
self.warmup_prefix = warmup_prefix
|
||||
self.t_in_epochs = t_in_epochs
|
||||
self.k_decay = k_decay
|
||||
if self.warmup_t:
|
||||
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
|
||||
super().update_groups(self.warmup_lr_init)
|
||||
else:
|
||||
self.warmup_steps = [1 for _ in self.base_values]
|
||||
|
||||
def _get_lr(self, t):
|
||||
if t < self.warmup_t:
|
||||
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
|
||||
else:
|
||||
if self.warmup_prefix:
|
||||
t = t - self.warmup_t
|
||||
|
||||
if self.cycle_mul != 1:
|
||||
i = math.floor(math.log(1 - t / self.t_initial * (1 - self.cycle_mul), self.cycle_mul))
|
||||
t_i = self.cycle_mul ** i * self.t_initial
|
||||
t_curr = t - (1 - self.cycle_mul ** i) / (1 - self.cycle_mul) * self.t_initial
|
||||
else:
|
||||
i = t // self.t_initial
|
||||
t_i = self.t_initial
|
||||
t_curr = t - (self.t_initial * i)
|
||||
|
||||
gamma = self.cycle_decay ** i
|
||||
lr_max_values = [v * gamma for v in self.base_values]
|
||||
k = self.k_decay
|
||||
|
||||
if i < self.cycle_limit:
|
||||
lrs = [
|
||||
self.lr_min + (lr_max - self.lr_min) * (1 - t_curr ** k / t_i ** k) ** self.power
|
||||
for lr_max in lr_max_values
|
||||
]
|
||||
else:
|
||||
lrs = [self.lr_min for _ in self.base_values]
|
||||
|
||||
return lrs
|
||||
|
||||
def get_epoch_values(self, epoch: int):
|
||||
if self.t_in_epochs:
|
||||
return self._get_lr(epoch)
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_update_values(self, num_updates: int):
|
||||
if not self.t_in_epochs:
|
||||
return self._get_lr(num_updates)
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_cycle_length(self, cycles=0):
|
||||
cycles = max(1, cycles or self.cycle_limit)
|
||||
if self.cycle_mul == 1.0:
|
||||
return self.t_initial * cycles
|
||||
else:
|
||||
return int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul)))
|
@ -1,11 +1,12 @@
|
||||
""" Scheduler Factory
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
Hacked together by / Copyright 2021 Ross Wightman
|
||||
"""
|
||||
from .cosine_lr import CosineLRScheduler
|
||||
from .tanh_lr import TanhLRScheduler
|
||||
from .step_lr import StepLRScheduler
|
||||
from .plateau_lr import PlateauLRScheduler
|
||||
from .multistep_lr import MultiStepLRScheduler
|
||||
from .plateau_lr import PlateauLRScheduler
|
||||
from .poly_lr import PolyLRScheduler
|
||||
from .step_lr import StepLRScheduler
|
||||
from .tanh_lr import TanhLRScheduler
|
||||
|
||||
|
||||
def create_scheduler(args, optimizer):
|
||||
@ -27,19 +28,22 @@ def create_scheduler(args, optimizer):
|
||||
noise_std=getattr(args, 'lr_noise_std', 1.),
|
||||
noise_seed=getattr(args, 'seed', 42),
|
||||
)
|
||||
cycle_args = dict(
|
||||
cycle_mul=getattr(args, 'lr_cycle_mul', 1.),
|
||||
cycle_decay=getattr(args, 'lr_cycle_decay', 0.1),
|
||||
cycle_limit=getattr(args, 'lr_cycle_limit', 1),
|
||||
)
|
||||
|
||||
lr_scheduler = None
|
||||
if args.sched == 'cosine':
|
||||
lr_scheduler = CosineLRScheduler(
|
||||
optimizer,
|
||||
t_initial=num_epochs,
|
||||
t_mul=getattr(args, 'lr_cycle_mul', 1.),
|
||||
lr_min=args.min_lr,
|
||||
decay_rate=args.decay_rate,
|
||||
warmup_lr_init=args.warmup_lr,
|
||||
warmup_t=args.warmup_epochs,
|
||||
cycle_limit=getattr(args, 'lr_cycle_limit', 1),
|
||||
t_in_epochs=True,
|
||||
k_decay=getattr(args, 'lr_k_decay', 1.0),
|
||||
**cycle_args,
|
||||
**noise_args,
|
||||
)
|
||||
num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs
|
||||
@ -47,12 +51,11 @@ def create_scheduler(args, optimizer):
|
||||
lr_scheduler = TanhLRScheduler(
|
||||
optimizer,
|
||||
t_initial=num_epochs,
|
||||
t_mul=getattr(args, 'lr_cycle_mul', 1.),
|
||||
lr_min=args.min_lr,
|
||||
warmup_lr_init=args.warmup_lr,
|
||||
warmup_t=args.warmup_epochs,
|
||||
cycle_limit=getattr(args, 'lr_cycle_limit', 1),
|
||||
t_in_epochs=True,
|
||||
**cycle_args,
|
||||
**noise_args,
|
||||
)
|
||||
num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs
|
||||
@ -87,5 +90,18 @@ def create_scheduler(args, optimizer):
|
||||
cooldown_t=0,
|
||||
**noise_args,
|
||||
)
|
||||
elif args.sched == 'poly':
|
||||
lr_scheduler = PolyLRScheduler(
|
||||
optimizer,
|
||||
power=args.decay_rate, # overloading 'decay_rate' as polynomial power
|
||||
t_initial=num_epochs,
|
||||
lr_min=args.min_lr,
|
||||
warmup_lr_init=args.warmup_lr,
|
||||
warmup_t=args.warmup_epochs,
|
||||
k_decay=getattr(args, 'lr_k_decay', 1.0),
|
||||
**cycle_args,
|
||||
**noise_args,
|
||||
)
|
||||
num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs
|
||||
|
||||
return lr_scheduler, num_epochs
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
TanH schedule with warmup, cycle/restarts, noise.
|
||||
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
Hacked together by / Copyright 2021 Ross Wightman
|
||||
"""
|
||||
import logging
|
||||
import math
|
||||
@ -24,15 +24,15 @@ class TanhLRScheduler(Scheduler):
|
||||
def __init__(self,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
t_initial: int,
|
||||
lb: float = -6.,
|
||||
ub: float = 4.,
|
||||
t_mul: float = 1.,
|
||||
lb: float = -7.,
|
||||
ub: float = 3.,
|
||||
lr_min: float = 0.,
|
||||
decay_rate: float = 1.,
|
||||
cycle_mul: float = 1.,
|
||||
cycle_decay: float = 1.,
|
||||
cycle_limit: int = 1,
|
||||
warmup_t=0,
|
||||
warmup_lr_init=0,
|
||||
warmup_prefix=False,
|
||||
cycle_limit=0,
|
||||
t_in_epochs=True,
|
||||
noise_range_t=None,
|
||||
noise_pct=0.67,
|
||||
@ -53,9 +53,9 @@ class TanhLRScheduler(Scheduler):
|
||||
self.lb = lb
|
||||
self.ub = ub
|
||||
self.t_initial = t_initial
|
||||
self.t_mul = t_mul
|
||||
self.lr_min = lr_min
|
||||
self.decay_rate = decay_rate
|
||||
self.cycle_mul = cycle_mul
|
||||
self.cycle_decay = cycle_decay
|
||||
self.cycle_limit = cycle_limit
|
||||
self.warmup_t = warmup_t
|
||||
self.warmup_lr_init = warmup_lr_init
|
||||
@ -75,27 +75,26 @@ class TanhLRScheduler(Scheduler):
|
||||
if self.warmup_prefix:
|
||||
t = t - self.warmup_t
|
||||
|
||||
if self.t_mul != 1:
|
||||
i = math.floor(math.log(1 - t / self.t_initial * (1 - self.t_mul), self.t_mul))
|
||||
t_i = self.t_mul ** i * self.t_initial
|
||||
t_curr = t - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial
|
||||
if self.cycle_mul != 1:
|
||||
i = math.floor(math.log(1 - t / self.t_initial * (1 - self.cycle_mul), self.cycle_mul))
|
||||
t_i = self.cycle_mul ** i * self.t_initial
|
||||
t_curr = t - (1 - self.cycle_mul ** i) / (1 - self.cycle_mul) * self.t_initial
|
||||
else:
|
||||
i = t // self.t_initial
|
||||
t_i = self.t_initial
|
||||
t_curr = t - (self.t_initial * i)
|
||||
|
||||
if self.cycle_limit == 0 or (self.cycle_limit > 0 and i < self.cycle_limit):
|
||||
gamma = self.decay_rate ** i
|
||||
lr_min = self.lr_min * gamma
|
||||
if i < self.cycle_limit:
|
||||
gamma = self.cycle_decay ** i
|
||||
lr_max_values = [v * gamma for v in self.base_values]
|
||||
|
||||
tr = t_curr / t_i
|
||||
lrs = [
|
||||
lr_min + 0.5 * (lr_max - lr_min) * (1 - math.tanh(self.lb * (1. - tr) + self.ub * tr))
|
||||
self.lr_min + 0.5 * (lr_max - self.lr_min) * (1 - math.tanh(self.lb * (1. - tr) + self.ub * tr))
|
||||
for lr_max in lr_max_values
|
||||
]
|
||||
else:
|
||||
lrs = [self.lr_min * (self.decay_rate ** self.cycle_limit) for _ in self.base_values]
|
||||
lrs = [self.lr_min for _ in self.base_values]
|
||||
return lrs
|
||||
|
||||
def get_epoch_values(self, epoch: int):
|
||||
@ -111,10 +110,8 @@ class TanhLRScheduler(Scheduler):
|
||||
return None
|
||||
|
||||
def get_cycle_length(self, cycles=0):
|
||||
if not cycles:
|
||||
cycles = self.cycle_limit
|
||||
cycles = max(1, cycles)
|
||||
if self.t_mul == 1.0:
|
||||
cycles = max(1, cycles or self.cycle_limit)
|
||||
if self.cycle_mul == 1.0:
|
||||
return self.t_initial * cycles
|
||||
else:
|
||||
return int(math.floor(-self.t_initial * (self.t_mul ** cycles - 1) / (1 - self.t_mul)))
|
||||
return int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul)))
|
||||
|
Loading…
x
Reference in New Issue
Block a user