pull/491/merge
dmitrysarov 2024-12-19 09:03:05 +01:00 committed by GitHub
commit 180ed7bbd3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 40 additions and 10 deletions

View File

@ -18,7 +18,7 @@ import dinov2.distributed as distributed
from dinov2.fsdp import FSDPCheckpointer
from dinov2.logging import MetricLogger
from dinov2.utils.config import setup
from dinov2.utils.utils import CosineScheduler
from dinov2.utils.utils import MemEfficientCosineScheduler
from dinov2.train.ssl_meta_arch import SSLMetaArch
@ -89,15 +89,14 @@ def build_schedulers(cfg):
start_warmup_value=cfg.teacher["warmup_teacher_temp"],
)
lr_schedule = CosineScheduler(**lr)
wd_schedule = CosineScheduler(**wd)
momentum_schedule = CosineScheduler(**momentum)
teacher_temp_schedule = CosineScheduler(**teacher_temp)
last_layer_lr_schedule = CosineScheduler(**lr)
last_layer_lr_schedule.schedule[
: cfg.optim["freeze_last_layer_epochs"] * OFFICIAL_EPOCH_LENGTH
] = 0 # mimicking the original schedules
lr_schedule = MemEfficientCosineScheduler(**lr)
wd_schedule = MemEfficientCosineScheduler(**wd)
momentum_schedule = MemEfficientCosineScheduler(**momentum)
teacher_temp_schedule = MemEfficientCosineScheduler(**teacher_temp)
# this is a hack to mimic the original schedules
_lr = lr.copy()
_lr.update(freeze_iters=cfg.optim["freeze_last_layer_epochs"] * OFFICIAL_EPOCH_LENGTH)
last_layer_lr_schedule = MemEfficientCosineScheduler(**_lr)
logger.info("Schedulers ready.")

View File

@ -85,7 +85,38 @@ class CosineScheduler(object):
return self.final_value
else:
return self.schedule[it]
class MemEfficientCosineScheduler:
def __init__(self, base_value, final_value, total_iters, warmup_iters=0, start_warmup_value=0, freeze_iters=0):
super().__init__()
self.final_value = final_value
self.total_iters = total_iters
self.start_warmup_value = start_warmup_value
self.base_value = base_value
self.freeze_iters = freeze_iters
self.warmup_iters = warmup_iters
def __getitem__(self, it):
if it >= self.total_iters:
return self.final_value
if it < self.freeze_iters:
return 0.0
if it < self.freeze_iters + self.warmup_iters:
# Linear warmup - fixed to match original implementation
alpha = (it - self.freeze_iters) / max(1, self.warmup_iters)
value = self.start_warmup_value * (1 - alpha) + self.base_value * alpha
return value
# Cosine schedule - this part needed adjustment to match CosineScheduler
effective_it = it - self.freeze_iters - self.warmup_iters
total_cosine_iters = self.total_iters - self.warmup_iters - self.freeze_iters
return self.final_value + 0.5 * (self.base_value - self.final_value) * (
1 + np.cos(np.pi * effective_it / total_cosine_iters)
)
def has_batchnorms(model):
bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)