Merge pull request #2124 from fzyzcjy/patch-1

Fix super tiny type error
pull/2138/head
Ross Wightman 2024-04-02 14:31:38 -07:00 committed by GitHub
commit 59b3d86c1d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 15 additions and 9 deletions

View File

@ -8,6 +8,7 @@ import logging
import math
import numpy as np
import torch
from typing import List
from .scheduler import Scheduler
@ -77,7 +78,7 @@ class CosineLRScheduler(Scheduler):
else:
self.warmup_steps = [1 for _ in self.base_values]
def _get_lr(self, t):
def _get_lr(self, t: int) -> List[float]:
if t < self.warmup_t:
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
else:

View File

@ -53,7 +53,7 @@ class MultiStepLRScheduler(Scheduler):
# assumes self.decay_t is sorted
return bisect.bisect_right(self.decay_t, t + 1)
def _get_lr(self, t):
def _get_lr(self, t: int) -> List[float]:
if t < self.warmup_t:
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
else:

View File

@ -5,6 +5,7 @@ Adapts PyTorch plateau scheduler and allows application of noise, warmup.
Hacked together by / Copyright 2020 Ross Wightman
"""
import torch
from typing import List
from .scheduler import Scheduler
@ -106,5 +107,5 @@ class PlateauLRScheduler(Scheduler):
param_group['lr'] = new_lr
self.restore_lr = restore_lr
def _get_lr(self, t: int) -> float:
def _get_lr(self, t: int) -> List[float]:
assert False, 'should not be called as step is overridden'

View File

@ -6,6 +6,7 @@ Hacked together by / Copyright 2021 Ross Wightman
"""
import math
import logging
from typing import List
import torch
@ -73,7 +74,7 @@ class PolyLRScheduler(Scheduler):
else:
self.warmup_steps = [1 for _ in self.base_values]
def _get_lr(self, t):
def _get_lr(self, t: int) -> List[float]:
if t < self.warmup_t:
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
else:

View File

@ -1,6 +1,6 @@
import abc
from abc import ABC
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional
import torch
@ -65,10 +65,10 @@ class Scheduler(ABC):
self.__dict__.update(state_dict)
@abc.abstractmethod
def _get_lr(self, t: int) -> float:
def _get_lr(self, t: int) -> List[float]:
pass
def _get_values(self, t: int, on_epoch: bool = True) -> Optional[float]:
def _get_values(self, t: int, on_epoch: bool = True) -> Optional[List[float]]:
proceed = (on_epoch and self.t_in_epochs) or (not on_epoch and not self.t_in_epochs)
if not proceed:
return None

View File

@ -6,6 +6,8 @@ Hacked together by / Copyright 2020 Ross Wightman
"""
import math
import torch
from typing import List
from .scheduler import Scheduler
@ -51,7 +53,7 @@ class StepLRScheduler(Scheduler):
else:
self.warmup_steps = [1 for _ in self.base_values]
def _get_lr(self, t):
def _get_lr(self, t: int) -> List[float]:
if t < self.warmup_t:
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
else:

View File

@ -8,6 +8,7 @@ import logging
import math
import numpy as np
import torch
from typing import List
from .scheduler import Scheduler
@ -75,7 +76,7 @@ class TanhLRScheduler(Scheduler):
else:
self.warmup_steps = [1 for _ in self.base_values]
def _get_lr(self, t):
def _get_lr(self, t: int) -> List[float]:
if t < self.warmup_t:
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
else: