Extend train epoch schedule by warmup_epochs if warmup_prefix enable, allows schedule to reach end w/ prefix enabledy

pull/2331/head
Ross Wightman 2024-11-07 10:24:06 -08:00 committed by Ross Wightman
parent 7f0c1b1f30
commit 363b043c13
4 changed files with 14 additions and 7 deletions

View File

@ -111,6 +111,7 @@ class CosineLRScheduler(Scheduler):
def get_cycle_length(self, cycles=0):
cycles = max(1, cycles or self.cycle_limit)
if self.cycle_mul == 1.0:
return self.t_initial * cycles
t = self.t_initial * cycles
else:
return int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul)))
t = int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul)))
return t + self.warmup_t if self.warmup_prefix else t

View File

@ -107,6 +107,7 @@ class PolyLRScheduler(Scheduler):
def get_cycle_length(self, cycles=0):
cycles = max(1, cycles or self.cycle_limit)
if self.cycle_mul == 1.0:
return self.t_initial * cycles
t = self.t_initial * cycles
else:
return int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul)))
t = int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul)))
return t + self.warmup_t if self.warmup_prefix else t

View File

@ -196,11 +196,15 @@ def create_scheduler_v2(
)
if hasattr(lr_scheduler, 'get_cycle_length'):
# for cycle based schedulers (cosine, tanh, poly) recalculate total epochs w/ cycles & cooldown
# For cycle based schedulers (cosine, tanh, poly) recalculate total epochs w/ cycles & cooldown
# NOTE: Warmup prefix added in get_cycle_lengths() if enabled
t_with_cycles_and_cooldown = lr_scheduler.get_cycle_length() + cooldown_t
if step_on_epochs:
num_epochs = t_with_cycles_and_cooldown
else:
num_epochs = t_with_cycles_and_cooldown // updates_per_epoch
else:
if warmup_prefix:
num_epochs += warmup_epochs
return lr_scheduler, num_epochs

View File

@ -108,6 +108,7 @@ class TanhLRScheduler(Scheduler):
def get_cycle_length(self, cycles=0):
cycles = max(1, cycles or self.cycle_limit)
if self.cycle_mul == 1.0:
return self.t_initial * cycles
t = self.t_initial * cycles
else:
return int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul)))
t = int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul)))
return t + self.warmup_t if self.warmup_prefix else t