Merge 065764ad21
into e1277af2ba
commit
180ed7bbd3
|
@ -18,7 +18,7 @@ import dinov2.distributed as distributed
|
||||||
from dinov2.fsdp import FSDPCheckpointer
|
from dinov2.fsdp import FSDPCheckpointer
|
||||||
from dinov2.logging import MetricLogger
|
from dinov2.logging import MetricLogger
|
||||||
from dinov2.utils.config import setup
|
from dinov2.utils.config import setup
|
||||||
from dinov2.utils.utils import CosineScheduler
|
from dinov2.utils.utils import MemEfficientCosineScheduler
|
||||||
|
|
||||||
from dinov2.train.ssl_meta_arch import SSLMetaArch
|
from dinov2.train.ssl_meta_arch import SSLMetaArch
|
||||||
|
|
||||||
|
@ -89,15 +89,14 @@ def build_schedulers(cfg):
|
||||||
start_warmup_value=cfg.teacher["warmup_teacher_temp"],
|
start_warmup_value=cfg.teacher["warmup_teacher_temp"],
|
||||||
)
|
)
|
||||||
|
|
||||||
lr_schedule = CosineScheduler(**lr)
|
lr_schedule = MemEfficientCosineScheduler(**lr)
|
||||||
wd_schedule = CosineScheduler(**wd)
|
wd_schedule = MemEfficientCosineScheduler(**wd)
|
||||||
momentum_schedule = CosineScheduler(**momentum)
|
momentum_schedule = MemEfficientCosineScheduler(**momentum)
|
||||||
teacher_temp_schedule = CosineScheduler(**teacher_temp)
|
teacher_temp_schedule = MemEfficientCosineScheduler(**teacher_temp)
|
||||||
last_layer_lr_schedule = CosineScheduler(**lr)
|
# this is a hack to mimic the original schedules
|
||||||
|
_lr = lr.copy()
|
||||||
last_layer_lr_schedule.schedule[
|
_lr.update(freeze_iters=cfg.optim["freeze_last_layer_epochs"] * OFFICIAL_EPOCH_LENGTH)
|
||||||
: cfg.optim["freeze_last_layer_epochs"] * OFFICIAL_EPOCH_LENGTH
|
last_layer_lr_schedule = MemEfficientCosineScheduler(**_lr)
|
||||||
] = 0 # mimicking the original schedules
|
|
||||||
|
|
||||||
logger.info("Schedulers ready.")
|
logger.info("Schedulers ready.")
|
||||||
|
|
||||||
|
|
|
@ -87,6 +87,37 @@ class CosineScheduler(object):
|
||||||
return self.schedule[it]
|
return self.schedule[it]
|
||||||
|
|
||||||
|
|
||||||
|
class MemEfficientCosineScheduler:
|
||||||
|
def __init__(self, base_value, final_value, total_iters, warmup_iters=0, start_warmup_value=0, freeze_iters=0):
|
||||||
|
super().__init__()
|
||||||
|
self.final_value = final_value
|
||||||
|
self.total_iters = total_iters
|
||||||
|
self.start_warmup_value = start_warmup_value
|
||||||
|
self.base_value = base_value
|
||||||
|
self.freeze_iters = freeze_iters
|
||||||
|
self.warmup_iters = warmup_iters
|
||||||
|
|
||||||
|
def __getitem__(self, it):
|
||||||
|
if it >= self.total_iters:
|
||||||
|
return self.final_value
|
||||||
|
|
||||||
|
if it < self.freeze_iters:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
if it < self.freeze_iters + self.warmup_iters:
|
||||||
|
# Linear warmup - fixed to match original implementation
|
||||||
|
alpha = (it - self.freeze_iters) / max(1, self.warmup_iters)
|
||||||
|
value = self.start_warmup_value * (1 - alpha) + self.base_value * alpha
|
||||||
|
return value
|
||||||
|
|
||||||
|
# Cosine schedule - this part needed adjustment to match CosineScheduler
|
||||||
|
effective_it = it - self.freeze_iters - self.warmup_iters
|
||||||
|
total_cosine_iters = self.total_iters - self.warmup_iters - self.freeze_iters
|
||||||
|
return self.final_value + 0.5 * (self.base_value - self.final_value) * (
|
||||||
|
1 + np.cos(np.pi * effective_it / total_cosine_iters)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def has_batchnorms(model):
|
def has_batchnorms(model):
|
||||||
bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)
|
bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
|
|
Loading…
Reference in New Issue