diff --git a/conda.yaml b/conda.yaml index 35dfc30..cc82dac 100644 --- a/conda.yaml +++ b/conda.yaml @@ -6,15 +6,17 @@ channels: - xformers - conda-forge dependencies: - - python=3.9 - - pytorch::pytorch=2.0.0 - - pytorch::pytorch-cuda=11.7.0 - - pytorch::torchvision=0.15.0 + - python + - pytorch::pytorch + - pytorch::pytorch-cuda=12.1 + - pytorch::torchvision - omegaconf - - torchmetrics=0.10.3 + - torchmetrics - fvcore - iopath - - xformers::xformers=0.0.18 + - xformers::xformers + - wandb + - h5py - pip - pip: - git+https://github.com/facebookincubator/submitit diff --git a/dinov2/logging/helpers.py b/dinov2/logging/helpers.py index c6e70bb..ba28061 100644 --- a/dinov2/logging/helpers.py +++ b/dinov2/logging/helpers.py @@ -10,6 +10,7 @@ import logging import time import torch +import wandb import dinov2.distributed as distributed @@ -59,6 +60,10 @@ class MetricLogger(object): data_time=data_time, ) dict_to_dump.update({k: v.median for k, v in self.meters.items()}) + # log to wandb as well + for k, v in self.meters.items(): + wandb.log({k: v.median}, step=iteration) + with open(self.output_file, "a") as f: f.write(json.dumps(dict_to_dump) + "\n") pass diff --git a/dinov2/train/train.py b/dinov2/train/train.py index 473b8d0..a20cedd 100644 --- a/dinov2/train/train.py +++ b/dinov2/train/train.py @@ -30,6 +30,7 @@ logger = logging.getLogger("dinov2") def get_args_parser(add_help: bool = True): parser = argparse.ArgumentParser("DINOv2 training", add_help=add_help) parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file") + parser.add_argument("--no-wandb", action="store_true", help="Whether to not use wandb") parser.add_argument( "--no-resume", action="store_true", diff --git a/dinov2/utils/config.py b/dinov2/utils/config.py index c9de578..2cdcac6 100644 --- a/dinov2/utils/config.py +++ b/dinov2/utils/config.py @@ -7,14 +7,14 @@ import math import logging import os -from omegaconf import OmegaConf +import wandb +from omegaconf import OmegaConf, DictConfig import dinov2.distributed as distributed from dinov2.logging import setup_logging from dinov2.utils import utils from dinov2.configs import dinov2_default_config - logger = logging.getLogger("dinov2") @@ -60,6 +60,19 @@ def default_setup(args): logger.info("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))) +def setup_wandb(cfg: DictConfig): + """ + Setup wandb in the main process + """ + if distributed.is_main_process(): + config = OmegaConf.to_container(cfg) + # disable wandb if no_wandb is set in the config + no_wandb = config.get("no_wandb", False) + mode = "disabled" if no_wandb else "online" + dataset_name = cfg["train"]["dataset_path"].split(":")[0] + wandb.init(project=f"dinov2-{dataset_name}", config=config, mode=mode) + + def setup(args): """ Create configs and perform basic setups. @@ -69,4 +82,5 @@ def setup(args): default_setup(args) apply_scaling_rules_to_cfg(cfg) write_config(cfg, args.output_dir) + setup_wandb(cfg) return cfg