From 8880a5cd5c079244ed0cb4930ec31f811a991c4a Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Sat, 23 Mar 2024 11:27:33 +0800 Subject: [PATCH 1/2] Update scheduler.py --- timm/scheduler/scheduler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/timm/scheduler/scheduler.py b/timm/scheduler/scheduler.py index 4ae2e2ae..3b1913f3 100644 --- a/timm/scheduler/scheduler.py +++ b/timm/scheduler/scheduler.py @@ -65,10 +65,10 @@ class Scheduler(ABC): self.__dict__.update(state_dict) @abc.abstractmethod - def _get_lr(self, t: int) -> float: + def _get_lr(self, t: int) -> List[float]: pass - def _get_values(self, t: int, on_epoch: bool = True) -> Optional[float]: + def _get_values(self, t: int, on_epoch: bool = True) -> Optional[List[float]]: proceed = (on_epoch and self.t_in_epochs) or (not on_epoch and not self.t_in_epochs) if not proceed: return None From b44e4e45a24946f529d875eac984992a91e7c3c8 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 2 Apr 2024 10:25:30 +0800 Subject: [PATCH 2/2] more --- timm/scheduler/cosine_lr.py | 3 ++- timm/scheduler/multistep_lr.py | 2 +- timm/scheduler/plateau_lr.py | 3 ++- timm/scheduler/poly_lr.py | 3 ++- timm/scheduler/scheduler.py | 2 +- timm/scheduler/step_lr.py | 4 +++- timm/scheduler/tanh_lr.py | 3 ++- 7 files changed, 13 insertions(+), 7 deletions(-) diff --git a/timm/scheduler/cosine_lr.py b/timm/scheduler/cosine_lr.py index e2c975fb..4eaaa86a 100644 --- a/timm/scheduler/cosine_lr.py +++ b/timm/scheduler/cosine_lr.py @@ -8,6 +8,7 @@ import logging import math import numpy as np import torch +from typing import List from .scheduler import Scheduler @@ -77,7 +78,7 @@ class CosineLRScheduler(Scheduler): else: self.warmup_steps = [1 for _ in self.base_values] - def _get_lr(self, t): + def _get_lr(self, t: int) -> List[float]: if t < self.warmup_t: lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] else: diff --git a/timm/scheduler/multistep_lr.py b/timm/scheduler/multistep_lr.py index 10f2fb50..e5db556d 100644 --- a/timm/scheduler/multistep_lr.py +++ b/timm/scheduler/multistep_lr.py @@ -53,7 +53,7 @@ class MultiStepLRScheduler(Scheduler): # assumes self.decay_t is sorted return bisect.bisect_right(self.decay_t, t + 1) - def _get_lr(self, t): + def _get_lr(self, t: int) -> List[float]: if t < self.warmup_t: lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] else: diff --git a/timm/scheduler/plateau_lr.py b/timm/scheduler/plateau_lr.py index 9f827157..e868bd5e 100644 --- a/timm/scheduler/plateau_lr.py +++ b/timm/scheduler/plateau_lr.py @@ -5,6 +5,7 @@ Adapts PyTorch plateau scheduler and allows application of noise, warmup. Hacked together by / Copyright 2020 Ross Wightman """ import torch +from typing import List from .scheduler import Scheduler @@ -106,5 +107,5 @@ class PlateauLRScheduler(Scheduler): param_group['lr'] = new_lr self.restore_lr = restore_lr - def _get_lr(self, t: int) -> float: + def _get_lr(self, t: int) -> List[float]: assert False, 'should not be called as step is overridden' diff --git a/timm/scheduler/poly_lr.py b/timm/scheduler/poly_lr.py index 906f6acf..8875e15b 100644 --- a/timm/scheduler/poly_lr.py +++ b/timm/scheduler/poly_lr.py @@ -6,6 +6,7 @@ Hacked together by / Copyright 2021 Ross Wightman """ import math import logging +from typing import List import torch @@ -73,7 +74,7 @@ class PolyLRScheduler(Scheduler): else: self.warmup_steps = [1 for _ in self.base_values] - def _get_lr(self, t): + def _get_lr(self, t: int) -> List[float]: if t < self.warmup_t: lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] else: diff --git a/timm/scheduler/scheduler.py b/timm/scheduler/scheduler.py index 3b1913f3..583357f7 100644 --- a/timm/scheduler/scheduler.py +++ b/timm/scheduler/scheduler.py @@ -1,6 +1,6 @@ import abc from abc import ABC -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional import torch diff --git a/timm/scheduler/step_lr.py b/timm/scheduler/step_lr.py index 70a45a70..c205d437 100644 --- a/timm/scheduler/step_lr.py +++ b/timm/scheduler/step_lr.py @@ -6,6 +6,8 @@ Hacked together by / Copyright 2020 Ross Wightman """ import math import torch +from typing import List + from .scheduler import Scheduler @@ -51,7 +53,7 @@ class StepLRScheduler(Scheduler): else: self.warmup_steps = [1 for _ in self.base_values] - def _get_lr(self, t): + def _get_lr(self, t: int) -> List[float]: if t < self.warmup_t: lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] else: diff --git a/timm/scheduler/tanh_lr.py b/timm/scheduler/tanh_lr.py index 48acc61b..94455302 100644 --- a/timm/scheduler/tanh_lr.py +++ b/timm/scheduler/tanh_lr.py @@ -8,6 +8,7 @@ import logging import math import numpy as np import torch +from typing import List from .scheduler import Scheduler @@ -75,7 +76,7 @@ class TanhLRScheduler(Scheduler): else: self.warmup_steps = [1 for _ in self.base_values] - def _get_lr(self, t): + def _get_lr(self, t: int) -> List[float]: if t < self.warmup_t: lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] else: