mirror of https://github.com/open-mmlab/mmcv.git
add cosine lr schedule
parent
c8c3493868
commit
ccb7ca30f8
|
@ -1,5 +1,7 @@
|
|||
from __future__ import division
|
||||
|
||||
from math import cos, pi
|
||||
|
||||
from .hook import Hook
|
||||
|
||||
|
||||
|
@ -161,3 +163,20 @@ class InvLrUpdaterHook(LrUpdaterHook):
|
|||
def get_lr(self, runner, base_lr):
|
||||
progress = runner.epoch if self.by_epoch else runner.iter
|
||||
return base_lr * (1 + self.gamma * progress)**(-self.power)
|
||||
|
||||
|
||||
class CosineLrUpdaterHook(LrUpdaterHook):
|
||||
|
||||
def __init__(self, target_lr=0, **kwargs):
|
||||
self.target_lr = target_lr
|
||||
super(CosineLrUpdaterHook, self).__init__(**kwargs)
|
||||
|
||||
def get_lr(self, runner, base_lr):
|
||||
if self.by_epoch:
|
||||
progress = runner.epoch
|
||||
max_progress = runner.max_epochs
|
||||
else:
|
||||
progress = runner.iter
|
||||
max_progress = runner.max_iters
|
||||
return self.target_lr + 0.5 * (base_lr - self.target_lr) * \
|
||||
(1 + cos(pi * (progress / max_progress)))
|
||||
|
|
Loading…
Reference in New Issue