This commit is contained in:
fzyzcjy 2024-04-02 10:25:30 +08:00
parent 8880a5cd5c
commit b44e4e45a2
7 changed files with 13 additions and 7 deletions

View File

@ -8,6 +8,7 @@ import logging
import math import math
import numpy as np import numpy as np
import torch import torch
from typing import List
from .scheduler import Scheduler from .scheduler import Scheduler
@ -77,7 +78,7 @@ class CosineLRScheduler(Scheduler):
else: else:
self.warmup_steps = [1 for _ in self.base_values] self.warmup_steps = [1 for _ in self.base_values]
def _get_lr(self, t): def _get_lr(self, t: int) -> List[float]:
if t < self.warmup_t: if t < self.warmup_t:
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
else: else:

View File

@ -53,7 +53,7 @@ class MultiStepLRScheduler(Scheduler):
# assumes self.decay_t is sorted # assumes self.decay_t is sorted
return bisect.bisect_right(self.decay_t, t + 1) return bisect.bisect_right(self.decay_t, t + 1)
def _get_lr(self, t): def _get_lr(self, t: int) -> List[float]:
if t < self.warmup_t: if t < self.warmup_t:
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
else: else:

View File

@ -5,6 +5,7 @@ Adapts PyTorch plateau scheduler and allows application of noise, warmup.
Hacked together by / Copyright 2020 Ross Wightman Hacked together by / Copyright 2020 Ross Wightman
""" """
import torch import torch
from typing import List
from .scheduler import Scheduler from .scheduler import Scheduler
@ -106,5 +107,5 @@ class PlateauLRScheduler(Scheduler):
param_group['lr'] = new_lr param_group['lr'] = new_lr
self.restore_lr = restore_lr self.restore_lr = restore_lr
def _get_lr(self, t: int) -> float: def _get_lr(self, t: int) -> List[float]:
assert False, 'should not be called as step is overridden' assert False, 'should not be called as step is overridden'

View File

@ -6,6 +6,7 @@ Hacked together by / Copyright 2021 Ross Wightman
""" """
import math import math
import logging import logging
from typing import List
import torch import torch
@ -73,7 +74,7 @@ class PolyLRScheduler(Scheduler):
else: else:
self.warmup_steps = [1 for _ in self.base_values] self.warmup_steps = [1 for _ in self.base_values]
def _get_lr(self, t): def _get_lr(self, t: int) -> List[float]:
if t < self.warmup_t: if t < self.warmup_t:
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
else: else:

View File

@ -1,6 +1,6 @@
import abc import abc
from abc import ABC from abc import ABC
from typing import Any, Dict, Optional from typing import Any, Dict, List, Optional
import torch import torch

View File

@ -6,6 +6,8 @@ Hacked together by / Copyright 2020 Ross Wightman
""" """
import math import math
import torch import torch
from typing import List
from .scheduler import Scheduler from .scheduler import Scheduler
@ -51,7 +53,7 @@ class StepLRScheduler(Scheduler):
else: else:
self.warmup_steps = [1 for _ in self.base_values] self.warmup_steps = [1 for _ in self.base_values]
def _get_lr(self, t): def _get_lr(self, t: int) -> List[float]:
if t < self.warmup_t: if t < self.warmup_t:
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
else: else:

View File

@ -8,6 +8,7 @@ import logging
import math import math
import numpy as np import numpy as np
import torch import torch
from typing import List
from .scheduler import Scheduler from .scheduler import Scheduler
@ -75,7 +76,7 @@ class TanhLRScheduler(Scheduler):
else: else:
self.warmup_steps = [1 for _ in self.base_values] self.warmup_steps = [1 for _ in self.base_values]
def _get_lr(self, t): def _get_lr(self, t: int) -> List[float]:
if t < self.warmup_t: if t < self.warmup_t:
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
else: else: