Added code for continuing training.
parent
26d9ca6d12
commit
69b50c8ff2
|
@ -78,4 +78,6 @@ checkpoints:
|
|||
filename: '{epoch}-{val_similarity:.2f}'
|
||||
monitor: val_similarity
|
||||
mode: max
|
||||
save_top_k: 3
|
||||
save_top_k: 3
|
||||
|
||||
resume_from_checkpoint: /home/arda/dinov2/distillation/checkpoints/stdc/epoch=2-val_similarity=0.33.ckpt
|
|
@ -106,28 +106,51 @@ class DistillationTrainer:
|
|||
|
||||
logger = TensorBoardLogger(f"logs/{self.cfg.train.name}", name="distillation")
|
||||
logger.log_hyperparams(self.cfg)
|
||||
|
||||
checkpoint_dir = os.path.join(self.cfg.checkpoints.dirpath, self.cfg.train.name)
|
||||
checkpoint_path = self._find_latest_checkpoint(checkpoint_dir)
|
||||
|
||||
# Also save config as text for better readability
|
||||
experiment_dir = logger.log_dir
|
||||
os.makedirs(experiment_dir, exist_ok=True)
|
||||
config_path = os.path.join(experiment_dir, 'config.yaml')
|
||||
OmegaConf.save(self.cfg, config_path)
|
||||
|
||||
if checkpoint_path:
|
||||
print(f"Resuming training from checkpoint: {checkpoint_path}")
|
||||
else:
|
||||
print("Starting training from scratch")
|
||||
|
||||
return pl.Trainer(
|
||||
max_epochs=self.training_config.max_epochs,
|
||||
accelerator=self.cfg.train.accelerator,
|
||||
devices=self.cfg.train.devices,
|
||||
num_nodes=self.cfg.train.num_nodes,
|
||||
strategy=self.cfg.train.strategy, # 'ddp', 'deepspeed', etc.
|
||||
strategy=self.cfg.train.strategy,
|
||||
precision=self.training_config.precision,
|
||||
callbacks=[checkpoint_callback],
|
||||
logger=logger,
|
||||
resume_from_checkpoint=checkpoint_path
|
||||
)
|
||||
|
||||
def train(self):
|
||||
"""Execute training pipeline."""
|
||||
self.trainer.fit(self.distillation_module, self.data_module)
|
||||
|
||||
def _find_latest_checkpoint(self, checkpoint_dir: str) -> str | None:
|
||||
"""Find the latest checkpoint in the specified directory."""
|
||||
if not os.path.exists(checkpoint_dir):
|
||||
return None
|
||||
|
||||
checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith('.ckpt')]
|
||||
if not checkpoints:
|
||||
return None
|
||||
|
||||
# Sort by modification time, newest first
|
||||
latest_checkpoint = max(
|
||||
checkpoints,
|
||||
key=lambda x: os.path.getmtime(os.path.join(checkpoint_dir, x))
|
||||
)
|
||||
return os.path.join(checkpoint_dir, latest_checkpoint)
|
||||
|
||||
def setup_environment():
|
||||
"""Setup environment configurations."""
|
||||
|
|
Loading…
Reference in New Issue