Extend train epoch schedule by warmup_epochs if warmup_prefix enable, allows schedule to reach end w/ prefix enabledy
parent
7f0c1b1f30
commit
363b043c13
|
@ -111,6 +111,7 @@ class CosineLRScheduler(Scheduler):
|
|||
def get_cycle_length(self, cycles=0):
|
||||
cycles = max(1, cycles or self.cycle_limit)
|
||||
if self.cycle_mul == 1.0:
|
||||
return self.t_initial * cycles
|
||||
t = self.t_initial * cycles
|
||||
else:
|
||||
return int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul)))
|
||||
t = int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul)))
|
||||
return t + self.warmup_t if self.warmup_prefix else t
|
|
@ -107,6 +107,7 @@ class PolyLRScheduler(Scheduler):
|
|||
def get_cycle_length(self, cycles=0):
|
||||
cycles = max(1, cycles or self.cycle_limit)
|
||||
if self.cycle_mul == 1.0:
|
||||
return self.t_initial * cycles
|
||||
t = self.t_initial * cycles
|
||||
else:
|
||||
return int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul)))
|
||||
t = int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul)))
|
||||
return t + self.warmup_t if self.warmup_prefix else t
|
||||
|
|
|
@ -196,11 +196,15 @@ def create_scheduler_v2(
|
|||
)
|
||||
|
||||
if hasattr(lr_scheduler, 'get_cycle_length'):
|
||||
# for cycle based schedulers (cosine, tanh, poly) recalculate total epochs w/ cycles & cooldown
|
||||
# For cycle based schedulers (cosine, tanh, poly) recalculate total epochs w/ cycles & cooldown
|
||||
# NOTE: Warmup prefix added in get_cycle_lengths() if enabled
|
||||
t_with_cycles_and_cooldown = lr_scheduler.get_cycle_length() + cooldown_t
|
||||
if step_on_epochs:
|
||||
num_epochs = t_with_cycles_and_cooldown
|
||||
else:
|
||||
num_epochs = t_with_cycles_and_cooldown // updates_per_epoch
|
||||
else:
|
||||
if warmup_prefix:
|
||||
num_epochs += warmup_epochs
|
||||
|
||||
return lr_scheduler, num_epochs
|
||||
|
|
|
@ -108,6 +108,7 @@ class TanhLRScheduler(Scheduler):
|
|||
def get_cycle_length(self, cycles=0):
|
||||
cycles = max(1, cycles or self.cycle_limit)
|
||||
if self.cycle_mul == 1.0:
|
||||
return self.t_initial * cycles
|
||||
t = self.t_initial * cycles
|
||||
else:
|
||||
return int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul)))
|
||||
t = int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul)))
|
||||
return t + self.warmup_t if self.warmup_prefix else t
|
||||
|
|
Loading…
Reference in New Issue