Merge 065764ad21
into e1277af2ba
commit
180ed7bbd3
|
@ -18,7 +18,7 @@ import dinov2.distributed as distributed
|
|||
from dinov2.fsdp import FSDPCheckpointer
|
||||
from dinov2.logging import MetricLogger
|
||||
from dinov2.utils.config import setup
|
||||
from dinov2.utils.utils import CosineScheduler
|
||||
from dinov2.utils.utils import MemEfficientCosineScheduler
|
||||
|
||||
from dinov2.train.ssl_meta_arch import SSLMetaArch
|
||||
|
||||
|
@ -89,15 +89,14 @@ def build_schedulers(cfg):
|
|||
start_warmup_value=cfg.teacher["warmup_teacher_temp"],
|
||||
)
|
||||
|
||||
lr_schedule = CosineScheduler(**lr)
|
||||
wd_schedule = CosineScheduler(**wd)
|
||||
momentum_schedule = CosineScheduler(**momentum)
|
||||
teacher_temp_schedule = CosineScheduler(**teacher_temp)
|
||||
last_layer_lr_schedule = CosineScheduler(**lr)
|
||||
|
||||
last_layer_lr_schedule.schedule[
|
||||
: cfg.optim["freeze_last_layer_epochs"] * OFFICIAL_EPOCH_LENGTH
|
||||
] = 0 # mimicking the original schedules
|
||||
lr_schedule = MemEfficientCosineScheduler(**lr)
|
||||
wd_schedule = MemEfficientCosineScheduler(**wd)
|
||||
momentum_schedule = MemEfficientCosineScheduler(**momentum)
|
||||
teacher_temp_schedule = MemEfficientCosineScheduler(**teacher_temp)
|
||||
# this is a hack to mimic the original schedules
|
||||
_lr = lr.copy()
|
||||
_lr.update(freeze_iters=cfg.optim["freeze_last_layer_epochs"] * OFFICIAL_EPOCH_LENGTH)
|
||||
last_layer_lr_schedule = MemEfficientCosineScheduler(**_lr)
|
||||
|
||||
logger.info("Schedulers ready.")
|
||||
|
||||
|
|
|
@ -85,7 +85,38 @@ class CosineScheduler(object):
|
|||
return self.final_value
|
||||
else:
|
||||
return self.schedule[it]
|
||||
|
||||
|
||||
class MemEfficientCosineScheduler:
|
||||
def __init__(self, base_value, final_value, total_iters, warmup_iters=0, start_warmup_value=0, freeze_iters=0):
|
||||
super().__init__()
|
||||
self.final_value = final_value
|
||||
self.total_iters = total_iters
|
||||
self.start_warmup_value = start_warmup_value
|
||||
self.base_value = base_value
|
||||
self.freeze_iters = freeze_iters
|
||||
self.warmup_iters = warmup_iters
|
||||
|
||||
def __getitem__(self, it):
|
||||
if it >= self.total_iters:
|
||||
return self.final_value
|
||||
|
||||
if it < self.freeze_iters:
|
||||
return 0.0
|
||||
|
||||
if it < self.freeze_iters + self.warmup_iters:
|
||||
# Linear warmup - fixed to match original implementation
|
||||
alpha = (it - self.freeze_iters) / max(1, self.warmup_iters)
|
||||
value = self.start_warmup_value * (1 - alpha) + self.base_value * alpha
|
||||
return value
|
||||
|
||||
# Cosine schedule - this part needed adjustment to match CosineScheduler
|
||||
effective_it = it - self.freeze_iters - self.warmup_iters
|
||||
total_cosine_iters = self.total_iters - self.warmup_iters - self.freeze_iters
|
||||
return self.final_value + 0.5 * (self.base_value - self.final_value) * (
|
||||
1 + np.cos(np.pi * effective_it / total_cosine_iters)
|
||||
)
|
||||
|
||||
|
||||
def has_batchnorms(model):
|
||||
bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)
|
||||
|
|
Loading…
Reference in New Issue