mirror of
https://github.com/PaddlePaddle/PaddleClas.git
synced 2025-06-03 21:55:06 +08:00
update learning_rate.py
This commit is contained in:
parent
2cfd8dd828
commit
30cbb18321
@ -14,7 +14,7 @@
|
|||||||
|
|
||||||
from __future__ import (absolute_import, division, print_function,
|
from __future__ import (absolute_import, division, print_function,
|
||||||
unicode_literals)
|
unicode_literals)
|
||||||
|
import types
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
@ -466,5 +466,12 @@ class ReduceOnPlateau(LRBase):
|
|||||||
if self.warmup_steps > 0:
|
if self.warmup_steps > 0:
|
||||||
learning_rate = self.linear_warmup(learning_rate)
|
learning_rate = self.linear_warmup(learning_rate)
|
||||||
|
|
||||||
|
# NOTE: Implement get_lr() method for class `ReduceOnPlateau`,
|
||||||
|
# which is called in `log_info` function
|
||||||
|
def get_lr(self):
|
||||||
|
return self.last_lr
|
||||||
|
|
||||||
|
learning_rate.get_lr = types.MethodType(get_lr, learning_rate)
|
||||||
|
|
||||||
setattr(learning_rate, "by_epoch", self.by_epoch)
|
setattr(learning_rate, "by_epoch", self.by_epoch)
|
||||||
return learning_rate
|
return learning_rate
|
||||||
|
Loading…
x
Reference in New Issue
Block a user