mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Extend train epoch schedule by warmup_epochs if warmup_prefix enable, allows schedule to reach end w/ prefix enabledy
This commit is contained in:
parent
7f0c1b1f30
commit
363b043c13
@ -111,6 +111,7 @@ class CosineLRScheduler(Scheduler):
|
|||||||
def get_cycle_length(self, cycles=0):
|
def get_cycle_length(self, cycles=0):
|
||||||
cycles = max(1, cycles or self.cycle_limit)
|
cycles = max(1, cycles or self.cycle_limit)
|
||||||
if self.cycle_mul == 1.0:
|
if self.cycle_mul == 1.0:
|
||||||
return self.t_initial * cycles
|
t = self.t_initial * cycles
|
||||||
else:
|
else:
|
||||||
return int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul)))
|
t = int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul)))
|
||||||
|
return t + self.warmup_t if self.warmup_prefix else t
|
@ -107,6 +107,7 @@ class PolyLRScheduler(Scheduler):
|
|||||||
def get_cycle_length(self, cycles=0):
|
def get_cycle_length(self, cycles=0):
|
||||||
cycles = max(1, cycles or self.cycle_limit)
|
cycles = max(1, cycles or self.cycle_limit)
|
||||||
if self.cycle_mul == 1.0:
|
if self.cycle_mul == 1.0:
|
||||||
return self.t_initial * cycles
|
t = self.t_initial * cycles
|
||||||
else:
|
else:
|
||||||
return int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul)))
|
t = int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul)))
|
||||||
|
return t + self.warmup_t if self.warmup_prefix else t
|
||||||
|
@ -196,11 +196,15 @@ def create_scheduler_v2(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if hasattr(lr_scheduler, 'get_cycle_length'):
|
if hasattr(lr_scheduler, 'get_cycle_length'):
|
||||||
# for cycle based schedulers (cosine, tanh, poly) recalculate total epochs w/ cycles & cooldown
|
# For cycle based schedulers (cosine, tanh, poly) recalculate total epochs w/ cycles & cooldown
|
||||||
|
# NOTE: Warmup prefix added in get_cycle_lengths() if enabled
|
||||||
t_with_cycles_and_cooldown = lr_scheduler.get_cycle_length() + cooldown_t
|
t_with_cycles_and_cooldown = lr_scheduler.get_cycle_length() + cooldown_t
|
||||||
if step_on_epochs:
|
if step_on_epochs:
|
||||||
num_epochs = t_with_cycles_and_cooldown
|
num_epochs = t_with_cycles_and_cooldown
|
||||||
else:
|
else:
|
||||||
num_epochs = t_with_cycles_and_cooldown // updates_per_epoch
|
num_epochs = t_with_cycles_and_cooldown // updates_per_epoch
|
||||||
|
else:
|
||||||
|
if warmup_prefix:
|
||||||
|
num_epochs += warmup_epochs
|
||||||
|
|
||||||
return lr_scheduler, num_epochs
|
return lr_scheduler, num_epochs
|
||||||
|
@ -108,6 +108,7 @@ class TanhLRScheduler(Scheduler):
|
|||||||
def get_cycle_length(self, cycles=0):
|
def get_cycle_length(self, cycles=0):
|
||||||
cycles = max(1, cycles or self.cycle_limit)
|
cycles = max(1, cycles or self.cycle_limit)
|
||||||
if self.cycle_mul == 1.0:
|
if self.cycle_mul == 1.0:
|
||||||
return self.t_initial * cycles
|
t = self.t_initial * cycles
|
||||||
else:
|
else:
|
||||||
return int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul)))
|
t = int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul)))
|
||||||
|
return t + self.warmup_t if self.warmup_prefix else t
|
||||||
|
Loading…
x
Reference in New Issue
Block a user