commit
59b3d86c1d
|
@ -8,6 +8,7 @@ import logging
|
|||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
from typing import List
|
||||
|
||||
from .scheduler import Scheduler
|
||||
|
||||
|
@ -77,7 +78,7 @@ class CosineLRScheduler(Scheduler):
|
|||
else:
|
||||
self.warmup_steps = [1 for _ in self.base_values]
|
||||
|
||||
def _get_lr(self, t):
|
||||
def _get_lr(self, t: int) -> List[float]:
|
||||
if t < self.warmup_t:
|
||||
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
|
||||
else:
|
||||
|
|
|
@ -53,7 +53,7 @@ class MultiStepLRScheduler(Scheduler):
|
|||
# assumes self.decay_t is sorted
|
||||
return bisect.bisect_right(self.decay_t, t + 1)
|
||||
|
||||
def _get_lr(self, t):
|
||||
def _get_lr(self, t: int) -> List[float]:
|
||||
if t < self.warmup_t:
|
||||
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
|
||||
else:
|
||||
|
|
|
@ -5,6 +5,7 @@ Adapts PyTorch plateau scheduler and allows application of noise, warmup.
|
|||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import torch
|
||||
from typing import List
|
||||
|
||||
from .scheduler import Scheduler
|
||||
|
||||
|
@ -106,5 +107,5 @@ class PlateauLRScheduler(Scheduler):
|
|||
param_group['lr'] = new_lr
|
||||
self.restore_lr = restore_lr
|
||||
|
||||
def _get_lr(self, t: int) -> float:
|
||||
def _get_lr(self, t: int) -> List[float]:
|
||||
assert False, 'should not be called as step is overridden'
|
||||
|
|
|
@ -6,6 +6,7 @@ Hacked together by / Copyright 2021 Ross Wightman
|
|||
"""
|
||||
import math
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -73,7 +74,7 @@ class PolyLRScheduler(Scheduler):
|
|||
else:
|
||||
self.warmup_steps = [1 for _ in self.base_values]
|
||||
|
||||
def _get_lr(self, t):
|
||||
def _get_lr(self, t: int) -> List[float]:
|
||||
if t < self.warmup_t:
|
||||
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
|
||||
else:
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import abc
|
||||
from abc import ABC
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -65,10 +65,10 @@ class Scheduler(ABC):
|
|||
self.__dict__.update(state_dict)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_lr(self, t: int) -> float:
|
||||
def _get_lr(self, t: int) -> List[float]:
|
||||
pass
|
||||
|
||||
def _get_values(self, t: int, on_epoch: bool = True) -> Optional[float]:
|
||||
def _get_values(self, t: int, on_epoch: bool = True) -> Optional[List[float]]:
|
||||
proceed = (on_epoch and self.t_in_epochs) or (not on_epoch and not self.t_in_epochs)
|
||||
if not proceed:
|
||||
return None
|
||||
|
|
|
@ -6,6 +6,8 @@ Hacked together by / Copyright 2020 Ross Wightman
|
|||
"""
|
||||
import math
|
||||
import torch
|
||||
from typing import List
|
||||
|
||||
|
||||
from .scheduler import Scheduler
|
||||
|
||||
|
@ -51,7 +53,7 @@ class StepLRScheduler(Scheduler):
|
|||
else:
|
||||
self.warmup_steps = [1 for _ in self.base_values]
|
||||
|
||||
def _get_lr(self, t):
|
||||
def _get_lr(self, t: int) -> List[float]:
|
||||
if t < self.warmup_t:
|
||||
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
|
||||
else:
|
||||
|
|
|
@ -8,6 +8,7 @@ import logging
|
|||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
from typing import List
|
||||
|
||||
from .scheduler import Scheduler
|
||||
|
||||
|
@ -75,7 +76,7 @@ class TanhLRScheduler(Scheduler):
|
|||
else:
|
||||
self.warmup_steps = [1 for _ in self.base_values]
|
||||
|
||||
def _get_lr(self, t):
|
||||
def _get_lr(self, t: int) -> List[float]:
|
||||
if t < self.warmup_t:
|
||||
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue