diff --git a/ppcls/configs/reid/strong_baseline/softmax_triplet_with_center.yaml b/ppcls/configs/reid/strong_baseline/softmax_triplet_with_center.yaml index b500fb203..70c70a99b 100644 --- a/ppcls/configs/reid/strong_baseline/softmax_triplet_with_center.yaml +++ b/ppcls/configs/reid/strong_baseline/softmax_triplet_with_center.yaml @@ -87,7 +87,7 @@ Optimizer: - SGD: scope: CenterLoss lr: - name: Constant + name: ConstLR learning_rate: 1000.0 # NOTE: set to ori_lr*(1/centerloss_weight) to avoid manually scaling centers' gradidents. # data loader for train and eval diff --git a/ppcls/optimizer/learning_rate.py b/ppcls/optimizer/learning_rate.py index 437ffe243..c8d87517e 100644 --- a/ppcls/optimizer/learning_rate.py +++ b/ppcls/optimizer/learning_rate.py @@ -93,7 +93,26 @@ class LRBase(object): return warmup_lr -class Constant(LRBase): +class Constant(lr.LRScheduler): + """Constant learning rate Class implementation + + Args: + learning_rate (float): The initial learning rate + last_epoch (int, optional): The index of last epoch. Default: -1. + """ + + def __init__(self, learning_rate, last_epoch=-1, **kwargs): + self.learning_rate = learning_rate + self.last_epoch = last_epoch + super(Constant, self).__init__() + + def get_lr(self) -> float: + """always return the same learning rate + """ + return self.learning_rate + + +class ConstLR(LRBase): """Constant learning rate Args: @@ -115,22 +134,14 @@ class Constant(LRBase): last_epoch=-1, by_epoch=False, **kwargs): - super(Constant, self).__init__(epochs, step_each_epoch, learning_rate, - warmup_epoch, warmup_start_lr, - last_epoch, by_epoch) + super(ConstLR, self).__init__(epochs, step_each_epoch, learning_rate, + warmup_epoch, warmup_start_lr, + last_epoch, by_epoch) def __call__(self): - learning_rate = lr.LRScheduler( + learning_rate = Constant( learning_rate=self.learning_rate, last_epoch=self.last_epoch) - def make_get_lr(): - def get_lr(self): - return self.learning_rate - - return get_lr - - setattr(learning_rate, "get_lr", make_get_lr()) - if self.warmup_steps > 0: learning_rate = self.linear_warmup(learning_rate)