diff --git a/distillation/config/config_stdc.yaml b/distillation/config/config_stdc.yaml index 5f6da79..75bb24b 100644 --- a/distillation/config/config_stdc.yaml +++ b/distillation/config/config_stdc.yaml @@ -78,4 +78,6 @@ checkpoints: filename: '{epoch}-{val_similarity:.2f}' monitor: val_similarity mode: max - save_top_k: 3 \ No newline at end of file + save_top_k: 3 + +resume_from_checkpoint: /home/arda/dinov2/distillation/checkpoints/stdc/epoch=2-val_similarity=0.33.ckpt \ No newline at end of file diff --git a/distillation/train.py b/distillation/train.py index f9f6c9d..662abc8 100644 --- a/distillation/train.py +++ b/distillation/train.py @@ -106,28 +106,51 @@ class DistillationTrainer: logger = TensorBoardLogger(f"logs/{self.cfg.train.name}", name="distillation") logger.log_hyperparams(self.cfg) - + checkpoint_dir = os.path.join(self.cfg.checkpoints.dirpath, self.cfg.train.name) + checkpoint_path = self._find_latest_checkpoint(checkpoint_dir) + # Also save config as text for better readability experiment_dir = logger.log_dir os.makedirs(experiment_dir, exist_ok=True) config_path = os.path.join(experiment_dir, 'config.yaml') OmegaConf.save(self.cfg, config_path) + if checkpoint_path: + print(f"Resuming training from checkpoint: {checkpoint_path}") + else: + print("Starting training from scratch") + return pl.Trainer( max_epochs=self.training_config.max_epochs, accelerator=self.cfg.train.accelerator, devices=self.cfg.train.devices, num_nodes=self.cfg.train.num_nodes, - strategy=self.cfg.train.strategy, # 'ddp', 'deepspeed', etc. + strategy=self.cfg.train.strategy, precision=self.training_config.precision, callbacks=[checkpoint_callback], logger=logger, + resume_from_checkpoint=checkpoint_path ) def train(self): """Execute training pipeline.""" self.trainer.fit(self.distillation_module, self.data_module) + def _find_latest_checkpoint(self, checkpoint_dir: str) -> str | None: + """Find the latest checkpoint in the specified directory.""" + if not os.path.exists(checkpoint_dir): + return None + + checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith('.ckpt')] + if not checkpoints: + return None + + # Sort by modification time, newest first + latest_checkpoint = max( + checkpoints, + key=lambda x: os.path.getmtime(os.path.join(checkpoint_dir, x)) + ) + return os.path.join(checkpoint_dir, latest_checkpoint) def setup_environment(): """Setup environment configurations."""