From 065764ad215662ca90248a5de1c703feda149199 Mon Sep 17 00:00:00 2001 From: Dmitrii Victorovich Shaulskii Date: Thu, 19 Dec 2024 09:00:04 +0100 Subject: [PATCH] mem efficient cosine scheduler --- dinov2/train/train.py | 19 +++++++++---------- dinov2/utils/utils.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 10 deletions(-) diff --git a/dinov2/train/train.py b/dinov2/train/train.py index 473b8d0..2ed24a6 100644 --- a/dinov2/train/train.py +++ b/dinov2/train/train.py @@ -18,7 +18,7 @@ import dinov2.distributed as distributed from dinov2.fsdp import FSDPCheckpointer from dinov2.logging import MetricLogger from dinov2.utils.config import setup -from dinov2.utils.utils import CosineScheduler +from dinov2.utils.utils import MemEfficientCosineScheduler from dinov2.train.ssl_meta_arch import SSLMetaArch @@ -89,15 +89,14 @@ def build_schedulers(cfg): start_warmup_value=cfg.teacher["warmup_teacher_temp"], ) - lr_schedule = CosineScheduler(**lr) - wd_schedule = CosineScheduler(**wd) - momentum_schedule = CosineScheduler(**momentum) - teacher_temp_schedule = CosineScheduler(**teacher_temp) - last_layer_lr_schedule = CosineScheduler(**lr) - - last_layer_lr_schedule.schedule[ - : cfg.optim["freeze_last_layer_epochs"] * OFFICIAL_EPOCH_LENGTH - ] = 0 # mimicking the original schedules + lr_schedule = MemEfficientCosineScheduler(**lr) + wd_schedule = MemEfficientCosineScheduler(**wd) + momentum_schedule = MemEfficientCosineScheduler(**momentum) + teacher_temp_schedule = MemEfficientCosineScheduler(**teacher_temp) + # this is a hack to mimic the original schedules + _lr = lr.copy() + _lr.update(freeze_iters=cfg.optim["freeze_last_layer_epochs"] * OFFICIAL_EPOCH_LENGTH) + last_layer_lr_schedule = MemEfficientCosineScheduler(**_lr) logger.info("Schedulers ready.") diff --git a/dinov2/utils/utils.py b/dinov2/utils/utils.py index 68f8e2c..0406fb1 100644 --- a/dinov2/utils/utils.py +++ b/dinov2/utils/utils.py @@ -85,7 +85,38 @@ class CosineScheduler(object): return self.final_value else: return self.schedule[it] + +class MemEfficientCosineScheduler: + def __init__(self, base_value, final_value, total_iters, warmup_iters=0, start_warmup_value=0, freeze_iters=0): + super().__init__() + self.final_value = final_value + self.total_iters = total_iters + self.start_warmup_value = start_warmup_value + self.base_value = base_value + self.freeze_iters = freeze_iters + self.warmup_iters = warmup_iters + + def __getitem__(self, it): + if it >= self.total_iters: + return self.final_value + + if it < self.freeze_iters: + return 0.0 + + if it < self.freeze_iters + self.warmup_iters: + # Linear warmup - fixed to match original implementation + alpha = (it - self.freeze_iters) / max(1, self.warmup_iters) + value = self.start_warmup_value * (1 - alpha) + self.base_value * alpha + return value + + # Cosine schedule - this part needed adjustment to match CosineScheduler + effective_it = it - self.freeze_iters - self.warmup_iters + total_cosine_iters = self.total_iters - self.warmup_iters - self.freeze_iters + return self.final_value + 0.5 * (self.base_value - self.final_value) * ( + 1 + np.cos(np.pi * effective_it / total_cosine_iters) + ) + def has_batchnorms(model): bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)