mirror of https://github.com/open-mmlab/mmcv.git
commit
1d6e91b1b0
|
@ -1,5 +1,7 @@
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
|
|
||||||
|
from math import cos, pi
|
||||||
|
|
||||||
from .hook import Hook
|
from .hook import Hook
|
||||||
|
|
||||||
|
|
||||||
|
@ -161,3 +163,20 @@ class InvLrUpdaterHook(LrUpdaterHook):
|
||||||
def get_lr(self, runner, base_lr):
|
def get_lr(self, runner, base_lr):
|
||||||
progress = runner.epoch if self.by_epoch else runner.iter
|
progress = runner.epoch if self.by_epoch else runner.iter
|
||||||
return base_lr * (1 + self.gamma * progress)**(-self.power)
|
return base_lr * (1 + self.gamma * progress)**(-self.power)
|
||||||
|
|
||||||
|
|
||||||
|
class CosineLrUpdaterHook(LrUpdaterHook):
|
||||||
|
|
||||||
|
def __init__(self, target_lr=0, **kwargs):
|
||||||
|
self.target_lr = target_lr
|
||||||
|
super(CosineLrUpdaterHook, self).__init__(**kwargs)
|
||||||
|
|
||||||
|
def get_lr(self, runner, base_lr):
|
||||||
|
if self.by_epoch:
|
||||||
|
progress = runner.epoch
|
||||||
|
max_progress = runner.max_epochs
|
||||||
|
else:
|
||||||
|
progress = runner.iter
|
||||||
|
max_progress = runner.max_iters
|
||||||
|
return self.target_lr + 0.5 * (base_lr - self.target_lr) * \
|
||||||
|
(1 + cos(pi * (progress / max_progress)))
|
||||||
|
|
Loading…
Reference in New Issue