Added code for continuing training.

pull/493/head
arda 2024-12-01 04:47:56 +00:00
parent 26d9ca6d12
commit 69b50c8ff2
2 changed files with 28 additions and 3 deletions

View File

@ -78,4 +78,6 @@ checkpoints:
filename: '{epoch}-{val_similarity:.2f}'
monitor: val_similarity
mode: max
save_top_k: 3
save_top_k: 3
resume_from_checkpoint: /home/arda/dinov2/distillation/checkpoints/stdc/epoch=2-val_similarity=0.33.ckpt

View File

@ -106,28 +106,51 @@ class DistillationTrainer:
logger = TensorBoardLogger(f"logs/{self.cfg.train.name}", name="distillation")
logger.log_hyperparams(self.cfg)
checkpoint_dir = os.path.join(self.cfg.checkpoints.dirpath, self.cfg.train.name)
checkpoint_path = self._find_latest_checkpoint(checkpoint_dir)
# Also save config as text for better readability
experiment_dir = logger.log_dir
os.makedirs(experiment_dir, exist_ok=True)
config_path = os.path.join(experiment_dir, 'config.yaml')
OmegaConf.save(self.cfg, config_path)
if checkpoint_path:
print(f"Resuming training from checkpoint: {checkpoint_path}")
else:
print("Starting training from scratch")
return pl.Trainer(
max_epochs=self.training_config.max_epochs,
accelerator=self.cfg.train.accelerator,
devices=self.cfg.train.devices,
num_nodes=self.cfg.train.num_nodes,
strategy=self.cfg.train.strategy, # 'ddp', 'deepspeed', etc.
strategy=self.cfg.train.strategy,
precision=self.training_config.precision,
callbacks=[checkpoint_callback],
logger=logger,
resume_from_checkpoint=checkpoint_path
)
def train(self):
"""Execute training pipeline."""
self.trainer.fit(self.distillation_module, self.data_module)
def _find_latest_checkpoint(self, checkpoint_dir: str) -> str | None:
"""Find the latest checkpoint in the specified directory."""
if not os.path.exists(checkpoint_dir):
return None
checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith('.ckpt')]
if not checkpoints:
return None
# Sort by modification time, newest first
latest_checkpoint = max(
checkpoints,
key=lambda x: os.path.getmtime(os.path.join(checkpoint_dir, x))
)
return os.path.join(checkpoint_dir, latest_checkpoint)
def setup_environment():
"""Setup environment configurations."""