diff --git a/dinov2/configs/distill/vitl14_vitb14.yaml b/dinov2/configs/distill/vitl14_vitb14.yaml deleted file mode 100644 index e69de29..0000000 diff --git a/dinov2/configs/distill_default_config.yaml b/dinov2/configs/distill_default_config.yaml deleted file mode 100644 index 9f3bf0f..0000000 --- a/dinov2/configs/distill_default_config.yaml +++ /dev/null @@ -1,126 +0,0 @@ -MODEL: - WEIGHTS: '' -compute_precision: - grad_scaler: true - teacher: - backbone: - sharding_strategy: SHARD_GRAD_OP - mixed_precision: - param_dtype: fp16 - reduce_dtype: fp16 - buffer_dtype: fp32 - dino_head: - sharding_strategy: SHARD_GRAD_OP - mixed_precision: - param_dtype: fp16 - reduce_dtype: fp16 - buffer_dtype: fp32 - ibot_head: - sharding_strategy: SHARD_GRAD_OP - mixed_precision: - param_dtype: fp16 - reduce_dtype: fp16 - buffer_dtype: fp32 - student: - backbone: - sharding_strategy: SHARD_GRAD_OP - mixed_precision: - param_dtype: fp16 - reduce_dtype: fp16 - buffer_dtype: fp32 - dino_head: - sharding_strategy: SHARD_GRAD_OP - mixed_precision: - param_dtype: fp16 - reduce_dtype: fp32 - buffer_dtype: fp32 - ibot_head: - sharding_strategy: SHARD_GRAD_OP - mixed_precision: - param_dtype: fp16 - reduce_dtype: fp32 - buffer_dtype: fp32 -dino: - loss_weight: 1.0 - head_n_prototypes: 65536 - head_bottleneck_dim: 256 - head_nlayers: 3 - head_hidden_dim: 2048 - koleo_loss_weight: 0.1 -ibot: - loss_weight: 1.0 - mask_sample_probability: 0.0 - mask_ratio_min_max: - - 0.1 - - 0.5 - separate_head: false - head_n_prototypes: 65536 - head_bottleneck_dim: 256 - head_nlayers: 3 - head_hidden_dim: 2048 -train: - batch_size_per_gpu: 64 - dataset_path: ImageNet:split=TRAIN - output_dir: . - saveckp_freq: 20 - seed: 0 - num_workers: 10 - OFFICIAL_EPOCH_LENGTH: 1250 - cache_dataset: true - centering: "centering" # or "sinkhorn_knopp" -student: - arch: vit_large - patch_size: 16 - drop_path_rate: 0.0 - layerscale: 1.0e-05 - drop_path_uniform: true - pretrained_weights: '' - ffn_layer: "mlp" - block_chunks: 0 - qkv_bias: true - proj_bias: true - ffn_bias: true -teacher: - arch: vit_base - patch_size: 16 - drop_path_rate: 0.0 - layerscale: 1.0e-05 - drop_path_uniform: true - pretrained_weights: '' - ffn_layer: "mlp" - block_chunks: 0 - qkv_bias: true - proj_bias: true - ffn_bias: true - momentum_teacher: 0.992 - final_momentum_teacher: 1 - warmup_teacher_temp: 0.04 - teacher_temp: 0.07 - warmup_teacher_temp_epochs: 30 -optim: - epochs: 100 - weight_decay: 0.04 - weight_decay_end: 0.4 - base_lr: 0.004 # learning rate for a batch size of 1024 - lr: 0. # will be set after applying scaling rule - warmup_epochs: 10 - min_lr: 1.0e-06 - clip_grad: 3.0 - freeze_last_layer_epochs: 1 - scaling_rule: sqrt_wrt_1024 - patch_embed_lr_mult: 0.2 - layerwise_decay: 0.9 - adamw_beta1: 0.9 - adamw_beta2: 0.999 -crops: - global_crops_scale: - - 0.32 - - 1.0 - local_crops_number: 8 - local_crops_scale: - - 0.05 - - 0.32 - global_crops_size: 224 - local_crops_size: 96 -evaluation: - eval_period_iterations: 12500 \ No newline at end of file diff --git a/dinov2/configs/train/nextvit.yaml b/dinov2/configs/train/nextvit.yaml new file mode 100644 index 0000000..167d9b9 --- /dev/null +++ b/dinov2/configs/train/nextvit.yaml @@ -0,0 +1,7 @@ +# this corresponds to the default config +train: + dataset_path: ImageNet:split=TRAIN + batch_size_per_gpu: 64 +student: + arch: vit_large # vit_large, nextvit + block_chunks: 4 diff --git a/dinov2/distributed/__init__.py b/dinov2/distributed/__init__.py index 23226f4..a9addef 100644 --- a/dinov2/distributed/__init__.py +++ b/dinov2/distributed/__init__.py @@ -12,8 +12,8 @@ from typing import Dict, List import torch import torch.distributed as dist -_LOCAL_RANK = -1 -_LOCAL_WORLD_SIZE = -1 +_LOCAL_RANK = 0 +_LOCAL_WORLD_SIZE = 1 def is_enabled() -> bool: diff --git a/dinov2/models/__init__.py b/dinov2/models/__init__.py index 3fdff20..a16ecce 100644 --- a/dinov2/models/__init__.py +++ b/dinov2/models/__init__.py @@ -6,6 +6,7 @@ import logging from . import vision_transformer as vits +from .nextvit import NextVitSmall logger = logging.getLogger("dinov2") @@ -27,14 +28,20 @@ def build_model(args, only_teacher=False, img_size=224): interpolate_offset=args.interpolate_offset, interpolate_antialias=args.interpolate_antialias, ) - teacher = vits.__dict__[args.arch](**vit_kwargs) + if args.arch in vits.__dict__: + teacher = vits.__dict__[args.arch](**vit_kwargs) + else: + teacher = NextVitSmall() if only_teacher: return teacher, teacher.embed_dim - student = vits.__dict__[args.arch]( - **vit_kwargs, - drop_path_rate=args.drop_path_rate, - drop_path_uniform=args.drop_path_uniform, - ) + if args.arch in vits.__dict__: + student = vits.__dict__[args.arch]( + **vit_kwargs, + drop_path_rate=args.drop_path_rate, + drop_path_uniform=args.drop_path_uniform, + ) + else: + student = NextVitSmall() embed_dim = student.embed_dim return student, teacher, embed_dim diff --git a/dinov2/models/nextvit.py b/dinov2/models/nextvit.py new file mode 100644 index 0000000..8c1e00c --- /dev/null +++ b/dinov2/models/nextvit.py @@ -0,0 +1,50 @@ +import torch +from torch.nn import functional + +import sys +sys.path.append('/home/li.yu/code/JupiterCVML/europa/base/src/europa') +from dl.network.nextvit_brt import _get_nextvit + +class NextVitSmall(torch.nn.Module): + """BRT Segmentation model with definition to make it a custom model supported.""" + def __init__(self, num_classes=197*1024) -> None: + super().__init__() + + # define backbone + self.backbone = _get_nextvit( + model_size="small", + frozen_stages=-1, + norm_eval=False, + with_extra_norm=True, + norm_cfg=dict(type="SyncBN", requires_grad=True), + in_channels=3, + ) + + # self.proj_head = torch.nn.Sequential( + # torch.nn.Linear(1024, num_classes), + # ) + assert num_classes == 197 * 1024 + self.num_register_tokens = 1 + self.embed_dim = 1024 + self.proj_head = torch.nn.Linear(1024, num_classes) + + def forward_backbone(self, x, masks=None): + y = self.backbone(x) + y = functional.adaptive_avg_pool2d(y[-1], (1, 1)) + y = torch.flatten(y, 1) + y = self.proj_head(y) + + n = y.shape[0] + y_reshaped = y.reshape(n, 197, 1024) + return { + "x_norm_clstoken": y_reshaped[:, 0], # teacher 128x1024 + "x_norm_regtokens": y_reshaped[:, 1 : self.num_register_tokens + 1], # teacher 128x0x1024 + "x_norm_patchtokens": y_reshaped[:, self.num_register_tokens + 1 :], # teacher 128x196x1024 + "x_prenorm": None, + "masks": masks, + } + + def forward(self, x, *args, masks=None, **kwargs): + if isinstance(x, list): + return [self.forward_backbone(_x, _masks) for _x, _masks in zip(x, masks)] + return self.forward_backbone(x, masks) \ No newline at end of file diff --git a/dinov2/models/vision_transformer.py b/dinov2/models/vision_transformer.py index 13b44ae..9ca75c9 100644 --- a/dinov2/models/vision_transformer.py +++ b/dinov2/models/vision_transformer.py @@ -212,14 +212,14 @@ class DinoVisionTransformer(nn.Module): def prepare_tokens_with_masks(self, x, masks=None): B, nc, w, h = x.shape - x = self.patch_embed(x) + x = self.patch_embed(x) # teacher 128x3x224x224 -> 128x196x1024 if masks is not None: x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) - x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) - x = x + self.interpolate_pos_encoding(x, w, h) + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) # teacher 128x197x1024 + x = x + self.interpolate_pos_encoding(x, w, h) # teacher 128x197x1024 - if self.register_tokens is not None: + if self.register_tokens is not None: # self.register_tokens is None x = torch.cat( ( x[:, :1], @@ -232,9 +232,9 @@ class DinoVisionTransformer(nn.Module): return x def forward_features_list(self, x_list, masks_list): - x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] + x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] # student 128x3x224x224 -> 128x197x1024, 512x3x96x96 -> 512x37x1024 for blk in self.blocks: - x = blk(x) + x = blk(x) # student 128x197x1024, 512x37x1024 all_x = x output = [] @@ -242,30 +242,30 @@ class DinoVisionTransformer(nn.Module): x_norm = self.norm(x) output.append( { - "x_norm_clstoken": x_norm[:, 0], - "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], - "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], - "x_prenorm": x, - "masks": masks, + "x_norm_clstoken": x_norm[:, 0], # student 128x1024, 512x1024 + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], # student 128x0x1024, 512x0x1024 + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], # student 128x196x1024, 512x36x1024 + "x_prenorm": x, # student 128x197x1024, 512x37x1024 + "masks": masks, # student 128x196, None } ) return output def forward_features(self, x, masks=None): - if isinstance(x, list): + if isinstance(x, list): # student x [128x3x224x224, 512x3x96x96], masks [128x196, None] return self.forward_features_list(x, masks) - x = self.prepare_tokens_with_masks(x, masks) + x = self.prepare_tokens_with_masks(x, masks) # teacher 128x3x224x224 -> 128x197x1024 for blk in self.blocks: - x = blk(x) + x = blk(x) # teacher 128x197x1024 - x_norm = self.norm(x) + x_norm = self.norm(x) # teacher 128x197x1024 return { - "x_norm_clstoken": x_norm[:, 0], - "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], - "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], - "x_prenorm": x, + "x_norm_clstoken": x_norm[:, 0], # teacher 128x1024 + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], # teacher 128x0x1024 + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], # teacher 128x196x1024 + "x_prenorm": x, # teacher 128x197x1024 "masks": masks, } diff --git a/dinov2/run/submit.py b/dinov2/run/submit.py index 313be9d..4434ac7 100644 --- a/dinov2/run/submit.py +++ b/dinov2/run/submit.py @@ -9,7 +9,7 @@ import os from pathlib import Path from typing import List, Optional -import submitit +# import submitit from dinov2.utils.cluster import ( get_slurm_executor_parameters, @@ -61,16 +61,10 @@ def get_args_parser( help="Partition where to submit", ) parser.add_argument( - "--mem-per-gpu", - default="60G", - type=str, - help="Memory per GPU", + "--use-volta32", + action="store_true", + help="Request V100-32GB GPUs", ) - # parser.add_argument( - # "--use-volta32", - # action="store_true", - # help="Request V100-32GB GPUs", - # ) parser.add_argument( "--comment", default="", @@ -103,8 +97,8 @@ def submit_jobs(task_class, args, name: str): executor = submitit.AutoExecutor(folder=args.output_dir, slurm_max_num_timeout=30) kwargs = {} - # if args.use_volta32: - # kwargs["slurm_constraint"] = "volta32gb" + if args.use_volta32: + kwargs["slurm_constraint"] = "volta32gb" if args.comment: kwargs["slurm_comment"] = args.comment if args.exclude: @@ -116,11 +110,10 @@ def submit_jobs(task_class, args, name: str): timeout_min=args.timeout, # max is 60 * 72 slurm_signal_delay_s=120, slurm_partition=args.partition, - mem_per_gpu=args.mem_per_gpu, **kwargs, ) executor.update_parameters(name=name, **executor_params) - print(args, executor_params) + task = task_class(args) job = executor.submit(task) diff --git a/dinov2/run/train/train.py b/dinov2/run/train/train.py index c2366e9..332b921 100644 --- a/dinov2/run/train/train.py +++ b/dinov2/run/train/train.py @@ -9,7 +9,8 @@ import sys from dinov2.logging import setup_logging from dinov2.train import get_args_parser as get_train_args_parser -from dinov2.run.submit import get_args_parser, submit_jobs +# from dinov2.run.submit import get_args_parser, submit_jobs +from dinov2.run.submit import get_args_parser logger = logging.getLogger("dinov2") @@ -22,7 +23,7 @@ class Trainer(object): def __call__(self): from dinov2.train import main as train_main - self._setup_args() + # self._setup_args() train_main(self.args) def checkpoint(self): @@ -51,7 +52,13 @@ def main(): setup_logging() assert os.path.exists(args.config_file), "Configuration file does not exist!" - submit_jobs(Trainer, args, name="dinov2:train") + print(args) + # submit_jobs(Trainer, args, name="dinov2:train") + + from dinov2.train import main as train_main + + train_main(args) + return 0 diff --git a/dinov2/train/distill_meta_arch.py b/dinov2/train/distill_meta_arch.py deleted file mode 100644 index 7920ac9..0000000 --- a/dinov2/train/distill_meta_arch.py +++ /dev/null @@ -1,431 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -from functools import partial -import logging -import copy - -import torch -from torch import nn - -from dinov2.loss import DINOLoss, iBOTPatchLoss, KoLeoLoss -from dinov2.models import build_model -from dinov2.layers import DINOHead -from dinov2.utils.utils import has_batchnorms -from dinov2.utils.param_groups import get_params_groups_with_decay, fuse_params_groups -from dinov2.fsdp import get_fsdp_wrapper, ShardedGradScaler, get_fsdp_modules, reshard_fsdp_model - -from dinov2.models.vision_transformer import BlockChunk - -try: - from xformers.ops import fmha - - XFORMERS_AVAILABLE = True -except ImportError: - XFORMERS_AVAILABLE = False -assert XFORMERS_AVAILABLE, "xFormers is required for DINOv2 training" - - -logger = logging.getLogger("dinov2") - - -class DistillMetaArch(nn.Module): - def __init__(self, cfg): - super().__init__() - self.cfg = cfg - self.fp16_scaler = ShardedGradScaler() if cfg.compute_precision.grad_scaler else None - - student_model_dict = dict() - teacher_model_dict = dict() - - student_backbone, student_embed_dim = build_model(cfg.student, only_teacher=True, img_size=cfg.crops.global_crops_size) # pyright: ignore - teacher_backbone, teacher_embed_dim = build_model(cfg.teacher, only_teacher=True, img_size=cfg.crops.global_crops_size) # pyright: ignore - student_model_dict["backbone"] = student_backbone - teacher_model_dict["backbone"] = teacher_backbone - - # Student and teacher embedding dimensions can be different and therefore DINO head and IBOT heads should be different - logger.info(f"OPTIONS -- architecture : student_embed_dim: {student_embed_dim}") - logger.info(f"OPTIONS -- architecture : teacher_embed_dim: {teacher_embed_dim}") - - self.student_embed_dim = student_embed_dim - self.teacher_embed_dim = teacher_embed_dim - self.dino_out_dim = cfg.dino.head_n_prototypes - - self.do_dino = cfg.dino.loss_weight > 0 - self.do_koleo = cfg.dino.koleo_loss_weight > 0 - self.do_ibot = cfg.ibot.loss_weight > 0 - self.ibot_separate_head = cfg.ibot.separate_head - - logger.info("OPTIONS -- DINO") - if self.do_dino: - logger.info(f"OPTIONS -- DINO -- loss_weight: {cfg.dino.loss_weight}") - logger.info(f"OPTIONS -- DINO -- head_n_prototypes: {cfg.dino.head_n_prototypes}") - logger.info(f"OPTIONS -- DINO -- head_bottleneck_dim: {cfg.dino.head_bottleneck_dim}") - logger.info(f"OPTIONS -- DINO -- head_hidden_dim: {cfg.dino.head_hidden_dim}") - self.dino_loss_weight = cfg.dino.loss_weight - dino_head = partial( - DINOHead, - out_dim=cfg.dino.head_n_prototypes, - hidden_dim=cfg.dino.head_hidden_dim, - bottleneck_dim=cfg.dino.head_bottleneck_dim, - nlayers=cfg.dino.head_nlayers, - ) - - self.dino_loss = DINOLoss(self.dino_out_dim) - if self.do_koleo: - logger.info("OPTIONS -- DINO -- applying KOLEO regularization") - self.koleo_loss = KoLeoLoss() - - else: - logger.info("OPTIONS -- DINO -- not using DINO") - - if self.do_dino or self.do_ibot: - student_model_dict["dino_head"] = dino_head(in_dim=student_embed_dim) # pyright: ignore - teacher_model_dict["dino_head"] = dino_head(in_dim=teacher_embed_dim) # pyright: ignore - - logger.info("OPTIONS -- IBOT") - logger.info(f"OPTIONS -- IBOT -- loss_weight: {cfg.ibot.loss_weight}") - logger.info(f"OPTIONS -- IBOT masking -- ibot_mask_ratio_tuple: {cfg.ibot.mask_ratio_min_max}") - logger.info(f"OPTIONS -- IBOT masking -- ibot_mask_sample_probability: {cfg.ibot.mask_sample_probability}") - if self.do_ibot: - self.ibot_loss_weight = cfg.ibot.loss_weight - assert max(cfg.ibot.mask_ratio_min_max) > 0, "please provide a positive mask ratio tuple for ibot" - assert cfg.ibot.mask_sample_probability > 0, "please provide a positive mask probability for ibot" - self.ibot_out_dim = cfg.ibot.head_n_prototypes if self.ibot_separate_head else cfg.dino.head_n_prototypes - self.ibot_patch_loss = iBOTPatchLoss(self.ibot_out_dim) - if self.ibot_separate_head: - logger.info(f"OPTIONS -- IBOT -- loss_weight: {cfg.ibot.loss_weight}") - logger.info(f"OPTIONS -- IBOT -- head_n_prototypes: {cfg.ibot.head_n_prototypes}") - logger.info(f"OPTIONS -- IBOT -- head_bottleneck_dim: {cfg.ibot.head_bottleneck_dim}") - logger.info(f"OPTIONS -- IBOT -- head_hidden_dim: {cfg.ibot.head_hidden_dim}") - ibot_head = partial( - DINOHead, - out_dim=cfg.ibot.head_n_prototypes, - hidden_dim=cfg.ibot.head_hidden_dim, - bottleneck_dim=cfg.ibot.head_bottleneck_dim, - nlayers=cfg.ibot.head_nlayers, - ) - student_model_dict["ibot_head"] = ibot_head(in_dim=student_embed_dim) - teacher_model_dict["ibot_head"] = ibot_head(in_dim=teacher_embed_dim) - else: - logger.info("OPTIONS -- IBOT -- head shared with DINO") - - # self.need_to_synchronize_fsdp_streams = True - - self.student = nn.ModuleDict(student_model_dict) - self.teacher = nn.ModuleDict(teacher_model_dict) - self.student_shadow = copy.deepcopy(self.student) - - assert cfg.teacher.pretrained_weights is not None, "Must contain pretrained weights for distillation." - chkpt = torch.load(cfg.teacher.pretrained_weights) - logger.info(f"OPTIONS -- pretrained weights: loading from {cfg.teacher.pretrained_weights}") - self.teacher.load_state_dict(chkpt["model"], strict=False) - - # there is no backpropagation through the teacher, so no need for gradients - for p in self.teacher.parameters(): - p.requires_grad = False - - # there is no backpropagation through the shadow copy either - for p in self.student_shadow.parameters(): - p.requires_grad = False - - logger.info(f"Student is built: it is using {cfg.student.arch}") - logger.info(f"Teacher is built: it is using {cfg.teacher.arch}") - - def forward(self, inputs): - raise NotImplementedError - - def backprop_loss(self, loss): - if self.fp16_scaler is not None: - self.fp16_scaler.scale(loss).backward() # pyright: ignore - else: - loss.backward() - - def forward_backward(self, images, teacher_temp): - n_global_crops = 2 - assert n_global_crops == 2 - n_local_crops = self.cfg.crops.local_crops_number - - global_crops = images["collated_global_crops"].cuda(non_blocking=True) - local_crops = images["collated_local_crops"].cuda(non_blocking=True) - - masks = images["collated_masks"].cuda(non_blocking=True) - mask_indices_list = images["mask_indices_list"].cuda(non_blocking=True) - n_masked_patches_tensor = images["n_masked_patches"].cuda(non_blocking=True) - n_masked_patches = mask_indices_list.shape[0] - upperbound = images["upperbound"] - masks_weight = images["masks_weight"].cuda(non_blocking=True) - - n_local_crops_loss_terms = max(n_local_crops * n_global_crops, 1) - n_global_crops_loss_terms = (n_global_crops - 1) * n_global_crops - - do_dino = self.do_dino - do_ibot = self.do_ibot - - # loss scales - ibot_loss_scale = 1.0 / n_global_crops - - # teacher output - @torch.no_grad() - def get_teacher_output(): - x, n_global_crops_teacher = global_crops, n_global_crops - teacher_backbone_output_dict = self.teacher.backbone(x, is_training=True) # pyright: ignore - teacher_cls_tokens = teacher_backbone_output_dict["x_norm_clstoken"] - teacher_cls_tokens = teacher_cls_tokens.chunk(n_global_crops_teacher) - # watch out: these are chunked and cat'd in reverse so A is matched to B in the global crops dino loss - teacher_cls_tokens = torch.cat((teacher_cls_tokens[1], teacher_cls_tokens[0])) - ibot_teacher_patch_tokens = teacher_backbone_output_dict["x_norm_patchtokens"] - _dim = ibot_teacher_patch_tokens.shape[-1] - n_cls_tokens = teacher_cls_tokens.shape[0] - - if do_ibot and not self.ibot_separate_head: - buffer_tensor_teacher = ibot_teacher_patch_tokens.new_zeros(upperbound + n_cls_tokens, _dim) - buffer_tensor_teacher[:n_cls_tokens].copy_(teacher_cls_tokens) - torch.index_select( - ibot_teacher_patch_tokens.flatten(0, 1), - dim=0, - index=mask_indices_list, - out=buffer_tensor_teacher[n_cls_tokens : n_cls_tokens + n_masked_patches], - ) - tokens_after_head = self.teacher.dino_head(buffer_tensor_teacher) - teacher_cls_tokens_after_head = tokens_after_head[:n_cls_tokens] - masked_teacher_patch_tokens_after_head = tokens_after_head[ - n_cls_tokens : n_cls_tokens + n_masked_patches - ] - elif do_ibot and self.ibot_separate_head: - buffer_tensor_teacher = ibot_teacher_patch_tokens.new_zeros(upperbound, _dim) - torch.index_select( - ibot_teacher_patch_tokens.flatten(0, 1), - dim=0, - index=mask_indices_list, - out=buffer_tensor_teacher[:n_masked_patches], - ) - teacher_cls_tokens_after_head = self.teacher.dino_head(teacher_cls_tokens) - masked_teacher_patch_tokens_after_head = self.teacher.ibot_head(buffer_tensor_teacher)[ - :n_masked_patches - ] - else: - teacher_cls_tokens_after_head = self.teacher.dino_head(teacher_cls_tokens) - masked_teacher_ibot_softmaxed_centered = None - - if self.cfg.train.centering == "centering": - teacher_dino_softmaxed_centered_list = self.dino_loss.softmax_center_teacher( - teacher_cls_tokens_after_head, teacher_temp=teacher_temp - ).view(n_global_crops_teacher, -1, *teacher_cls_tokens_after_head.shape[1:]) - self.dino_loss.update_center(teacher_cls_tokens_after_head) - if do_ibot: - masked_teacher_patch_tokens_after_head = masked_teacher_patch_tokens_after_head.unsqueeze(0) - masked_teacher_ibot_softmaxed_centered = self.ibot_patch_loss.softmax_center_teacher( - masked_teacher_patch_tokens_after_head[:, :n_masked_patches], teacher_temp=teacher_temp - ) - masked_teacher_ibot_softmaxed_centered = masked_teacher_ibot_softmaxed_centered.squeeze(0) - self.ibot_patch_loss.update_center(masked_teacher_patch_tokens_after_head[:n_masked_patches]) - - elif self.cfg.train.centering == "sinkhorn_knopp": - teacher_dino_softmaxed_centered_list = self.dino_loss.sinkhorn_knopp_teacher( - teacher_cls_tokens_after_head, teacher_temp=teacher_temp - ).view(n_global_crops_teacher, -1, *teacher_cls_tokens_after_head.shape[1:]) - - if do_ibot: - masked_teacher_ibot_softmaxed_centered = self.ibot_patch_loss.sinkhorn_knopp_teacher( - masked_teacher_patch_tokens_after_head, - teacher_temp=teacher_temp, - n_masked_patches_tensor=n_masked_patches_tensor, - ) - - else: - raise NotImplementedError - - return teacher_dino_softmaxed_centered_list, masked_teacher_ibot_softmaxed_centered - - teacher_dino_softmaxed_centered_list, masked_teacher_ibot_softmaxed_centered = get_teacher_output() - reshard_fsdp_model(self.teacher) - - loss_dict = {} - - loss_accumulator = 0 # for backprop - student_global_backbone_output_dict, student_local_backbone_output_dict = self.student.backbone( - [global_crops, local_crops], masks=[masks, None], is_training=True - ) - - inputs_for_student_head_list = [] - - # 1a: local crops cls tokens - student_local_cls_tokens = student_local_backbone_output_dict["x_norm_clstoken"] - inputs_for_student_head_list.append(student_local_cls_tokens.unsqueeze(0)) - - # 1b: global crops cls tokens - student_global_cls_tokens = student_global_backbone_output_dict["x_norm_clstoken"] - inputs_for_student_head_list.append(student_global_cls_tokens.unsqueeze(0)) - - # 1c: global crops patch tokens - if do_ibot: - _dim = student_global_backbone_output_dict["x_norm_clstoken"].shape[-1] - ibot_student_patch_tokens = student_global_backbone_output_dict["x_norm_patchtokens"] - buffer_tensor_patch_tokens = ibot_student_patch_tokens.new_zeros(upperbound, _dim) - buffer_tensor_patch_tokens[:n_masked_patches].copy_( - torch.index_select(ibot_student_patch_tokens.flatten(0, 1), dim=0, index=mask_indices_list) - ) - if not self.ibot_separate_head: - inputs_for_student_head_list.append(buffer_tensor_patch_tokens.unsqueeze(0)) - else: - student_global_masked_patch_tokens_after_head = self.student.ibot_head(buffer_tensor_patch_tokens)[ - :n_masked_patches - ] - - # 2: run - _attn_bias, cat_inputs = fmha.BlockDiagonalMask.from_tensor_list(inputs_for_student_head_list) - outputs_list = _attn_bias.split(self.student.dino_head(cat_inputs)) - - # 3a: local crops cls tokens - student_local_cls_tokens_after_head = outputs_list.pop(0).squeeze(0) - - # 3b: global crops cls tokens - student_global_cls_tokens_after_head = outputs_list.pop(0).squeeze(0) - - # 3c: global crops patch tokens - if do_ibot and not self.ibot_separate_head: - student_global_masked_patch_tokens_after_head = outputs_list.pop(0).squeeze(0)[:n_masked_patches] - - if n_local_crops > 0: - dino_local_crops_loss = self.dino_loss( - student_output_list=student_local_cls_tokens_after_head.chunk(n_local_crops), - teacher_out_softmaxed_centered_list=teacher_dino_softmaxed_centered_list, - ) / (n_global_crops_loss_terms + n_local_crops_loss_terms) - - # store for display - loss_dict["dino_local_crops_loss"] = dino_local_crops_loss - - # accumulate loss - loss_accumulator += self.dino_loss_weight * dino_local_crops_loss - - # process global crops - loss_scales = 2 # this is here since we process global crops together - - if do_dino: - # compute loss - dino_global_crops_loss = ( - self.dino_loss( - student_output_list=[student_global_cls_tokens_after_head], - teacher_out_softmaxed_centered_list=[ - teacher_dino_softmaxed_centered_list.flatten(0, 1) - ], # these were chunked and stacked in reverse so A is matched to B - ) - * loss_scales - / (n_global_crops_loss_terms + n_local_crops_loss_terms) - ) - - loss_dict["dino_global_crops_loss"] = dino_global_crops_loss - - # accumulate loss - loss_accumulator += self.dino_loss_weight * dino_global_crops_loss - - student_cls_tokens = student_global_cls_tokens - - if self.do_koleo: - koleo_loss = self.cfg.dino.koleo_loss_weight * sum( - self.koleo_loss(p) for p in student_cls_tokens.chunk(2) - ) # we don't apply koleo loss between cls tokens of a same image - loss_accumulator += koleo_loss - loss_dict["koleo_loss"] = ( - koleo_loss / loss_scales - ) # this is to display the same losses as before but we can remove eventually - - if do_ibot: - # compute loss - ibot_patch_loss = ( - self.ibot_patch_loss.forward_masked( - student_global_masked_patch_tokens_after_head, - masked_teacher_ibot_softmaxed_centered, - student_masks_flat=masks, - n_masked_patches=n_masked_patches, - masks_weight=masks_weight, - ) - * loss_scales - * ibot_loss_scale - ) - - # store for display - loss_dict["ibot_loss"] = ibot_patch_loss / 2 - - # accumulate loss - loss_accumulator += self.ibot_loss_weight * ibot_patch_loss - - self.backprop_loss(loss_accumulator) - - # self.fsdp_synchronize_streams() - - return loss_dict - - # def fsdp_synchronize_streams(self): - # if self.need_to_synchronize_fsdp_streams: - # torch.cuda.synchronize() - # self.student.dino_head._streams = ( - # self.teacher.dino_head._streams - # ) = self.student.backbone._streams = self.teacher.backbone._streams - # self.need_to_synchronize_fsdp_streams = False - - # def update_teacher(self, m): - # student_param_list = [] - # teacher_param_list = [] - # with torch.no_grad(): - # for k in self.student.keys(): - # for ms, mt in zip(get_fsdp_modules(self.student[k]), get_fsdp_modules(self.teacher[k])): - # student_param_list += ms.params - # teacher_param_list += mt.params - # torch._foreach_mul_(teacher_param_list, m) - # torch._foreach_add_(teacher_param_list, student_param_list, alpha=1 - m) - - def update_student_shadow(self, m): - student_param_list = [] - student_shadow_param_list = [] - with torch.no_grad(): - for k in self.student.keys(): - for ms, mss in zip(get_fsdp_modules(self.student[k]), get_fsdp_modules(self.student_shadow[k])): - student_param_list += ms.params - student_shadow_param_list += mss.params - torch._foreach_mul_(student_shadow_param_list, m) - torch._foreach_add_(student_shadow_param_list, student_param_list, alpha=1 - m) - - def train(self): - super().train() - self.teacher.eval() - self.student_shadow.eval() - - def get_maybe_fused_params_for_submodel(self, m): - params_groups = get_params_groups_with_decay( - model=m, - lr_decay_rate=self.cfg.optim.layerwise_decay, - patch_embed_lr_mult=self.cfg.optim.patch_embed_lr_mult, - ) - fused_params_groups = fuse_params_groups(params_groups) - logger.info("fusing param groups") - - for g in fused_params_groups: - g["foreach"] = True - return fused_params_groups - - def get_params_groups(self): - all_params_groups = [] - for m in self.student.values(): - all_params_groups += self.get_maybe_fused_params_for_submodel(m) - return all_params_groups - - def prepare_for_distributed_training(self): - logger.info("DISTRIBUTED FSDP -- preparing model for distributed training") - if has_batchnorms(self.student): - raise NotImplementedError - - # below will synchronize all student subnetworks across gpus: - for k, v in self.student.items(): - self.student_shadow[k].load_state_dict(self.student[k].state_dict()) - # self.teacher[k].load_state_dict(self.student[k].state_dict()) - student_model_cfg = self.cfg.compute_precision.student[k] - self.student[k] = get_fsdp_wrapper(student_model_cfg, modules_to_wrap={BlockChunk})(self.student[k]) - - for k, v in self.teacher.items(): - teacher_model_cfg = self.cfg.compute_precision.teacher[k] - self.teacher[k] = get_fsdp_wrapper(teacher_model_cfg, modules_to_wrap={BlockChunk})(self.teacher[k]) \ No newline at end of file diff --git a/dinov2/train/ssl_meta_arch.py b/dinov2/train/ssl_meta_arch.py index 1568b54..d991be4 100644 --- a/dinov2/train/ssl_meta_arch.py +++ b/dinov2/train/ssl_meta_arch.py @@ -37,7 +37,7 @@ class SSLMetaArch(nn.Module): student_model_dict = dict() teacher_model_dict = dict() - student_backbone, teacher_backbone, embed_dim = build_model_from_cfg(cfg) + student_backbone, teacher_backbone, embed_dim = build_model_from_cfg(cfg) # embed_dim=1024 student_model_dict["backbone"] = student_backbone teacher_model_dict["backbone"] = teacher_backbone logger.info(f"OPTIONS -- architecture : embed_dim: {embed_dim}") @@ -47,13 +47,13 @@ class SSLMetaArch(nn.Module): logger.info(f"OPTIONS -- pretrained weights: loading from {cfg.student.pretrained_weights}") student_backbone.load_state_dict(chkpt["model"], strict=False) - self.embed_dim = embed_dim - self.dino_out_dim = cfg.dino.head_n_prototypes + self.embed_dim = embed_dim # 1024 + self.dino_out_dim = cfg.dino.head_n_prototypes # 65536 - self.do_dino = cfg.dino.loss_weight > 0 - self.do_koleo = cfg.dino.koleo_loss_weight > 0 - self.do_ibot = cfg.ibot.loss_weight > 0 - self.ibot_separate_head = cfg.ibot.separate_head + self.do_dino = cfg.dino.loss_weight > 0 # True + self.do_koleo = cfg.dino.koleo_loss_weight > 0 # True + self.do_ibot = cfg.ibot.loss_weight > 0 # True + self.ibot_separate_head = cfg.ibot.separate_head # False logger.info("OPTIONS -- DINO") if self.do_dino: @@ -81,7 +81,7 @@ class SSLMetaArch(nn.Module): if self.do_dino or self.do_ibot: student_model_dict["dino_head"] = dino_head() teacher_model_dict["dino_head"] = dino_head() - + logger.info("OPTIONS -- IBOT") logger.info(f"OPTIONS -- IBOT -- loss_weight: {cfg.ibot.loss_weight}") logger.info(f"OPTIONS -- IBOT masking -- ibot_mask_ratio_tuple: {cfg.ibot.mask_ratio_min_max}") @@ -134,18 +134,18 @@ class SSLMetaArch(nn.Module): assert n_global_crops == 2 n_local_crops = self.cfg.crops.local_crops_number - global_crops = images["collated_global_crops"].cuda(non_blocking=True) - local_crops = images["collated_local_crops"].cuda(non_blocking=True) + global_crops = images["collated_global_crops"].cuda(non_blocking=True) # 128x3x224x224 + local_crops = images["collated_local_crops"].cuda(non_blocking=True) # 512x3x96x96 - masks = images["collated_masks"].cuda(non_blocking=True) - mask_indices_list = images["mask_indices_list"].cuda(non_blocking=True) - n_masked_patches_tensor = images["n_masked_patches"].cuda(non_blocking=True) - n_masked_patches = mask_indices_list.shape[0] - upperbound = images["upperbound"] - masks_weight = images["masks_weight"].cuda(non_blocking=True) + masks = images["collated_masks"].cuda(non_blocking=True) # 128x196 + mask_indices_list = images["mask_indices_list"].cuda(non_blocking=True) # 3730x + n_masked_patches_tensor = images["n_masked_patches"].cuda(non_blocking=True) # 1x + n_masked_patches = mask_indices_list.shape[0] # 3730 + upperbound = images["upperbound"] # 3771 + masks_weight = images["masks_weight"].cuda(non_blocking=True) # 3730x - n_local_crops_loss_terms = max(n_local_crops * n_global_crops, 1) - n_global_crops_loss_terms = (n_global_crops - 1) * n_global_crops + n_local_crops_loss_terms = max(n_local_crops * n_global_crops, 1) # 16 + n_global_crops_loss_terms = (n_global_crops - 1) * n_global_crops # 2 do_dino = self.do_dino do_ibot = self.do_ibot @@ -226,7 +226,7 @@ class SSLMetaArch(nn.Module): return teacher_dino_softmaxed_centered_list, masked_teacher_ibot_softmaxed_centered - teacher_dino_softmaxed_centered_list, masked_teacher_ibot_softmaxed_centered = get_teacher_output() + teacher_dino_softmaxed_centered_list, masked_teacher_ibot_softmaxed_centered = get_teacher_output() # [64x65536, 64x65536], 3730x65536 reshard_fsdp_model(self.teacher) loss_dict = {} diff --git a/dinov2/train/train.py b/dinov2/train/train.py index 473b8d0..fa666c2 100644 --- a/dinov2/train/train.py +++ b/dinov2/train/train.py @@ -22,6 +22,10 @@ from dinov2.utils.utils import CosineScheduler from dinov2.train.ssl_meta_arch import SSLMetaArch +import sys +sys.path.append('/home/li.yu/code/JupiterCVML/europa/base/src/europa') +from dl.network.nextvit_brt import _get_nextvit + torch.backends.cuda.matmul.allow_tf32 = True # PyTorch 1.12 sets this to False by default logger = logging.getLogger("dinov2") @@ -54,6 +58,9 @@ For python-based LazyConfig, use "path.key=value". type=str, help="Output directory to save logs and checkpoints", ) + parser.add_argument('--world_size', default=1, type=int, + help='number of distributed processes') + parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') return parser @@ -133,7 +140,7 @@ def do_test(cfg, model, iteration): def do_train(cfg, model, resume=False): model.train() - inputs_dtype = torch.half + inputs_dtype = torch.half # torch.float16 fp16_scaler = model.fp16_scaler # for mixed precision training # setup optimizer @@ -152,8 +159,8 @@ def do_train(cfg, model, resume=False): start_iter = checkpointer.resume_or_load(cfg.MODEL.WEIGHTS, resume=resume).get("iteration", -1) + 1 - OFFICIAL_EPOCH_LENGTH = cfg.train.OFFICIAL_EPOCH_LENGTH - max_iter = cfg.optim.epochs * OFFICIAL_EPOCH_LENGTH + OFFICIAL_EPOCH_LENGTH = cfg.train.OFFICIAL_EPOCH_LENGTH # 1250 + max_iter = cfg.optim.epochs * OFFICIAL_EPOCH_LENGTH # 125000 periodic_checkpointer = PeriodicCheckpointer( checkpointer, @@ -164,9 +171,9 @@ def do_train(cfg, model, resume=False): # setup data preprocessing - img_size = cfg.crops.global_crops_size - patch_size = cfg.student.patch_size - n_tokens = (img_size // patch_size) ** 2 + img_size = cfg.crops.global_crops_size # 224 + patch_size = cfg.student.patch_size # 16 + n_tokens = (img_size // patch_size) ** 2 # 196 mask_generator = MaskingGenerator( input_size=(img_size // patch_size, img_size // patch_size), max_num_patches=0.5 * img_size // patch_size * img_size // patch_size, @@ -300,6 +307,18 @@ def main(args): model = SSLMetaArch(cfg).to(torch.device("cuda")) model.prepare_for_distributed_training() + # model = _get_nextvit( + # model_size="small", + # frozen_stages=-1, + # norm_eval=False, + # with_extra_norm=True, + # norm_cfg=dict(type="SyncBN", requires_grad=True), + # in_channels=3, + # ) + # print('tunable parameters', sum(p.numel() for p in model.parameters() if p.requires_grad)) + # if args.distributed: + # model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True) + logger.info("Model:\n{}".format(model)) if args.eval_only: iteration = ( diff --git a/dinov2/train/train_distill.py b/dinov2/train/train_distill.py deleted file mode 100644 index e6f58a9..0000000 --- a/dinov2/train/train_distill.py +++ /dev/null @@ -1,319 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import argparse -import logging -import math -import os -from functools import partial - -from fvcore.common.checkpoint import PeriodicCheckpointer -import torch - -from dinov2.data import SamplerType, make_data_loader, make_dataset -from dinov2.data import collate_data_and_cast, DataAugmentationDINO, MaskingGenerator -import dinov2.distributed as distributed -from dinov2.fsdp import FSDPCheckpointer -from dinov2.logging import MetricLogger -from dinov2.utils.config import setup -from dinov2.utils.utils import CosineScheduler - -from dinov2.train.distill_meta_arch import DistillMetaArch - - -torch.backends.cuda.matmul.allow_tf32 = True # PyTorch 1.12 sets this to False by default -logger = logging.getLogger("dinov2") - - -def get_args_parser(add_help: bool = True): - parser = argparse.ArgumentParser("DINOv2 training", add_help=add_help) - parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file") - parser.add_argument( - "--no-resume", - action="store_true", - help="Whether to not attempt to resume from the checkpoint directory. ", - ) - parser.add_argument("--eval-only", action="store_true", help="perform evaluation only") - parser.add_argument("--eval", type=str, default="", help="Eval type to perform") - parser.add_argument( - "opts", - help=""" -Modify config options at the end of the command. For Yacs configs, use -space-separated "PATH.KEY VALUE" pairs. -For python-based LazyConfig, use "path.key=value". - """.strip(), - default=None, - nargs=argparse.REMAINDER, - ) - parser.add_argument( - "--output-dir", - "--output_dir", - default="", - type=str, - help="Output directory to save logs and checkpoints", - ) - - return parser - - -def build_optimizer(cfg, params_groups): - return torch.optim.AdamW(params_groups, betas=(cfg.optim.adamw_beta1, cfg.optim.adamw_beta2)) - - -def build_schedulers(cfg): - OFFICIAL_EPOCH_LENGTH = cfg.train.OFFICIAL_EPOCH_LENGTH - lr = dict( - base_value=cfg.optim["lr"], - final_value=cfg.optim["min_lr"], - total_iters=cfg.optim["epochs"] * OFFICIAL_EPOCH_LENGTH, - warmup_iters=cfg.optim["warmup_epochs"] * OFFICIAL_EPOCH_LENGTH, - start_warmup_value=0, - ) - wd = dict( - base_value=cfg.optim["weight_decay"], - final_value=cfg.optim["weight_decay_end"], - total_iters=cfg.optim["epochs"] * OFFICIAL_EPOCH_LENGTH, - ) - momentum = dict( - base_value=cfg.teacher["momentum_teacher"], - final_value=cfg.teacher["final_momentum_teacher"], - total_iters=cfg.optim["epochs"] * OFFICIAL_EPOCH_LENGTH, - ) - teacher_temp = dict( - base_value=cfg.teacher["teacher_temp"], - final_value=cfg.teacher["teacher_temp"], - total_iters=cfg.teacher["warmup_teacher_temp_epochs"] * OFFICIAL_EPOCH_LENGTH, - warmup_iters=cfg.teacher["warmup_teacher_temp_epochs"] * OFFICIAL_EPOCH_LENGTH, - start_warmup_value=cfg.teacher["warmup_teacher_temp"], - ) - - lr_schedule = CosineScheduler(**lr) - wd_schedule = CosineScheduler(**wd) - momentum_schedule = CosineScheduler(**momentum) - teacher_temp_schedule = CosineScheduler(**teacher_temp) - last_layer_lr_schedule = CosineScheduler(**lr) - - last_layer_lr_schedule.schedule[ - : cfg.optim["freeze_last_layer_epochs"] * OFFICIAL_EPOCH_LENGTH - ] = 0 # mimicking the original schedules - - logger.info("Schedulers ready.") - - return ( - lr_schedule, - wd_schedule, - momentum_schedule, - teacher_temp_schedule, - last_layer_lr_schedule, - ) - - -def apply_optim_scheduler(optimizer, lr, wd, last_layer_lr): - for param_group in optimizer.param_groups: - is_last_layer = param_group["is_last_layer"] - lr_multiplier = param_group["lr_multiplier"] - wd_multiplier = param_group["wd_multiplier"] - param_group["weight_decay"] = wd * wd_multiplier - param_group["lr"] = (last_layer_lr if is_last_layer else lr) * lr_multiplier - - -def do_test(cfg, model, iteration): - new_state_dict = model.teacher.state_dict() - - if distributed.is_main_process(): - iterstring = str(iteration) - eval_dir = os.path.join(cfg.train.output_dir, "eval", iterstring) - os.makedirs(eval_dir, exist_ok=True) - # save teacher checkpoint - teacher_ckp_path = os.path.join(eval_dir, "teacher_checkpoint.pth") - torch.save({"teacher": new_state_dict}, teacher_ckp_path) - - -def do_train(cfg, model, resume=False): - model.train() - inputs_dtype = torch.half - fp16_scaler = model.fp16_scaler # for mixed precision training - - # setup optimizer - - optimizer = build_optimizer(cfg, model.get_params_groups()) - ( - lr_schedule, - wd_schedule, - momentum_schedule, - teacher_temp_schedule, - last_layer_lr_schedule, - ) = build_schedulers(cfg) - - # checkpointer - checkpointer = FSDPCheckpointer(model, cfg.train.output_dir, optimizer=optimizer, save_to_disk=True) - - start_iter = checkpointer.resume_or_load(cfg.MODEL.WEIGHTS, resume=resume).get("iteration", -1) + 1 - - OFFICIAL_EPOCH_LENGTH = cfg.train.OFFICIAL_EPOCH_LENGTH - max_iter = cfg.optim.epochs * OFFICIAL_EPOCH_LENGTH - - periodic_checkpointer = PeriodicCheckpointer( - checkpointer, - period=3 * OFFICIAL_EPOCH_LENGTH, - max_iter=max_iter, - max_to_keep=3, - ) - - # setup data preprocessing - - img_size = cfg.crops.global_crops_size - patch_size = cfg.student.patch_size - n_tokens = (img_size // patch_size) ** 2 - mask_generator = MaskingGenerator( - input_size=(img_size // patch_size, img_size // patch_size), - max_num_patches=0.5 * img_size // patch_size * img_size // patch_size, - ) - - data_transform = DataAugmentationDINO( - cfg.crops.global_crops_scale, - cfg.crops.local_crops_scale, - cfg.crops.local_crops_number, - global_crops_size=cfg.crops.global_crops_size, - local_crops_size=cfg.crops.local_crops_size, - ) - - collate_fn = partial( - collate_data_and_cast, - mask_ratio_tuple=cfg.ibot.mask_ratio_min_max, - mask_probability=cfg.ibot.mask_sample_probability, - n_tokens=n_tokens, - mask_generator=mask_generator, - dtype=inputs_dtype, - ) - - # setup data loader - - dataset = make_dataset( - dataset_str=cfg.train.dataset_path, - transform=data_transform, - target_transform=lambda _: (), - ) - # sampler_type = SamplerType.INFINITE - sampler_type = SamplerType.SHARDED_INFINITE - data_loader = make_data_loader( - dataset=dataset, - batch_size=cfg.train.batch_size_per_gpu, - num_workers=cfg.train.num_workers, - shuffle=True, - seed=start_iter, # TODO: Fix this -- cfg.train.seed - sampler_type=sampler_type, - sampler_advance=0, # TODO(qas): fix this -- start_iter * cfg.train.batch_size_per_gpu, - drop_last=True, - collate_fn=collate_fn, - ) - - # training loop - - iteration = start_iter - - logger.info("Starting training from iteration {}".format(start_iter)) - metrics_file = os.path.join(cfg.train.output_dir, "training_metrics.json") - metric_logger = MetricLogger(delimiter=" ", output_file=metrics_file) - header = "Training" - - for data in metric_logger.log_every( - data_loader, - 10, - header, - max_iter, - start_iter, - ): - current_batch_size = data["collated_global_crops"].shape[0] / 2 - if iteration > max_iter: - return - - # apply schedules - - lr = lr_schedule[iteration] - wd = wd_schedule[iteration] - mom = momentum_schedule[iteration] - teacher_temp = teacher_temp_schedule[iteration] - last_layer_lr = last_layer_lr_schedule[iteration] - apply_optim_scheduler(optimizer, lr, wd, last_layer_lr) - - # compute losses - - optimizer.zero_grad(set_to_none=True) - loss_dict = model.forward_backward(data, teacher_temp=teacher_temp) - - # clip gradients - - if fp16_scaler is not None: - if cfg.optim.clip_grad: - fp16_scaler.unscale_(optimizer) - for v in model.student.values(): - v.clip_grad_norm_(cfg.optim.clip_grad) - fp16_scaler.step(optimizer) - fp16_scaler.update() - else: - if cfg.optim.clip_grad: - for v in model.student.values(): - v.clip_grad_norm_(cfg.optim.clip_grad) - optimizer.step() - - # perform teacher EMA update - - model.update_student_shadow(mom) - - # logging - - if distributed.get_global_size() > 1: - for v in loss_dict.values(): - torch.distributed.all_reduce(v) - loss_dict_reduced = {k: v.item() / distributed.get_global_size() for k, v in loss_dict.items()} - - if math.isnan(sum(loss_dict_reduced.values())): - logger.info("NaN detected") - raise AssertionError - losses_reduced = sum(loss for loss in loss_dict_reduced.values()) - - metric_logger.update(lr=lr) - metric_logger.update(wd=wd) - metric_logger.update(mom=mom) - metric_logger.update(last_layer_lr=last_layer_lr) - metric_logger.update(current_batch_size=current_batch_size) - metric_logger.update(total_loss=losses_reduced, **loss_dict_reduced) - - # checkpointing and testing - - if cfg.evaluation.eval_period_iterations > 0 and (iteration + 1) % cfg.evaluation.eval_period_iterations == 0: - do_test(cfg, model, f"training_{iteration}") - torch.cuda.synchronize() - periodic_checkpointer.step(iteration) - - iteration = iteration + 1 - metric_logger.synchronize_between_processes() - return {k: meter.global_avg for k, meter in metric_logger.meters.items()} - - -def main(args): - cfg = setup(args) - - model = DistillMetaArch(cfg).to(torch.device("cuda")) - model.prepare_for_distributed_training() - - logger.info("Model:\n{}".format(model)) - if args.eval_only: - iteration = ( - FSDPCheckpointer(model, save_dir=cfg.train.output_dir) - .resume_or_load(cfg.MODEL.WEIGHTS, resume=not args.no_resume) - .get("iteration", -1) - + 1 - ) - return do_test(cfg, model, f"manual_{iteration}") - - do_train(cfg, model, resume=not args.no_resume) - - -if __name__ == "__main__": - args = get_args_parser(add_help=True).parse_args() - main(args) \ No newline at end of file diff --git a/dinov2/utils/cluster.py b/dinov2/utils/cluster.py index d5c9143..3df87dc 100644 --- a/dinov2/utils/cluster.py +++ b/dinov2/utils/cluster.py @@ -76,20 +76,20 @@ def get_slurm_executor_parameters( ) -> Dict[str, Any]: # create default parameters params = { - # "mem_gb": 0, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html + "mem_gb": 0, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html "gpus_per_node": num_gpus_per_node, "tasks_per_node": num_gpus_per_node, # one task per GPU - "cpus_per_gpu": 7, + "cpus_per_task": 10, "nodes": nodes, "slurm_partition": get_slurm_partition(cluster_type), } # apply cluster-specific adjustments cluster_type = get_cluster_type(cluster_type) if cluster_type == ClusterType.AWS: - params["cpus_per_gpu"] = 12 + params["cpus_per_task"] = 12 del params["mem_gb"] elif cluster_type == ClusterType.RSC: - params["cpus_per_gpu"] = 12 + params["cpus_per_task"] = 12 # set additional parameters / apply overrides params.update(kwargs) return params diff --git a/dinov2/utils/config.py b/dinov2/utils/config.py index c9de578..06b16c0 100644 --- a/dinov2/utils/config.py +++ b/dinov2/utils/config.py @@ -7,6 +7,7 @@ import math import logging import os +import torch from omegaconf import OmegaConf import dinov2.distributed as distributed @@ -46,6 +47,44 @@ def get_cfg_from_args(args): return cfg +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + import builtins as __builtin__ + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop('force', False) + if is_master or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + +def init_distributed_mode(args): + if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ['WORLD_SIZE']) + args.gpu = int(os.environ['LOCAL_RANK']) + elif 'SLURM_PROCID' in os.environ: + args.rank = int(os.environ['SLURM_PROCID']) + args.gpu = args.rank % torch.cuda.device_count() + else: + print('Not using distributed mode') + args.distributed = False + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = 'nccl' + print('| distributed init (rank {}): {}'.format( + args.rank, args.dist_url), flush=True) + torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, + world_size=args.world_size, rank=args.rank) + torch.distributed.barrier() + setup_for_distributed(args.rank == 0) + def default_setup(args): distributed.enable(overwrite=True) seed = getattr(args, "seed", 0) @@ -66,7 +105,8 @@ def setup(args): """ cfg = get_cfg_from_args(args) os.makedirs(args.output_dir, exist_ok=True) - default_setup(args) + # default_setup(args) + init_distributed_mode(args) apply_scaling_rules_to_cfg(cfg) write_config(cfg, args.output_dir) return cfg