2020-06-04 12:07:38 +08:00
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
2020-04-09 02:16:30 +08:00
#
2020-06-04 12:07:38 +08:00
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
2020-04-09 02:16:30 +08:00
#
# http://www.apache.org/licenses/LICENSE-2.0
#
2020-06-04 12:07:38 +08:00
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
2021-09-22 14:35:37 +08:00
2021-06-03 11:00:47 +08:00
from __future__ import ( absolute_import , division , print_function ,
unicode_literals )
2022-10-25 11:28:43 +08:00
import math
2022-10-17 15:51:48 +08:00
import types
2022-09-16 13:59:11 +08:00
from abc import abstractmethod
from typing import Union
from paddle . optimizer import lr
2021-09-22 14:35:37 +08:00
from ppcls . utils import logger
2020-04-09 02:16:30 +08:00
2022-09-16 13:59:11 +08:00
class LRBase ( object ) :
""" Base class for custom learning rates
Args :
epochs ( int ) : total epoch ( s )
step_each_epoch ( int ) : number of iterations within an epoch
learning_rate ( float ) : learning rate
warmup_epoch ( int ) : number of warmup epoch ( s )
warmup_start_lr ( float ) : start learning rate within warmup
last_epoch ( int ) : last epoch
by_epoch ( bool ) : learning rate decays by epoch when by_epoch is True , else by iter
verbose ( bool ) : If True , prints a message to stdout for each update . Defaults to False
2020-04-09 02:16:30 +08:00
"""
2022-09-16 13:59:11 +08:00
def __init__ ( self ,
epochs : int ,
step_each_epoch : int ,
learning_rate : float ,
warmup_epoch : int ,
warmup_start_lr : float ,
last_epoch : int ,
by_epoch : bool ,
verbose : bool = False ) - > None :
""" Initialize and record the necessary parameters
"""
super ( LRBase , self ) . __init__ ( )
if warmup_epoch > = epochs :
msg = f " When using warm up, the value of \" Global.epochs \" must be greater than value of \" Optimizer.lr.warmup_epoch \" . The value of \" Optimizer.lr.warmup_epoch \" has been set to { epochs } . "
logger . warning ( msg )
warmup_epoch = epochs
self . epochs = epochs
self . step_each_epoch = step_each_epoch
self . learning_rate = learning_rate
self . warmup_epoch = warmup_epoch
2022-09-19 14:37:15 +08:00
self . warmup_steps = self . warmup_epoch if by_epoch else round (
self . warmup_epoch * self . step_each_epoch )
2022-09-16 13:59:11 +08:00
self . warmup_start_lr = warmup_start_lr
self . last_epoch = last_epoch
self . by_epoch = by_epoch
self . verbose = verbose
@abstractmethod
def __call__ ( self , * kargs , * * kwargs ) - > lr . LRScheduler :
""" generate an learning rate scheduler
Returns :
lr . LinearWarmup : learning rate scheduler
"""
pass
def linear_warmup (
self ,
learning_rate : Union [ float , lr . LRScheduler ] ) - > lr . LinearWarmup :
""" Add an Linear Warmup before learning_rate
Args :
learning_rate ( Union [ float , lr . LRScheduler ] ) : original learning rate without warmup
Returns :
lr . LinearWarmup : learning rate scheduler with warmup
"""
warmup_lr = lr . LinearWarmup (
learning_rate = learning_rate ,
warmup_steps = self . warmup_steps ,
start_lr = self . warmup_start_lr ,
end_lr = self . learning_rate ,
last_epoch = self . last_epoch ,
verbose = self . verbose )
return warmup_lr
2022-09-20 14:45:02 +08:00
class Constant ( lr . LRScheduler ) :
2022-09-20 12:01:52 +08:00
""" Constant learning rate Class implementation
Args :
learning_rate ( float ) : The initial learning rate
last_epoch ( int , optional ) : The index of last epoch . Default : - 1.
"""
def __init__ ( self , learning_rate , last_epoch = - 1 , * * kwargs ) :
self . learning_rate = learning_rate
self . last_epoch = last_epoch
2022-09-20 14:45:02 +08:00
super ( Constant , self ) . __init__ ( )
2022-09-20 12:01:52 +08:00
def get_lr ( self ) - > float :
""" always return the same learning rate
"""
return self . learning_rate
2022-09-20 14:45:02 +08:00
class ConstLR ( LRBase ) :
2022-09-16 13:59:11 +08:00
""" Constant learning rate
2020-04-09 02:16:30 +08:00
Args :
2022-09-16 13:59:11 +08:00
epochs ( int ) : total epoch ( s )
step_each_epoch ( int ) : number of iterations within an epoch
learning_rate ( float ) : learning rate
warmup_epoch ( int ) : number of warmup epoch ( s )
warmup_start_lr ( float ) : start learning rate within warmup
last_epoch ( int ) : last epoch
by_epoch ( bool ) : learning rate decays by epoch when by_epoch is True , else by iter
2020-04-09 02:16:30 +08:00
"""
2021-05-27 18:41:44 +08:00
def __init__ ( self ,
2022-09-16 13:59:11 +08:00
epochs ,
step_each_epoch ,
2021-05-27 18:41:44 +08:00
learning_rate ,
2022-09-16 13:59:11 +08:00
warmup_epoch = 0 ,
warmup_start_lr = 0.0 ,
last_epoch = - 1 ,
by_epoch = False ,
* * kwargs ) :
2022-09-20 14:45:02 +08:00
super ( ConstLR , self ) . __init__ ( epochs , step_each_epoch , learning_rate ,
warmup_epoch , warmup_start_lr ,
last_epoch , by_epoch )
2022-09-16 13:59:11 +08:00
def __call__ ( self ) :
2022-09-20 14:45:02 +08:00
learning_rate = Constant (
2022-09-16 13:59:11 +08:00
learning_rate = self . learning_rate , last_epoch = self . last_epoch )
if self . warmup_steps > 0 :
learning_rate = self . linear_warmup ( learning_rate )
setattr ( learning_rate , " by_epoch " , self . by_epoch )
return learning_rate
class Linear ( LRBase ) :
""" Linear learning rate decay
Args :
epochs ( int ) : total epoch ( s )
step_each_epoch ( int ) : number of iterations within an epoch
learning_rate ( float ) : learning rate
end_lr ( float , optional ) : The minimum final learning rate . Defaults to 0.0 .
power ( float , optional ) : Power of polynomial . Defaults to 1.0 .
warmup_epoch ( int ) : number of warmup epoch ( s )
warmup_start_lr ( float ) : start learning rate within warmup
last_epoch ( int ) : last epoch
by_epoch ( bool ) : learning rate decays by epoch when by_epoch is True , else by iter
"""
def __init__ ( self ,
2021-05-27 18:41:44 +08:00
epochs ,
step_each_epoch ,
2022-09-16 13:59:11 +08:00
learning_rate ,
2021-05-27 18:41:44 +08:00
end_lr = 0.0 ,
power = 1.0 ,
2022-09-16 13:59:11 +08:00
cycle = False ,
2021-05-27 18:41:44 +08:00
warmup_epoch = 0 ,
2021-08-24 15:52:52 +08:00
warmup_start_lr = 0.0 ,
2021-05-27 18:41:44 +08:00
last_epoch = - 1 ,
2022-09-16 13:59:11 +08:00
by_epoch = False ,
2021-05-27 18:41:44 +08:00
* * kwargs ) :
2022-09-16 13:59:11 +08:00
super ( Linear , self ) . __init__ ( epochs , step_each_epoch , learning_rate ,
warmup_epoch , warmup_start_lr , last_epoch ,
by_epoch )
self . decay_steps = ( epochs - self . warmup_epoch ) * step_each_epoch
2021-05-27 18:41:44 +08:00
self . end_lr = end_lr
self . power = power
2022-09-16 13:59:11 +08:00
self . cycle = cycle
self . warmup_steps = round ( self . warmup_epoch * step_each_epoch )
if self . by_epoch :
self . decay_steps = self . epochs - self . warmup_epoch
2020-04-09 02:16:30 +08:00
2021-05-27 18:41:44 +08:00
def __call__ ( self ) :
learning_rate = lr . PolynomialDecay (
learning_rate = self . learning_rate ,
2022-09-16 13:59:11 +08:00
decay_steps = self . decay_steps ,
2021-05-27 18:41:44 +08:00
end_lr = self . end_lr ,
power = self . power ,
2022-09-16 13:59:11 +08:00
cycle = self . cycle ,
2022-09-27 15:26:38 +08:00
last_epoch = self . last_epoch ) if self . decay_steps > 0 else Constant (
self . learning_rate )
2021-05-27 18:41:44 +08:00
2022-09-16 13:59:11 +08:00
if self . warmup_steps > 0 :
learning_rate = self . linear_warmup ( learning_rate )
2021-05-27 18:41:44 +08:00
2022-09-16 13:59:11 +08:00
setattr ( learning_rate , " by_epoch " , self . by_epoch )
return learning_rate
2022-04-12 13:07:53 +08:00
2022-09-16 13:59:11 +08:00
class Cosine ( LRBase ) :
""" Cosine learning rate decay
2022-04-12 13:07:53 +08:00
2022-09-16 13:59:11 +08:00
` ` lr = 0.05 * ( math . cos ( epoch * ( math . pi / epochs ) ) + 1 ) ` `
2022-04-12 13:07:53 +08:00
2020-04-09 02:16:30 +08:00
Args :
2022-09-16 13:59:11 +08:00
epochs ( int ) : total epoch ( s )
step_each_epoch ( int ) : number of iterations within an epoch
learning_rate ( float ) : learning rate
eta_min ( float , optional ) : Minimum learning rate . Defaults to 0.0 .
warmup_epoch ( int , optional ) : The epoch numbers for LinearWarmup . Defaults to 0.
warmup_start_lr ( float , optional ) : start learning rate within warmup . Defaults to 0.0 .
last_epoch ( int , optional ) : last epoch . Defaults to - 1.
by_epoch ( bool , optional ) : learning rate decays by epoch when by_epoch is True , else by iter . Defaults to False .
2020-04-09 02:16:30 +08:00
"""
2021-05-27 18:41:44 +08:00
def __init__ ( self ,
epochs ,
2022-09-16 13:59:11 +08:00
step_each_epoch ,
learning_rate ,
2021-08-24 15:52:52 +08:00
eta_min = 0.0 ,
2021-05-27 18:41:44 +08:00
warmup_epoch = 0 ,
2021-08-24 15:52:52 +08:00
warmup_start_lr = 0.0 ,
2021-05-27 18:41:44 +08:00
last_epoch = - 1 ,
2022-09-16 13:59:11 +08:00
by_epoch = False ,
2021-05-27 18:41:44 +08:00
* * kwargs ) :
2022-09-16 14:24:34 +08:00
super ( Cosine , self ) . __init__ ( epochs , step_each_epoch , learning_rate ,
2022-09-16 13:59:11 +08:00
warmup_epoch , warmup_start_lr , last_epoch ,
by_epoch )
self . T_max = ( self . epochs - self . warmup_epoch ) * self . step_each_epoch
2021-08-24 15:52:52 +08:00
self . eta_min = eta_min
2022-09-16 13:59:11 +08:00
if self . by_epoch :
self . T_max = self . epochs - self . warmup_epoch
2020-04-09 02:16:30 +08:00
2021-05-27 18:41:44 +08:00
def __call__ ( self ) :
learning_rate = lr . CosineAnnealingDecay (
learning_rate = self . learning_rate ,
T_max = self . T_max ,
2021-08-24 15:52:52 +08:00
eta_min = self . eta_min ,
2022-09-27 15:26:38 +08:00
last_epoch = self . last_epoch ) if self . T_max > 0 else Constant (
self . learning_rate )
2022-09-16 13:59:11 +08:00
2021-08-24 15:52:52 +08:00
if self . warmup_steps > 0 :
2022-09-16 13:59:11 +08:00
learning_rate = self . linear_warmup ( learning_rate )
setattr ( learning_rate , " by_epoch " , self . by_epoch )
2021-05-27 18:41:44 +08:00
return learning_rate
2023-02-07 14:57:21 +08:00
class Cyclic ( LRBase ) :
""" Cyclic learning rate decay
Args :
2023-02-17 02:21:03 +08:00
epochs ( int ) : Total epoch ( s ) .
step_each_epoch ( int ) : Number of iterations within an epoch .
2023-02-07 14:57:21 +08:00
base_learning_rate ( float ) : Initial learning rate , which is the lower boundary in the cycle . The paper recommends
that set the base_learning_rate to 1 / 3 or 1 / 4 of max_learning_rate .
max_learning_rate ( float ) : Maximum learning rate in the cycle . It defines the cycle amplitude as above .
Since there is some scaling operation during process of learning rate adjustment ,
max_learning_rate may not actually be reached .
2023-02-17 02:21:03 +08:00
warmup_epoch ( int ) : Number of warmup epoch ( s ) .
warmup_start_lr ( float ) : Start learning rate within warmup .
2023-02-07 14:57:21 +08:00
step_size_up ( int ) : Number of training steps , which is used to increase learning rate in a cycle .
The step size of one cycle will be defined by step_size_up + step_size_down . According to the paper , step
size should be set as at least 3 or 4 times steps in one epoch .
step_size_down ( int , optional ) : Number of training steps , which is used to decrease learning rate in a cycle .
2023-02-17 02:21:03 +08:00
If not specified , it ' s value will initialize to `` step_size_up `` . Default: None.
2023-02-15 17:36:52 +08:00
mode ( str , optional ) : One of ' triangular ' , ' triangular2 ' or ' exp_range ' .
2023-02-17 02:21:03 +08:00
If scale_fn is specified , this argument will be ignored . Default : ' triangular ' .
exp_gamma ( float ) : Constant in ' exp_range ' scaling function : exp_gamma * * iterations . Used only when mode = ' exp_range ' . Default : 1.0 .
2023-02-07 14:57:21 +08:00
scale_fn ( function , optional ) : A custom scaling function , which is used to replace three build - in methods .
It should only have one argument . For all x > = 0 , 0 < = scale_fn ( x ) < = 1.
2023-02-17 02:21:03 +08:00
If specified , then ' mode ' will be ignored . Default : None .
2023-02-07 14:57:21 +08:00
scale_mode ( str , optional ) : One of ' cycle ' or ' iterations ' . Defines whether scale_fn is evaluated on cycle
2023-02-17 02:21:03 +08:00
number or cycle iterations ( total iterations since start of training ) . Default : ' cycle ' .
2023-02-15 17:36:52 +08:00
last_epoch ( int , optional ) : The index of last epoch . Can be set to restart training . Default : - 1 , means initial learning rate .
2023-02-17 02:21:03 +08:00
by_epoch ( bool ) : Learning rate decays by epoch when by_epoch is True , else by iter .
verbose : ( bool , optional ) : If True , prints a message to stdout for each update . Defaults to False .
2023-02-07 14:57:21 +08:00
"""
def __init__ ( self ,
epochs ,
step_each_epoch ,
base_learning_rate ,
max_learning_rate ,
warmup_epoch ,
warmup_start_lr ,
step_size_up ,
step_size_down = None ,
mode = ' triangular ' ,
exp_gamma = 1.0 ,
scale_fn = None ,
scale_mode = ' cycle ' ,
by_epoch = False ,
last_epoch = - 1 ,
verbose = False ) :
super ( Cyclic , self ) . __init__ (
epochs , step_each_epoch , base_learning_rate , warmup_epoch ,
warmup_start_lr , last_epoch , by_epoch , verbose )
self . base_learning_rate = base_learning_rate
self . max_learning_rate = max_learning_rate
self . step_size_up = step_size_up
self . step_size_down = step_size_down
self . mode = mode
self . exp_gamma = exp_gamma
self . scale_fn = scale_fn
self . scale_mode = scale_mode
def __call__ ( self ) :
learning_rate = lr . CyclicLR (
base_learning_rate = self . base_learning_rate ,
max_learning_rate = self . max_learning_rate ,
step_size_up = self . step_size_up ,
step_size_down = self . step_size_down ,
mode = self . mode ,
exp_gamma = self . exp_gamma ,
scale_fn = self . scale_fn ,
scale_mode = self . scale_mode ,
last_epoch = self . last_epoch ,
verbose = self . verbose )
if self . warmup_steps > 0 :
learning_rate = self . linear_warmup ( learning_rate )
setattr ( learning_rate , " by_epoch " , self . by_epoch )
return learning_rate
2022-09-16 13:59:11 +08:00
class Step ( LRBase ) :
""" Step learning rate decay
2020-05-06 19:17:39 +08:00
Args :
2022-09-16 13:59:11 +08:00
epochs ( int ) : total epoch ( s )
step_each_epoch ( int ) : number of iterations within an epoch
learning_rate ( float ) : learning rate
2023-03-23 02:20:59 +08:00
step_size ( int | float ) : the interval to update .
2022-09-16 13:59:11 +08:00
gamma ( float , optional ) : The Ratio that the learning rate will be reduced . ` ` new_lr = origin_lr * gamma ` ` . It should be less than 1.0 . Default : 0.1 .
warmup_epoch ( int , optional ) : The epoch numbers for LinearWarmup . Defaults to 0.
warmup_start_lr ( float , optional ) : start learning rate within warmup . Defaults to 0.0 .
last_epoch ( int , optional ) : last epoch . Defaults to - 1.
by_epoch ( bool , optional ) : learning rate decays by epoch when by_epoch is True , else by iter . Defaults to False .
2020-05-06 19:17:39 +08:00
"""
2020-06-04 12:07:38 +08:00
def __init__ ( self ,
2022-09-16 13:59:11 +08:00
epochs ,
step_each_epoch ,
2021-05-27 18:41:44 +08:00
learning_rate ,
step_size ,
gamma ,
warmup_epoch = 0 ,
2021-08-24 15:52:52 +08:00
warmup_start_lr = 0.0 ,
2021-05-27 18:41:44 +08:00
last_epoch = - 1 ,
2022-09-16 13:59:11 +08:00
by_epoch = False ,
2020-06-04 12:07:38 +08:00
* * kwargs ) :
2022-09-16 14:24:34 +08:00
super ( Step , self ) . __init__ ( epochs , step_each_epoch , learning_rate ,
warmup_epoch , warmup_start_lr , last_epoch ,
by_epoch )
2023-03-23 02:20:59 +08:00
self . step_size = int ( step_size * step_each_epoch )
2021-05-27 18:41:44 +08:00
self . gamma = gamma
2022-09-16 13:59:11 +08:00
if self . by_epoch :
self . step_size = step_size
2020-05-06 19:17:39 +08:00
2021-05-27 18:41:44 +08:00
def __call__ ( self ) :
learning_rate = lr . StepDecay (
learning_rate = self . learning_rate ,
step_size = self . step_size ,
gamma = self . gamma ,
last_epoch = self . last_epoch )
2022-09-16 13:59:11 +08:00
2021-08-24 15:52:52 +08:00
if self . warmup_steps > 0 :
2022-09-16 13:59:11 +08:00
learning_rate = self . linear_warmup ( learning_rate )
setattr ( learning_rate , " by_epoch " , self . by_epoch )
2021-05-27 18:41:44 +08:00
return learning_rate
2022-09-16 13:59:11 +08:00
class Piecewise ( LRBase ) :
""" Piecewise learning rate decay
2020-04-09 02:16:30 +08:00
Args :
2022-09-16 13:59:11 +08:00
epochs ( int ) : total epoch ( s )
step_each_epoch ( int ) : number of iterations within an epoch
decay_epochs ( List [ int ] ) : A list of steps numbers . The type of element in the list is python int .
values ( List [ float ] ) : A list of learning rate values that will be picked during different epoch boundaries .
warmup_epoch ( int , optional ) : The epoch numbers for LinearWarmup . Defaults to 0.
warmup_start_lr ( float , optional ) : start learning rate within warmup . Defaults to 0.0 .
last_epoch ( int , optional ) : last epoch . Defaults to - 1.
by_epoch ( bool , optional ) : learning rate decays by epoch when by_epoch is True , else by iter . Defaults to False .
2020-04-09 02:16:30 +08:00
"""
def __init__ ( self ,
2022-09-16 13:59:11 +08:00
epochs ,
2021-05-27 18:41:44 +08:00
step_each_epoch ,
decay_epochs ,
values ,
warmup_epoch = 0 ,
2021-08-24 15:52:52 +08:00
warmup_start_lr = 0.0 ,
2021-05-27 18:41:44 +08:00
last_epoch = - 1 ,
2022-09-16 13:59:11 +08:00
by_epoch = False ,
2021-05-27 18:41:44 +08:00
* * kwargs ) :
2022-09-16 14:24:34 +08:00
super ( Piecewise ,
2022-09-16 13:59:11 +08:00
self ) . __init__ ( epochs , step_each_epoch , values [ 0 ] , warmup_epoch ,
warmup_start_lr , last_epoch , by_epoch )
2021-05-27 18:41:44 +08:00
self . values = values
2022-09-16 13:59:11 +08:00
self . boundaries_steps = [ e * step_each_epoch for e in decay_epochs ]
if self . by_epoch is True :
self . boundaries_steps = decay_epochs
2020-04-09 02:16:30 +08:00
def __call__ ( self ) :
2022-09-16 13:59:11 +08:00
learning_rate = lr . PiecewiseDecay (
boundaries = self . boundaries_steps ,
values = self . values ,
last_epoch = self . last_epoch )
if self . warmup_steps > 0 :
learning_rate = self . linear_warmup ( learning_rate )
2022-05-05 19:55:08 +08:00
setattr ( learning_rate , " by_epoch " , self . by_epoch )
2021-05-27 18:41:44 +08:00
return learning_rate
2021-06-03 11:00:47 +08:00
2022-09-16 13:59:11 +08:00
class MultiStepDecay ( LRBase ) :
""" MultiStepDecay learning rate decay
2021-06-03 11:00:47 +08:00
Args :
2022-09-16 13:59:11 +08:00
epochs ( int ) : total epoch ( s )
step_each_epoch ( int ) : number of iterations within an epoch
learning_rate ( float ) : learning rate
milestones ( List [ int ] ) : List of each boundaries . Must be increasing .
gamma ( float , optional ) : The Ratio that the learning rate will be reduced . ` ` new_lr = origin_lr * gamma ` ` . It should be less than 1.0 . Defaults to 0.1 .
warmup_epoch ( int , optional ) : The epoch numbers for LinearWarmup . Defaults to 0.
warmup_start_lr ( float , optional ) : start learning rate within warmup . Defaults to 0.0 .
last_epoch ( int , optional ) : last epoch . Defaults to - 1.
by_epoch ( bool , optional ) : learning rate decays by epoch when by_epoch is True , else by iter . Defaults to False .
2021-06-03 11:00:47 +08:00
"""
def __init__ ( self ,
epochs ,
step_each_epoch ,
2022-09-16 13:59:11 +08:00
learning_rate ,
milestones ,
2021-06-03 11:00:47 +08:00
gamma = 0.1 ,
2022-09-16 13:59:11 +08:00
warmup_epoch = 0 ,
warmup_start_lr = 0.0 ,
2021-06-03 11:00:47 +08:00
last_epoch = - 1 ,
2022-09-16 13:59:11 +08:00
by_epoch = False ,
* * kwargs ) :
2022-09-16 14:24:34 +08:00
super ( MultiStepDecay , self ) . __init__ (
epochs , step_each_epoch , learning_rate , warmup_epoch ,
warmup_start_lr , last_epoch , by_epoch )
2021-06-03 11:00:47 +08:00
self . milestones = [ x * step_each_epoch for x in milestones ]
self . gamma = gamma
2022-09-16 13:59:11 +08:00
if self . by_epoch :
self . milestones = milestones
2021-06-03 11:00:47 +08:00
2022-09-16 13:59:11 +08:00
def __call__ ( self ) :
learning_rate = lr . MultiStepDecay (
learning_rate = self . learning_rate ,
milestones = self . milestones ,
gamma = self . gamma ,
last_epoch = self . last_epoch )
if self . warmup_steps > 0 :
learning_rate = self . linear_warmup ( learning_rate )
setattr ( learning_rate , " by_epoch " , self . by_epoch )
return learning_rate
2022-10-17 15:45:45 +08:00
class ReduceOnPlateau ( LRBase ) :
""" ReduceOnPlateau learning rate decay
Args :
epochs ( int ) : total epoch ( s )
step_each_epoch ( int ) : number of iterations within an epoch
learning_rate ( float ) : learning rate
mode ( str , optional ) : ` ` ' min ' ` ` or ` ` ' max ' ` ` can be selected . Normally , it is ` ` ' min ' ` ` , which means that the
learning rate will reduce when ` ` loss ` ` stops descending . Specially , if it ' s set to `` ' max ' ``, the learning
rate will reduce when ` ` loss ` ` stops ascending . Defaults to ` ` ' min ' ` ` .
factor ( float , optional ) : The Ratio that the learning rate will be reduced . ` ` new_lr = origin_lr * factor ` ` .
It should be less than 1.0 . Defaults to 0.1 .
patience ( int , optional ) : When ` ` loss ` ` doesn ' t improve for this number of epochs, learing rate will be reduced.
Defaults to 10.
threshold ( float , optional ) : ` ` threshold ` ` and ` ` threshold_mode ` ` will determine the minimum change of ` ` loss ` ` .
This make tiny changes of ` ` loss ` ` will be ignored . Defaults to 1e-4 .
threshold_mode ( str , optional ) : ` ` ' rel ' ` ` or ` ` ' abs ' ` ` can be selected . In ` ` ' rel ' ` ` mode , the minimum change of ` ` loss ` `
is ` ` last_loss * threshold ` ` , where ` ` last_loss ` ` is ` ` loss ` ` in last epoch . In ` ` ' abs ' ` ` mode , the minimum
change of ` ` loss ` ` is ` ` threshold ` ` . Defaults to ` ` ' rel ' ` ` .
cooldown ( int , optional ) : The number of epochs to wait before resuming normal operation . Defaults to 0.
min_lr ( float , optional ) : The lower bound of the learning rate after reduction . Defaults to 0.
epsilon ( float , optional ) : Minimal decay applied to lr . If the difference between new and old lr is smaller than epsilon ,
the update is ignored . Defaults to 1e-8 .
warmup_epoch ( int , optional ) : The epoch numbers for LinearWarmup . Defaults to 0.
warmup_start_lr ( float , optional ) : start learning rate within warmup . Defaults to 0.0 .
last_epoch ( int , optional ) : last epoch . Defaults to - 1.
by_epoch ( bool , optional ) : learning rate decays by epoch when by_epoch is True , else by iter . Defaults to False .
"""
2023-02-07 14:57:21 +08:00
2022-10-17 15:45:45 +08:00
def __init__ ( self ,
epochs ,
step_each_epoch ,
learning_rate ,
mode = ' min ' ,
factor = 0.1 ,
patience = 10 ,
threshold = 1e-4 ,
threshold_mode = ' rel ' ,
cooldown = 0 ,
min_lr = 0 ,
epsilon = 1e-8 ,
warmup_epoch = 0 ,
warmup_start_lr = 0.0 ,
last_epoch = - 1 ,
by_epoch = False ,
* * kwargs ) :
super ( ReduceOnPlateau , self ) . __init__ (
epochs , step_each_epoch , learning_rate , warmup_epoch ,
warmup_start_lr , last_epoch , by_epoch )
self . mode = mode
self . factor = factor
self . patience = patience
self . threshold = threshold
self . threshold_mode = threshold_mode
self . cooldown = cooldown
self . min_lr = min_lr
self . epsilon = epsilon
def __call__ ( self ) :
learning_rate = lr . ReduceOnPlateau (
learning_rate = self . learning_rate ,
mode = self . mode ,
factor = self . factor ,
patience = self . patience ,
threshold = self . threshold ,
threshold_mode = self . threshold_mode ,
cooldown = self . cooldown ,
min_lr = self . min_lr ,
epsilon = self . epsilon )
if self . warmup_steps > 0 :
learning_rate = self . linear_warmup ( learning_rate )
2022-10-17 15:51:48 +08:00
# NOTE: Implement get_lr() method for class `ReduceOnPlateau`,
# which is called in `log_info` function
def get_lr ( self ) :
return self . last_lr
learning_rate . get_lr = types . MethodType ( get_lr , learning_rate )
2022-10-17 15:45:45 +08:00
setattr ( learning_rate , " by_epoch " , self . by_epoch )
return learning_rate
2022-10-25 11:28:43 +08:00
class CosineFixmatch ( LRBase ) :
""" Cosine decay in FixMatch style
Args :
epochs ( int ) : total epoch ( s )
step_each_epoch ( int ) : number of iterations within an epoch
learning_rate ( float ) : learning rate
num_warmup_steps ( int ) : the number warmup steps .
warmunum_cycles ( float , optional ) : the factor for cosine in FixMatch learning rate . Defaults to 7 / 16.
last_epoch ( int , optional ) : last epoch . Defaults to - 1.
by_epoch ( bool , optional ) : learning rate decays by epoch when by_epoch is True , else by iter . Defaults to False .
"""
2023-02-07 14:57:21 +08:00
2022-10-25 11:28:43 +08:00
def __init__ ( self ,
epochs ,
step_each_epoch ,
learning_rate ,
num_warmup_steps ,
num_cycles = 7 / 16 ,
last_epoch = - 1 ,
by_epoch = False ) :
self . epochs = epochs
self . step_each_epoch = step_each_epoch
self . learning_rate = learning_rate
self . num_warmup_steps = num_warmup_steps
self . num_cycles = num_cycles
self . last_epoch = last_epoch
2022-10-25 11:41:39 +08:00
self . by_epoch = by_epoch
2022-10-25 11:28:43 +08:00
def __call__ ( self ) :
def _lr_lambda ( current_step ) :
if current_step < self . num_warmup_steps :
return float ( current_step ) / float (
max ( 1 , self . num_warmup_steps ) )
no_progress = float ( current_step - self . num_warmup_steps ) / \
float ( max ( 1 , self . epochs * self . step_each_epoch - self . num_warmup_steps ) )
return max ( 0. , math . cos ( math . pi * self . num_cycles * no_progress ) )
learning_rate = lr . LambdaDecay (
learning_rate = self . learning_rate ,
lr_lambda = _lr_lambda ,
last_epoch = self . last_epoch )
setattr ( learning_rate , " by_epoch " , self . by_epoch )
2022-10-25 11:41:39 +08:00
return learning_rate