pull/491/merge
dmitrysarov 2024-12-19 09:03:05 +01:00 committed by GitHub
commit 180ed7bbd3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 40 additions and 10 deletions

View File

@ -18,7 +18,7 @@ import dinov2.distributed as distributed
from dinov2.fsdp import FSDPCheckpointer from dinov2.fsdp import FSDPCheckpointer
from dinov2.logging import MetricLogger from dinov2.logging import MetricLogger
from dinov2.utils.config import setup from dinov2.utils.config import setup
from dinov2.utils.utils import CosineScheduler from dinov2.utils.utils import MemEfficientCosineScheduler
from dinov2.train.ssl_meta_arch import SSLMetaArch from dinov2.train.ssl_meta_arch import SSLMetaArch
@ -89,15 +89,14 @@ def build_schedulers(cfg):
start_warmup_value=cfg.teacher["warmup_teacher_temp"], start_warmup_value=cfg.teacher["warmup_teacher_temp"],
) )
lr_schedule = CosineScheduler(**lr) lr_schedule = MemEfficientCosineScheduler(**lr)
wd_schedule = CosineScheduler(**wd) wd_schedule = MemEfficientCosineScheduler(**wd)
momentum_schedule = CosineScheduler(**momentum) momentum_schedule = MemEfficientCosineScheduler(**momentum)
teacher_temp_schedule = CosineScheduler(**teacher_temp) teacher_temp_schedule = MemEfficientCosineScheduler(**teacher_temp)
last_layer_lr_schedule = CosineScheduler(**lr) # this is a hack to mimic the original schedules
_lr = lr.copy()
last_layer_lr_schedule.schedule[ _lr.update(freeze_iters=cfg.optim["freeze_last_layer_epochs"] * OFFICIAL_EPOCH_LENGTH)
: cfg.optim["freeze_last_layer_epochs"] * OFFICIAL_EPOCH_LENGTH last_layer_lr_schedule = MemEfficientCosineScheduler(**_lr)
] = 0 # mimicking the original schedules
logger.info("Schedulers ready.") logger.info("Schedulers ready.")

View File

@ -87,6 +87,37 @@ class CosineScheduler(object):
return self.schedule[it] return self.schedule[it]
class MemEfficientCosineScheduler:
def __init__(self, base_value, final_value, total_iters, warmup_iters=0, start_warmup_value=0, freeze_iters=0):
super().__init__()
self.final_value = final_value
self.total_iters = total_iters
self.start_warmup_value = start_warmup_value
self.base_value = base_value
self.freeze_iters = freeze_iters
self.warmup_iters = warmup_iters
def __getitem__(self, it):
if it >= self.total_iters:
return self.final_value
if it < self.freeze_iters:
return 0.0
if it < self.freeze_iters + self.warmup_iters:
# Linear warmup - fixed to match original implementation
alpha = (it - self.freeze_iters) / max(1, self.warmup_iters)
value = self.start_warmup_value * (1 - alpha) + self.base_value * alpha
return value
# Cosine schedule - this part needed adjustment to match CosineScheduler
effective_it = it - self.freeze_iters - self.warmup_iters
total_cosine_iters = self.total_iters - self.warmup_iters - self.freeze_iters
return self.final_value + 0.5 * (self.base_value - self.final_value) * (
1 + np.cos(np.pi * effective_it / total_cosine_iters)
)
def has_batchnorms(model): def has_batchnorms(model):
bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm) bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)
for name, module in model.named_modules(): for name, module in model.named_modules():