mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
more
This commit is contained in:
parent
8880a5cd5c
commit
b44e4e45a2
@ -8,6 +8,7 @@ import logging
|
|||||||
import math
|
import math
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from typing import List
|
||||||
|
|
||||||
from .scheduler import Scheduler
|
from .scheduler import Scheduler
|
||||||
|
|
||||||
@ -77,7 +78,7 @@ class CosineLRScheduler(Scheduler):
|
|||||||
else:
|
else:
|
||||||
self.warmup_steps = [1 for _ in self.base_values]
|
self.warmup_steps = [1 for _ in self.base_values]
|
||||||
|
|
||||||
def _get_lr(self, t):
|
def _get_lr(self, t: int) -> List[float]:
|
||||||
if t < self.warmup_t:
|
if t < self.warmup_t:
|
||||||
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
|
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
|
||||||
else:
|
else:
|
||||||
|
@ -53,7 +53,7 @@ class MultiStepLRScheduler(Scheduler):
|
|||||||
# assumes self.decay_t is sorted
|
# assumes self.decay_t is sorted
|
||||||
return bisect.bisect_right(self.decay_t, t + 1)
|
return bisect.bisect_right(self.decay_t, t + 1)
|
||||||
|
|
||||||
def _get_lr(self, t):
|
def _get_lr(self, t: int) -> List[float]:
|
||||||
if t < self.warmup_t:
|
if t < self.warmup_t:
|
||||||
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
|
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
|
||||||
else:
|
else:
|
||||||
|
@ -5,6 +5,7 @@ Adapts PyTorch plateau scheduler and allows application of noise, warmup.
|
|||||||
Hacked together by / Copyright 2020 Ross Wightman
|
Hacked together by / Copyright 2020 Ross Wightman
|
||||||
"""
|
"""
|
||||||
import torch
|
import torch
|
||||||
|
from typing import List
|
||||||
|
|
||||||
from .scheduler import Scheduler
|
from .scheduler import Scheduler
|
||||||
|
|
||||||
@ -106,5 +107,5 @@ class PlateauLRScheduler(Scheduler):
|
|||||||
param_group['lr'] = new_lr
|
param_group['lr'] = new_lr
|
||||||
self.restore_lr = restore_lr
|
self.restore_lr = restore_lr
|
||||||
|
|
||||||
def _get_lr(self, t: int) -> float:
|
def _get_lr(self, t: int) -> List[float]:
|
||||||
assert False, 'should not be called as step is overridden'
|
assert False, 'should not be called as step is overridden'
|
||||||
|
@ -6,6 +6,7 @@ Hacked together by / Copyright 2021 Ross Wightman
|
|||||||
"""
|
"""
|
||||||
import math
|
import math
|
||||||
import logging
|
import logging
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -73,7 +74,7 @@ class PolyLRScheduler(Scheduler):
|
|||||||
else:
|
else:
|
||||||
self.warmup_steps = [1 for _ in self.base_values]
|
self.warmup_steps = [1 for _ in self.base_values]
|
||||||
|
|
||||||
def _get_lr(self, t):
|
def _get_lr(self, t: int) -> List[float]:
|
||||||
if t < self.warmup_t:
|
if t < self.warmup_t:
|
||||||
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
|
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
|
||||||
else:
|
else:
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import abc
|
import abc
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
@ -6,6 +6,8 @@ Hacked together by / Copyright 2020 Ross Wightman
|
|||||||
"""
|
"""
|
||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
from .scheduler import Scheduler
|
from .scheduler import Scheduler
|
||||||
|
|
||||||
@ -51,7 +53,7 @@ class StepLRScheduler(Scheduler):
|
|||||||
else:
|
else:
|
||||||
self.warmup_steps = [1 for _ in self.base_values]
|
self.warmup_steps = [1 for _ in self.base_values]
|
||||||
|
|
||||||
def _get_lr(self, t):
|
def _get_lr(self, t: int) -> List[float]:
|
||||||
if t < self.warmup_t:
|
if t < self.warmup_t:
|
||||||
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
|
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
|
||||||
else:
|
else:
|
||||||
|
@ -8,6 +8,7 @@ import logging
|
|||||||
import math
|
import math
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from typing import List
|
||||||
|
|
||||||
from .scheduler import Scheduler
|
from .scheduler import Scheduler
|
||||||
|
|
||||||
@ -75,7 +76,7 @@ class TanhLRScheduler(Scheduler):
|
|||||||
else:
|
else:
|
||||||
self.warmup_steps = [1 for _ in self.base_values]
|
self.warmup_steps = [1 for _ in self.base_values]
|
||||||
|
|
||||||
def _get_lr(self, t):
|
def _get_lr(self, t: int) -> List[float]:
|
||||||
if t < self.warmup_t:
|
if t < self.warmup_t:
|
||||||
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
|
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
|
||||||
else:
|
else:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user