diff --git a/timm/scheduler/cosine_lr.py b/timm/scheduler/cosine_lr.py index 4eaaa86a..00dd9357 100644 --- a/timm/scheduler/cosine_lr.py +++ b/timm/scheduler/cosine_lr.py @@ -111,6 +111,7 @@ class CosineLRScheduler(Scheduler): def get_cycle_length(self, cycles=0): cycles = max(1, cycles or self.cycle_limit) if self.cycle_mul == 1.0: - return self.t_initial * cycles + t = self.t_initial * cycles else: - return int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul))) + t = int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul))) + return t + self.warmup_t if self.warmup_prefix else t \ No newline at end of file diff --git a/timm/scheduler/poly_lr.py b/timm/scheduler/poly_lr.py index 8875e15b..f7971302 100644 --- a/timm/scheduler/poly_lr.py +++ b/timm/scheduler/poly_lr.py @@ -107,6 +107,7 @@ class PolyLRScheduler(Scheduler): def get_cycle_length(self, cycles=0): cycles = max(1, cycles or self.cycle_limit) if self.cycle_mul == 1.0: - return self.t_initial * cycles + t = self.t_initial * cycles else: - return int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul))) + t = int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul))) + return t + self.warmup_t if self.warmup_prefix else t diff --git a/timm/scheduler/scheduler_factory.py b/timm/scheduler/scheduler_factory.py index caf68fad..08c5e180 100644 --- a/timm/scheduler/scheduler_factory.py +++ b/timm/scheduler/scheduler_factory.py @@ -196,11 +196,15 @@ def create_scheduler_v2( ) if hasattr(lr_scheduler, 'get_cycle_length'): - # for cycle based schedulers (cosine, tanh, poly) recalculate total epochs w/ cycles & cooldown + # For cycle based schedulers (cosine, tanh, poly) recalculate total epochs w/ cycles & cooldown + # NOTE: Warmup prefix added in get_cycle_lengths() if enabled t_with_cycles_and_cooldown = lr_scheduler.get_cycle_length() + cooldown_t if step_on_epochs: num_epochs = t_with_cycles_and_cooldown else: num_epochs = t_with_cycles_and_cooldown // updates_per_epoch + else: + if warmup_prefix: + num_epochs += warmup_epochs return lr_scheduler, num_epochs diff --git a/timm/scheduler/tanh_lr.py b/timm/scheduler/tanh_lr.py index 94455302..93222926 100644 --- a/timm/scheduler/tanh_lr.py +++ b/timm/scheduler/tanh_lr.py @@ -108,6 +108,7 @@ class TanhLRScheduler(Scheduler): def get_cycle_length(self, cycles=0): cycles = max(1, cycles or self.cycle_limit) if self.cycle_mul == 1.0: - return self.t_initial * cycles + t = self.t_initial * cycles else: - return int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul))) + t = int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul))) + return t + self.warmup_t if self.warmup_prefix else t