From 9164a24a56491e08cc22b49241cc4c58acd7c5a1 Mon Sep 17 00:00:00 2001
From: Adrian Wolny <adrian.wolny@bayer.com>
Date: Wed, 21 Aug 2024 15:54:37 +0200
Subject: [PATCH] setup wandb logging

---
 conda.yaml                | 14 ++++++++------
 dinov2/logging/helpers.py |  5 +++++
 dinov2/train/train.py     |  1 +
 dinov2/utils/config.py    | 18 ++++++++++++++++--
 4 files changed, 30 insertions(+), 8 deletions(-)

diff --git a/conda.yaml b/conda.yaml
index 35dfc30..cc82dac 100644
--- a/conda.yaml
+++ b/conda.yaml
@@ -6,15 +6,17 @@ channels:
   - xformers
   - conda-forge
 dependencies:
-  - python=3.9
-  - pytorch::pytorch=2.0.0
-  - pytorch::pytorch-cuda=11.7.0
-  - pytorch::torchvision=0.15.0
+  - python
+  - pytorch::pytorch
+  - pytorch::pytorch-cuda=12.1
+  - pytorch::torchvision
   - omegaconf
-  - torchmetrics=0.10.3
+  - torchmetrics
   - fvcore
   - iopath
-  - xformers::xformers=0.0.18
+  - xformers::xformers
+  - wandb
+  - h5py
   - pip
   - pip:
     - git+https://github.com/facebookincubator/submitit
diff --git a/dinov2/logging/helpers.py b/dinov2/logging/helpers.py
index c6e70bb..ba28061 100644
--- a/dinov2/logging/helpers.py
+++ b/dinov2/logging/helpers.py
@@ -10,6 +10,7 @@ import logging
 import time
 
 import torch
+import wandb
 
 import dinov2.distributed as distributed
 
@@ -59,6 +60,10 @@ class MetricLogger(object):
             data_time=data_time,
         )
         dict_to_dump.update({k: v.median for k, v in self.meters.items()})
+        # log to wandb as well
+        for k, v in self.meters.items():
+            wandb.log({k: v.median}, step=iteration)
+
         with open(self.output_file, "a") as f:
             f.write(json.dumps(dict_to_dump) + "\n")
         pass
diff --git a/dinov2/train/train.py b/dinov2/train/train.py
index 473b8d0..a20cedd 100644
--- a/dinov2/train/train.py
+++ b/dinov2/train/train.py
@@ -30,6 +30,7 @@ logger = logging.getLogger("dinov2")
 def get_args_parser(add_help: bool = True):
     parser = argparse.ArgumentParser("DINOv2 training", add_help=add_help)
     parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file")
+    parser.add_argument("--no-wandb", action="store_true", help="Whether to not use wandb")
     parser.add_argument(
         "--no-resume",
         action="store_true",
diff --git a/dinov2/utils/config.py b/dinov2/utils/config.py
index c9de578..2cdcac6 100644
--- a/dinov2/utils/config.py
+++ b/dinov2/utils/config.py
@@ -7,14 +7,14 @@ import math
 import logging
 import os
 
-from omegaconf import OmegaConf
+import wandb
+from omegaconf import OmegaConf, DictConfig
 
 import dinov2.distributed as distributed
 from dinov2.logging import setup_logging
 from dinov2.utils import utils
 from dinov2.configs import dinov2_default_config
 
-
 logger = logging.getLogger("dinov2")
 
 
@@ -60,6 +60,19 @@ def default_setup(args):
     logger.info("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
 
 
+def setup_wandb(cfg: DictConfig):
+    """
+    Setup wandb in the main process
+    """
+    if distributed.is_main_process():
+        config = OmegaConf.to_container(cfg)
+        # disable wandb if no_wandb is set in the config
+        no_wandb = config.get("no_wandb", False)
+        mode = "disabled" if no_wandb else "online"
+        dataset_name = cfg["train"]["dataset_path"].split(":")[0]
+        wandb.init(project=f"dinov2-{dataset_name}", config=config, mode=mode)
+
+
 def setup(args):
     """
     Create configs and perform basic setups.
@@ -69,4 +82,5 @@ def setup(args):
     default_setup(args)
     apply_scaling_rules_to_cfg(cfg)
     write_config(cfg, args.output_dir)
+    setup_wandb(cfg)
     return cfg