This commit is contained in:
fzyzcjy 2024-04-02 10:25:30 +08:00
parent 8880a5cd5c
commit b44e4e45a2
7 changed files with 13 additions and 7 deletions

View File

@ -8,6 +8,7 @@ import logging
import math
import numpy as np
import torch
from typing import List
from .scheduler import Scheduler
@ -77,7 +78,7 @@ class CosineLRScheduler(Scheduler):
else:
self.warmup_steps = [1 for _ in self.base_values]
def _get_lr(self, t):
def _get_lr(self, t: int) -> List[float]:
if t < self.warmup_t:
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
else:

View File

@ -53,7 +53,7 @@ class MultiStepLRScheduler(Scheduler):
# assumes self.decay_t is sorted
return bisect.bisect_right(self.decay_t, t + 1)
def _get_lr(self, t):
def _get_lr(self, t: int) -> List[float]:
if t < self.warmup_t:
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
else:

View File

@ -5,6 +5,7 @@ Adapts PyTorch plateau scheduler and allows application of noise, warmup.
Hacked together by / Copyright 2020 Ross Wightman
"""
import torch
from typing import List
from .scheduler import Scheduler
@ -106,5 +107,5 @@ class PlateauLRScheduler(Scheduler):
param_group['lr'] = new_lr
self.restore_lr = restore_lr
def _get_lr(self, t: int) -> float:
def _get_lr(self, t: int) -> List[float]:
assert False, 'should not be called as step is overridden'

View File

@ -6,6 +6,7 @@ Hacked together by / Copyright 2021 Ross Wightman
"""
import math
import logging
from typing import List
import torch
@ -73,7 +74,7 @@ class PolyLRScheduler(Scheduler):
else:
self.warmup_steps = [1 for _ in self.base_values]
def _get_lr(self, t):
def _get_lr(self, t: int) -> List[float]:
if t < self.warmup_t:
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
else:

View File

@ -1,6 +1,6 @@
import abc
from abc import ABC
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional
import torch

View File

@ -6,6 +6,8 @@ Hacked together by / Copyright 2020 Ross Wightman
"""
import math
import torch
from typing import List
from .scheduler import Scheduler
@ -51,7 +53,7 @@ class StepLRScheduler(Scheduler):
else:
self.warmup_steps = [1 for _ in self.base_values]
def _get_lr(self, t):
def _get_lr(self, t: int) -> List[float]:
if t < self.warmup_t:
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
else:

View File

@ -8,6 +8,7 @@ import logging
import math
import numpy as np
import torch
from typing import List
from .scheduler import Scheduler
@ -75,7 +76,7 @@ class TanhLRScheduler(Scheduler):
else:
self.warmup_steps = [1 for _ in self.base_values]
def _get_lr(self, t):
def _get_lr(self, t: int) -> List[float]:
if t < self.warmup_t:
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
else: