commit
59b3d86c1d
|
@ -8,6 +8,7 @@ import logging
|
||||||
import math
|
import math
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from typing import List
|
||||||
|
|
||||||
from .scheduler import Scheduler
|
from .scheduler import Scheduler
|
||||||
|
|
||||||
|
@ -77,7 +78,7 @@ class CosineLRScheduler(Scheduler):
|
||||||
else:
|
else:
|
||||||
self.warmup_steps = [1 for _ in self.base_values]
|
self.warmup_steps = [1 for _ in self.base_values]
|
||||||
|
|
||||||
def _get_lr(self, t):
|
def _get_lr(self, t: int) -> List[float]:
|
||||||
if t < self.warmup_t:
|
if t < self.warmup_t:
|
||||||
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
|
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -53,7 +53,7 @@ class MultiStepLRScheduler(Scheduler):
|
||||||
# assumes self.decay_t is sorted
|
# assumes self.decay_t is sorted
|
||||||
return bisect.bisect_right(self.decay_t, t + 1)
|
return bisect.bisect_right(self.decay_t, t + 1)
|
||||||
|
|
||||||
def _get_lr(self, t):
|
def _get_lr(self, t: int) -> List[float]:
|
||||||
if t < self.warmup_t:
|
if t < self.warmup_t:
|
||||||
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
|
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -5,6 +5,7 @@ Adapts PyTorch plateau scheduler and allows application of noise, warmup.
|
||||||
Hacked together by / Copyright 2020 Ross Wightman
|
Hacked together by / Copyright 2020 Ross Wightman
|
||||||
"""
|
"""
|
||||||
import torch
|
import torch
|
||||||
|
from typing import List
|
||||||
|
|
||||||
from .scheduler import Scheduler
|
from .scheduler import Scheduler
|
||||||
|
|
||||||
|
@ -106,5 +107,5 @@ class PlateauLRScheduler(Scheduler):
|
||||||
param_group['lr'] = new_lr
|
param_group['lr'] = new_lr
|
||||||
self.restore_lr = restore_lr
|
self.restore_lr = restore_lr
|
||||||
|
|
||||||
def _get_lr(self, t: int) -> float:
|
def _get_lr(self, t: int) -> List[float]:
|
||||||
assert False, 'should not be called as step is overridden'
|
assert False, 'should not be called as step is overridden'
|
||||||
|
|
|
@ -6,6 +6,7 @@ Hacked together by / Copyright 2021 Ross Wightman
|
||||||
"""
|
"""
|
||||||
import math
|
import math
|
||||||
import logging
|
import logging
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
@ -73,7 +74,7 @@ class PolyLRScheduler(Scheduler):
|
||||||
else:
|
else:
|
||||||
self.warmup_steps = [1 for _ in self.base_values]
|
self.warmup_steps = [1 for _ in self.base_values]
|
||||||
|
|
||||||
def _get_lr(self, t):
|
def _get_lr(self, t: int) -> List[float]:
|
||||||
if t < self.warmup_t:
|
if t < self.warmup_t:
|
||||||
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
|
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import abc
|
import abc
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
@ -65,10 +65,10 @@ class Scheduler(ABC):
|
||||||
self.__dict__.update(state_dict)
|
self.__dict__.update(state_dict)
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def _get_lr(self, t: int) -> float:
|
def _get_lr(self, t: int) -> List[float]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _get_values(self, t: int, on_epoch: bool = True) -> Optional[float]:
|
def _get_values(self, t: int, on_epoch: bool = True) -> Optional[List[float]]:
|
||||||
proceed = (on_epoch and self.t_in_epochs) or (not on_epoch and not self.t_in_epochs)
|
proceed = (on_epoch and self.t_in_epochs) or (not on_epoch and not self.t_in_epochs)
|
||||||
if not proceed:
|
if not proceed:
|
||||||
return None
|
return None
|
||||||
|
|
|
@ -6,6 +6,8 @@ Hacked together by / Copyright 2020 Ross Wightman
|
||||||
"""
|
"""
|
||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
from .scheduler import Scheduler
|
from .scheduler import Scheduler
|
||||||
|
|
||||||
|
@ -51,7 +53,7 @@ class StepLRScheduler(Scheduler):
|
||||||
else:
|
else:
|
||||||
self.warmup_steps = [1 for _ in self.base_values]
|
self.warmup_steps = [1 for _ in self.base_values]
|
||||||
|
|
||||||
def _get_lr(self, t):
|
def _get_lr(self, t: int) -> List[float]:
|
||||||
if t < self.warmup_t:
|
if t < self.warmup_t:
|
||||||
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
|
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -8,6 +8,7 @@ import logging
|
||||||
import math
|
import math
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from typing import List
|
||||||
|
|
||||||
from .scheduler import Scheduler
|
from .scheduler import Scheduler
|
||||||
|
|
||||||
|
@ -75,7 +76,7 @@ class TanhLRScheduler(Scheduler):
|
||||||
else:
|
else:
|
||||||
self.warmup_steps = [1 for _ in self.base_values]
|
self.warmup_steps = [1 for _ in self.base_values]
|
||||||
|
|
||||||
def _get_lr(self, t):
|
def _get_lr(self, t: int) -> List[float]:
|
||||||
if t < self.warmup_t:
|
if t < self.warmup_t:
|
||||||
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
|
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue