diff --git a/mmcv/runner/hooks/lr_updater.py b/mmcv/runner/hooks/lr_updater.py index 27709bb17..987d36b24 100644 --- a/mmcv/runner/hooks/lr_updater.py +++ b/mmcv/runner/hooks/lr_updater.py @@ -1,5 +1,7 @@ from __future__ import division +from math import cos, pi + from .hook import Hook @@ -161,3 +163,20 @@ class InvLrUpdaterHook(LrUpdaterHook): def get_lr(self, runner, base_lr): progress = runner.epoch if self.by_epoch else runner.iter return base_lr * (1 + self.gamma * progress)**(-self.power) + + +class CosineLrUpdaterHook(LrUpdaterHook): + + def __init__(self, target_lr=0, **kwargs): + self.target_lr = target_lr + super(CosineLrUpdaterHook, self).__init__(**kwargs) + + def get_lr(self, runner, base_lr): + if self.by_epoch: + progress = runner.epoch + max_progress = runner.max_epochs + else: + progress = runner.iter + max_progress = runner.max_iters + return self.target_lr + 0.5 * (base_lr - self.target_lr) * \ + (1 + cos(pi * (progress / max_progress)))