diff --git a/mmcv/runner/hooks/lr_updater.py b/mmcv/runner/hooks/lr_updater.py index 7c4e37292..837f74c30 100644 --- a/mmcv/runner/hooks/lr_updater.py +++ b/mmcv/runner/hooks/lr_updater.py @@ -138,8 +138,9 @@ class ExpLrUpdaterHook(LrUpdaterHook): class PolyLrUpdaterHook(LrUpdaterHook): - def __init__(self, power=1., **kwargs): + def __init__(self, power=1., min_lr=0., **kwargs): self.power = power + self.min_lr = min_lr super(PolyLrUpdaterHook, self).__init__(**kwargs) def get_lr(self, runner, base_lr): @@ -149,7 +150,8 @@ class PolyLrUpdaterHook(LrUpdaterHook): else: progress = runner.iter max_progress = runner.max_iters - return base_lr * (1 - progress / max_progress)**self.power + coeff = (1 - progress / max_progress)**self.power + return (base_lr - self.min_lr) * coeff + self.min_lr class InvLrUpdaterHook(LrUpdaterHook):