update learning_rate.py

This commit is contained in:
HydrogenSulfate 2022-10-17 15:51:48 +08:00
parent 2cfd8dd828
commit 30cbb18321

View File

@ -14,7 +14,7 @@
from __future__ import (absolute_import, division, print_function,
unicode_literals)
import types
from abc import abstractmethod
from typing import Union
@ -466,5 +466,12 @@ class ReduceOnPlateau(LRBase):
if self.warmup_steps > 0:
learning_rate = self.linear_warmup(learning_rate)
# NOTE: Implement get_lr() method for class `ReduceOnPlateau`,
# which is called in `log_info` function
def get_lr(self):
return self.last_lr
learning_rate.get_lr = types.MethodType(get_lr, learning_rate)
setattr(learning_rate, "by_epoch", self.by_epoch)
return learning_rate