Add cyclic learning rate
parent
3c21282dfb
commit
8d4a79e57f
|
@ -253,6 +253,87 @@ class Cosine(LRBase):
|
|||
return learning_rate
|
||||
|
||||
|
||||
class Cyclic(LRBase):
|
||||
"""Cyclic learning rate decay
|
||||
|
||||
Args:
|
||||
Args:
|
||||
epochs (int): total epoch(s)
|
||||
step_each_epoch (int): number of iterations within an epoch
|
||||
base_learning_rate (float): Initial learning rate, which is the lower boundary in the cycle. The paper recommends
|
||||
that set the base_learning_rate to 1/3 or 1/4 of max_learning_rate.
|
||||
max_learning_rate (float): Maximum learning rate in the cycle. It defines the cycle amplitude as above.
|
||||
Since there is some scaling operation during process of learning rate adjustment,
|
||||
max_learning_rate may not actually be reached.
|
||||
warmup_epoch (int): number of warmup epoch(s)
|
||||
warmup_start_lr (float): start learning rate within warmup
|
||||
step_size_up (int): Number of training steps, which is used to increase learning rate in a cycle.
|
||||
The step size of one cycle will be defined by step_size_up + step_size_down. According to the paper, step
|
||||
size should be set as at least 3 or 4 times steps in one epoch.
|
||||
step_size_down (int, optional): Number of training steps, which is used to decrease learning rate in a cycle.
|
||||
If not specified, it's value will initialize to `` step_size_up `` . Default: None
|
||||
mode (str, optional): one of 'triangular', 'triangular2' or 'exp_range'.
|
||||
If scale_fn is specified, this argument will be ignored. Default: 'triangular'
|
||||
exp_gamma (float): Constant in 'exp_range' scaling function: exp_gamma**iterations. Used only when mode = 'exp_range'. Default: 1.0
|
||||
scale_fn (function, optional): A custom scaling function, which is used to replace three build-in methods.
|
||||
It should only have one argument. For all x >= 0, 0 <= scale_fn(x) <= 1.
|
||||
If specified, then 'mode' will be ignored. Default: None
|
||||
scale_mode (str, optional): One of 'cycle' or 'iterations'. Defines whether scale_fn is evaluated on cycle
|
||||
number or cycle iterations (total iterations since start of training). Default: 'cycle'
|
||||
last_epoch (int, optional): The index of last epoch. Can be set to restart training.Default: -1, means initial learning rate.
|
||||
by_epoch (bool): learning rate decays by epoch when by_epoch is True, else by iter
|
||||
verbose: (bool, optional): If True, prints a message to stdout for each update. Defaults to False
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
epochs,
|
||||
step_each_epoch,
|
||||
base_learning_rate,
|
||||
max_learning_rate,
|
||||
warmup_epoch,
|
||||
warmup_start_lr,
|
||||
step_size_up,
|
||||
step_size_down=None,
|
||||
mode='triangular',
|
||||
exp_gamma=1.0,
|
||||
scale_fn=None,
|
||||
scale_mode='cycle',
|
||||
by_epoch=False,
|
||||
last_epoch=-1,
|
||||
verbose=False):
|
||||
|
||||
super(Cyclic, self).__init__(
|
||||
epochs, step_each_epoch, base_learning_rate, warmup_epoch,
|
||||
warmup_start_lr, last_epoch, by_epoch, verbose)
|
||||
self.base_learning_rate = base_learning_rate
|
||||
self.max_learning_rate = max_learning_rate
|
||||
self.step_size_up = step_size_up
|
||||
self.step_size_down = step_size_down
|
||||
self.mode = mode
|
||||
self.exp_gamma = exp_gamma
|
||||
self.scale_fn = scale_fn
|
||||
self.scale_mode = scale_mode
|
||||
|
||||
def __call__(self):
|
||||
learning_rate = lr.CyclicLR(
|
||||
base_learning_rate=self.base_learning_rate,
|
||||
max_learning_rate=self.max_learning_rate,
|
||||
step_size_up=self.step_size_up,
|
||||
step_size_down=self.step_size_down,
|
||||
mode=self.mode,
|
||||
exp_gamma=self.exp_gamma,
|
||||
scale_fn=self.scale_fn,
|
||||
scale_mode=self.scale_mode,
|
||||
last_epoch=self.last_epoch,
|
||||
verbose=self.verbose)
|
||||
|
||||
if self.warmup_steps > 0:
|
||||
learning_rate = self.linear_warmup(learning_rate)
|
||||
|
||||
setattr(learning_rate, "by_epoch", self.by_epoch)
|
||||
return learning_rate
|
||||
|
||||
|
||||
class Step(LRBase):
|
||||
"""Step learning rate decay
|
||||
|
||||
|
@ -421,6 +502,7 @@ class ReduceOnPlateau(LRBase):
|
|||
last_epoch (int, optional): last epoch. Defaults to -1.
|
||||
by_epoch (bool, optional): learning rate decays by epoch when by_epoch is True, else by iter. Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
epochs,
|
||||
step_each_epoch,
|
||||
|
@ -488,6 +570,7 @@ class CosineFixmatch(LRBase):
|
|||
last_epoch (int, optional): last epoch. Defaults to -1.
|
||||
by_epoch (bool, optional): learning rate decays by epoch when by_epoch is True, else by iter. Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
epochs,
|
||||
step_each_epoch,
|
||||
|
|
Loading…
Reference in New Issue