diff --git a/timm/scheduler/cosine_lr.py b/timm/scheduler/cosine_lr.py index e2c975fb..4eaaa86a 100644 --- a/timm/scheduler/cosine_lr.py +++ b/timm/scheduler/cosine_lr.py @@ -8,6 +8,7 @@ import logging import math import numpy as np import torch +from typing import List from .scheduler import Scheduler @@ -77,7 +78,7 @@ class CosineLRScheduler(Scheduler): else: self.warmup_steps = [1 for _ in self.base_values] - def _get_lr(self, t): + def _get_lr(self, t: int) -> List[float]: if t < self.warmup_t: lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] else: diff --git a/timm/scheduler/multistep_lr.py b/timm/scheduler/multistep_lr.py index 10f2fb50..e5db556d 100644 --- a/timm/scheduler/multistep_lr.py +++ b/timm/scheduler/multistep_lr.py @@ -53,7 +53,7 @@ class MultiStepLRScheduler(Scheduler): # assumes self.decay_t is sorted return bisect.bisect_right(self.decay_t, t + 1) - def _get_lr(self, t): + def _get_lr(self, t: int) -> List[float]: if t < self.warmup_t: lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] else: diff --git a/timm/scheduler/plateau_lr.py b/timm/scheduler/plateau_lr.py index 9f827157..e868bd5e 100644 --- a/timm/scheduler/plateau_lr.py +++ b/timm/scheduler/plateau_lr.py @@ -5,6 +5,7 @@ Adapts PyTorch plateau scheduler and allows application of noise, warmup. Hacked together by / Copyright 2020 Ross Wightman """ import torch +from typing import List from .scheduler import Scheduler @@ -106,5 +107,5 @@ class PlateauLRScheduler(Scheduler): param_group['lr'] = new_lr self.restore_lr = restore_lr - def _get_lr(self, t: int) -> float: + def _get_lr(self, t: int) -> List[float]: assert False, 'should not be called as step is overridden' diff --git a/timm/scheduler/poly_lr.py b/timm/scheduler/poly_lr.py index 906f6acf..8875e15b 100644 --- a/timm/scheduler/poly_lr.py +++ b/timm/scheduler/poly_lr.py @@ -6,6 +6,7 @@ Hacked together by / Copyright 2021 Ross Wightman """ import math import logging +from typing import List import torch @@ -73,7 +74,7 @@ class PolyLRScheduler(Scheduler): else: self.warmup_steps = [1 for _ in self.base_values] - def _get_lr(self, t): + def _get_lr(self, t: int) -> List[float]: if t < self.warmup_t: lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] else: diff --git a/timm/scheduler/scheduler.py b/timm/scheduler/scheduler.py index 3b1913f3..583357f7 100644 --- a/timm/scheduler/scheduler.py +++ b/timm/scheduler/scheduler.py @@ -1,6 +1,6 @@ import abc from abc import ABC -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional import torch diff --git a/timm/scheduler/step_lr.py b/timm/scheduler/step_lr.py index 70a45a70..c205d437 100644 --- a/timm/scheduler/step_lr.py +++ b/timm/scheduler/step_lr.py @@ -6,6 +6,8 @@ Hacked together by / Copyright 2020 Ross Wightman """ import math import torch +from typing import List + from .scheduler import Scheduler @@ -51,7 +53,7 @@ class StepLRScheduler(Scheduler): else: self.warmup_steps = [1 for _ in self.base_values] - def _get_lr(self, t): + def _get_lr(self, t: int) -> List[float]: if t < self.warmup_t: lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] else: diff --git a/timm/scheduler/tanh_lr.py b/timm/scheduler/tanh_lr.py index 48acc61b..94455302 100644 --- a/timm/scheduler/tanh_lr.py +++ b/timm/scheduler/tanh_lr.py @@ -8,6 +8,7 @@ import logging import math import numpy as np import torch +from typing import List from .scheduler import Scheduler @@ -75,7 +76,7 @@ class TanhLRScheduler(Scheduler): else: self.warmup_steps = [1 for _ in self.base_values] - def _get_lr(self, t): + def _get_lr(self, t: int) -> List[float]: if t < self.warmup_t: lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] else: