diff --git a/dinov2/configs/train/custom.yaml b/dinov2/configs/train/custom.yaml index 44457db..39f31e4 100644 --- a/dinov2/configs/train/custom.yaml +++ b/dinov2/configs/train/custom.yaml @@ -5,38 +5,76 @@ compute_precision: backbone: sharding_strategy: SHARD_GRAD_OP mixed_precision: - param_dtype: fp16 - reduce_dtype: fp16 + param_dtype: fp32 + reduce_dtype: fp32 buffer_dtype: fp32 dino_head: sharding_strategy: SHARD_GRAD_OP mixed_precision: - param_dtype: fp16 - reduce_dtype: fp16 - buffer_dtype: fp32 - ibot_head: - sharding_strategy: SHARD_GRAD_OP - mixed_precision: - param_dtype: fp16 - reduce_dtype: fp16 - buffer_dtype: fp32 - student: - backbone: - sharding_strategy: SHARD_GRAD_OP - mixed_precision: - param_dtype: fp16 - reduce_dtype: fp16 - buffer_dtype: fp32 - dino_head: - sharding_strategy: SHARD_GRAD_OP - mixed_precision: - param_dtype: fp16 + param_dtype: fp32 reduce_dtype: fp32 buffer_dtype: fp32 ibot_head: sharding_strategy: SHARD_GRAD_OP mixed_precision: - param_dtype: fp16 + param_dtype: fp32 + reduce_dtype: fp32 + buffer_dtype: fp32 + merge_blocks: + sharding_strategy: SHARD_GRAD_OP + mixed_precision: + param_dtype: fp32 + reduce_dtype: fp32 + buffer_dtype: fp32 + pre_encoder: + sharding_strategy: SHARD_GRAD_OP + mixed_precision: + param_dtype: fp32 + reduce_dtype: fp32 + buffer_dtype: fp32 + cast_forward_inputs: False + model_adapter: + sharding_strategy: SHARD_GRAD_OP + mixed_precision: + param_dtype: fp32 + reduce_dtype: fp32 + buffer_dtype: fp32 + student: + backbone: + sharding_strategy: SHARD_GRAD_OP + mixed_precision: + param_dtype: fp32 + reduce_dtype: fp32 + buffer_dtype: fp32 + dino_head: + sharding_strategy: SHARD_GRAD_OP + mixed_precision: + param_dtype: fp32 + reduce_dtype: fp32 + buffer_dtype: fp32 + ibot_head: + sharding_strategy: SHARD_GRAD_OP + mixed_precision: + param_dtype: fp32 + reduce_dtype: fp32 + buffer_dtype: fp32 + merge_blocks: + sharding_strategy: SHARD_GRAD_OP + mixed_precision: + param_dtype: fp32 + reduce_dtype: fp32 + buffer_dtype: fp32 + pre_encoder: + sharding_strategy: SHARD_GRAD_OP + mixed_precision: + param_dtype: fp32 + reduce_dtype: fp32 + buffer_dtype: fp32 + cast_forward_inputs: False + model_adapter: + sharding_strategy: SHARD_GRAD_OP + mixed_precision: + param_dtype: fp32 reduce_dtype: fp32 buffer_dtype: fp32 @@ -48,8 +86,8 @@ data_transform: "default" train: batch_size_per_gpu: 16 #vitg 26+, vitl: 56, vits:152, vitb:120 for 8 node num_workers: 1 - OFFICIAL_EPOCH_LENGTH: 100 # 1250 - dataset_path: ImageNet:root=/home/paperspace/Documents/nika_space/ADE20K/ADEChallengeData2016/images/training_raw/ + OFFICIAL_EPOCH_LENGTH: 1000 # 1250 + dataset_path: ImageNet:root=/home/paperspace/Documents/nika_space/ADE20K/ADEChallengeData2016/images/training_raw/ #MIT:root=/home/paperspace/Documents/nika_space/mit_dataset/train/ centering: sinkhorn_knopp drop_path_rate: 0.4 @@ -61,8 +99,8 @@ train: teacher: momentum_teacher: 0.994 optim: - epochs: 20 # 500 - weight_decay_end: 0.2 + epochs: 50 # 500 + weight_decay_end: 0.3 base_lr: 0.0001 # learning rate for a batch size of 1024 warmup_epochs: 20 # 80 layerwise_decay: 1.0 @@ -76,7 +114,7 @@ evaluation: # "dinov2_vits14","dinov2_vitb14","dinov2_vitl14","dinov2_vitg14" student: - arch: vit_base + arch: dinov2_vitb14 patch_size: 14 merge_block_indexes: "" # num, num, num, crops: diff --git a/dinov2/fsdp/__init__.py b/dinov2/fsdp/__init__.py index c4b39a6..fda65ee 100644 --- a/dinov2/fsdp/__init__.py +++ b/dinov2/fsdp/__init__.py @@ -17,6 +17,7 @@ from torch.distributed.fsdp import StateDictType from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler from torch.distributed.fsdp.wrap import ModuleWrapPolicy from torch.distributed.fsdp._runtime_utils import _reshard +import torch.nn as nn def get_fsdp_wrapper(model_cfg, modules_to_wrap=set()): @@ -32,16 +33,22 @@ def get_fsdp_wrapper(model_cfg, modules_to_wrap=set()): "bf16": torch.bfloat16, } + param_dtype = dtype_dict[model_cfg.mixed_precision.param_dtype] + reduce_dtype = dtype_dict[model_cfg.mixed_precision.reduce_dtype] + buffer_dtype = dtype_dict[model_cfg.mixed_precision.buffer_dtype] + mixed_precision_config = MixedPrecision( - param_dtype=dtype_dict[model_cfg.mixed_precision.param_dtype], - reduce_dtype=dtype_dict[model_cfg.mixed_precision.reduce_dtype], - buffer_dtype=dtype_dict[model_cfg.mixed_precision.buffer_dtype], + param_dtype=param_dtype, + reduce_dtype=reduce_dtype, + buffer_dtype=buffer_dtype, + cast_forward_inputs=model_cfg.mixed_precision.get("cast_forward_inputs", True) ) sharding_strategy_config = sharding_strategy_dict[model_cfg.sharding_strategy] local_rank = distributed.get_local_rank() + print("Modules to wrap: ", modules_to_wrap) fsdp_wrapper = partial( FSDP, sharding_strategy=sharding_strategy_config, @@ -51,6 +58,7 @@ def get_fsdp_wrapper(model_cfg, modules_to_wrap=set()): use_orig_params=True, auto_wrap_policy=ModuleWrapPolicy(modules_to_wrap), ) + return fsdp_wrapper @@ -112,7 +120,7 @@ class FSDPCheckpointer(Checkpointer): self.tag_last_checkpoint(basename) def load(self, *args, **kwargs): - with FSDP.state_dict_type(self.model, StateDictType.LOCAL_STATE_DICT): + with FSDP.state_dict_type(self.model, StateDictType.FULL_STATE_DICT): return super().load(*args, **kwargs) def has_checkpoint(self) -> bool: diff --git a/dinov2/hub/backbones.py b/dinov2/hub/backbones.py index 53fe837..ec054f4 100644 --- a/dinov2/hub/backbones.py +++ b/dinov2/hub/backbones.py @@ -28,6 +28,7 @@ def _make_dinov2_model( interpolate_offset: float = 0.1, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, + merge_block_indexes: list[int] = [0], **kwargs, ): from ..models import vision_transformer as vits @@ -48,8 +49,10 @@ def _make_dinov2_model( num_register_tokens=num_register_tokens, interpolate_antialias=interpolate_antialias, interpolate_offset=interpolate_offset, + merge_block_indexes=merge_block_indexes ) vit_kwargs.update(**kwargs) + print("Kwargs: ", vit_kwargs) model = vits.__dict__[arch_name](**vit_kwargs) if pretrained: diff --git a/dinov2/models/__init__.py b/dinov2/models/__init__.py index e910203..4009d2a 100644 --- a/dinov2/models/__init__.py +++ b/dinov2/models/__init__.py @@ -26,8 +26,10 @@ def build_model(args, only_teacher=False, img_size=224): num_register_tokens=args.num_register_tokens, interpolate_offset=args.interpolate_offset, interpolate_antialias=args.interpolate_antialias, - # merge_blocks_indexes=args.merge_block_indexes, + merge_block_indexes=args.student.merge_block_indexes, ) + print( vits.__dict__.keys()) + print("Merge block ind: ", args.student.merge_block_indexes) teacher = vits.__dict__[args.arch](**vit_kwargs) if only_teacher: return teacher, teacher.embed_dim @@ -41,4 +43,5 @@ def build_model(args, only_teacher=False, img_size=224): def build_model_from_cfg(cfg, only_teacher=False): + print("Only teacher", only_teacher) return build_model(cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size) diff --git a/dinov2/models/help.py b/dinov2/models/help.py index 8285ca8..e15750a 100644 --- a/dinov2/models/help.py +++ b/dinov2/models/help.py @@ -25,28 +25,30 @@ class Merge_block(BaseModule): self.ada_c = ada_c # 784 - embedded dim + adapter_c self.embeded_dim = 768 - self.fc_1 = nn.Linear(self.embeded_dim*2, mid_c).to(torch.float16) - self.fc_2 = nn.Linear(mid_c, self.embeded_dim).to(torch.float16) + self.fc_1 = nn.Linear(self.embeded_dim*2, mid_c) + print("Fc 1 type: ", self.fc_1.weight.dtype, self.fc_1.bias.dtype) + self.fc_2 = nn.Linear(mid_c, self.embeded_dim) self.return_ada = return_ada if self.return_ada: - self.conv_3 = nn.Conv1d(mid_c, self.embeded_dim, kernel_size=1).to(torch.float16) # 1D Conv instead of 3x3 + self.conv_3 = nn.Conv1d(mid_c, self.embeded_dim, kernel_size=1) # 1D Conv instead of 3x3 else: self.conv_3 = None def forward(self, fea, adapter, ratio=1.0): res = fea # print("Before concatenation: ", fea.shape, adapter.shape, self.fea_c, self.ada_c) - # print("before concatenation: ", fea.shape, adapter.shape) + # print("before concatenation: ", fea.dtype, adapter.dtype) fea = torch.cat([fea, adapter], dim=-1) # (B, seq_len, fea_c + ada_c) # print("after concatenation: ", fea.shape, adapter.shape) B, seq_len, C = fea.shape fea = fea.view(B * seq_len, C) + # print("before concatenation: ", fea.dtype, adapter.dtype) + fea = fea.to(self.fc_1.weight.dtype) fea = self.fc_1(fea) fea = fea.view(B, seq_len, -1) ada = self.fc_2(fea) fea_out = ratio * ada + res - if self.return_ada: ada = self.conv_3(fea.permute(0, 2, 1)) @@ -313,7 +315,7 @@ class CustomLayerNorm(nn.Module): # Predictor P_K class Kernel_Predictor(nn.Module): - def __init__(self, dim, mode='low', num_heads=1, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + def __init__(self, dim, mode='normal', num_heads=1, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads @@ -546,7 +548,7 @@ def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Te return img -def Gain_Denoise(I1, r1, r2, gain, sigma, k_size=3): # [9, 9] in LOD dataset, [3, 3] in other dataset +def Gain_Denoise(I1, r1, r2, gain, sigma, k_size=1): # [9, 9] in LOD dataset, [3, 3] in other dataset out = [] for i in range(I1.shape[0]): I1_gain = gain[i] * I1[i,:,:,:] @@ -583,11 +585,11 @@ def WB_CCM(I2, ccm_matrix, distance): out_I4 = [] for i in range(I2.shape[0]): # SOG White Balance Algorithm - I3 = SoG_algo(I2[i,:,:,:], distance[i]) + I3 = SoG_algo(I2[i,:,:,:]) # Camera Color Matrix I4 = torch.tensordot(I3, ccm_matrix[i,:,:], dims=[[-1], [-1]]) - I4 = torch.clamp(I4, 1e-5, 1.0) + I4 = torch.clamp(I4, 1e-7, 1.0) out_I3.append(I3) out_I4.append(I4) @@ -620,7 +622,7 @@ class VitInputLevelAdapter(nn.Module): # (1). I1 --> I2: Denoise & Enhancement & Sharpen r1, r2, gain, sigma = self.Predictor_K(I1) I2 = Gain_Denoise(I1, r1, r2, gain, sigma, k_size=self.k_size) # (B,C,H,W) - I2 = torch.clamp(I2, 1e-5, 1.0) # normal & over-exposure + I2 = torch.clamp(I2, 1e-7, 1.0) # normal & over-exposure ccm_matrix, distance = self.Predictor_M(I2) # (2). I2 --> I3: White Balance, Shade of Gray diff --git a/dinov2/models/vision_transformer.py b/dinov2/models/vision_transformer.py index cacff27..fc655be 100644 --- a/dinov2/models/vision_transformer.py +++ b/dinov2/models/vision_transformer.py @@ -19,7 +19,7 @@ from torch.nn.init import trunc_normal_ import torch.nn.functional as F from dinov2.layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block from dinov2.models.help import Merge_block, Model_level_Adapeter -from dinov2.models.help import VitInputLevelAdapter as Input_level_Adapeter +from dinov2.models.input_level_adapter import Input_level_Adapeter logger = logging.getLogger("dinov2") @@ -68,17 +68,18 @@ class DinoVisionTransformer(nn.Module): interpolate_antialias=False, interpolate_offset=0.1, # RAW adapter parameters - w_lut=False, + w_lut=True, light_mode='normal', lut_dim=32, k_size=3, merge_ratio=1.0, model_adapter_path='/home/paperspace/Documents/nika_space/ECCV_RAW_Adapter/extracted_model_adapter_weights.pth', - input_level_adapter_path='/home/paperspace/Documents/nika_space/ECCV_RAW_Adapter/extracted_pre_encoder_weights.pth', + input_level_adapter_path='/home/paperspace/Documents/nika_space/ECCV_RAW_Adapter/mmsegmentation_github/extracted_pre_encoder_weights.pth', fea_c_s = [384, 768, 1920], ada_c_s = [16, 32, 64], mid_c_s = [384, 576, 768], - # merge_blocks_indexes=[], + merge_block_indexes=[], + is_teacher=False, ): """ Args: @@ -190,28 +191,31 @@ class DinoVisionTransformer(nn.Module): self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) # Initialize RAW adapter + # self.merge_block_indexes = merge_block_indexes + self.merge_block_indexes = [0, 6, 10] + print("Before if:") if self.w_lut: self.pre_encoder = Input_level_Adapeter(mode=light_mode, lut_dim=lut_dim, k_size=k_size, w_lut=w_lut) for param in self.pre_encoder.parameters(): param.requires_grad_(True) - # self.model_adapter = Model_level_Adapeter(in_c=in_chans, w_lut=w_lut) + self.model_adapter = Model_level_Adapeter(in_c=in_chans, w_lut=w_lut) self.model_adapter = Model_level_Adapeter(in_c=3, in_dim=ada_c_s[0], w_lut=self.w_lut) - # if model_adapter_path is not None: - # print("Loading model adapter:", model_adapter_path) - # adapter_state = torch.load(model_adapter_path, map_location="cpu") - # self.model_adapter.load_state_dict(adapter_state, strict=False) - # if input_level_adapter_path is not None: - # print("Loading input-level adapter:", input_level_adapter_path) - # adapter_state = torch.load(input_level_adapter_path, map_location="cpu") - # self.pre_encoder.load_state_dict(adapter_state) + if model_adapter_path is not None: + print("Loading model adapter:", model_adapter_path) + adapter_state = torch.load(model_adapter_path, map_location="cpu") + self.model_adapter.load_state_dict(adapter_state, strict=False) + if input_level_adapter_path is not None: + print("Loading input-level adapter:", input_level_adapter_path) + adapter_state = torch.load(input_level_adapter_path, map_location="cpu") + self.pre_encoder.load_state_dict(adapter_state) self.merge_blocks = [] - self.merge_blocks_indexes = merge_blocks_indexes + # Loop through the merge_blocks_indexes and create Merge_block instances - for i, idx in enumerate(self.merge_blocks_indexes): - return_ada = False if i == len(self.merge_blocks_indexes) - 1 else True # Only the last block gets return_ada=False - if i != 0 or i != len(self.merge_blocks_indexes) - 1: + for i, idx in enumerate(self.merge_block_indexes): + return_ada = False if i == len(self.merge_block_indexes) - 1 else True # Only the last block gets return_ada=False + if i != 0 or i != len(self.merge_block_indexes) - 1: k = 1 else: k = i @@ -221,12 +225,57 @@ class DinoVisionTransformer(nn.Module): mid_c=mid_c_s[k], return_ada=return_ada ).to("cuda") - self.merge_blocks.append(merge_block) - # self.merge_blocks.to("cuda") - print(self.merge_blocks) + # merge_block_state = torch.load("/home/paperspace/Documents/nika_space/ECCV_RAW_Adapter/mmsegmentation_github/extracted_merged_blocks_weights_2.pth", map_location="cpu") + # merge_block.load_state_dict(merge_block_state) + # if is_teacher: + # for param in merge_block.parameters(): + # param.requires_grad = False + + self.merge_blocks.append(merge_block) + # self.merge_blocks.to("cuda") + print("MERGED BLOCKS:", self.merge_block_indexes) + + + # # Freeze the patch embedding + # for param in self.patch_embed.parameters(): + # param.requires_grad_(False) + + # # Freeze the position embedding + # if hasattr(self, 'pos_embed') and self.pos_embed is not None: + # self.pos_embed.requires_grad_(False) + + # # Freeze the cls token if it exists + # if hasattr(self, 'cls_token') and self.cls_token is not None: + # self.cls_token.requires_grad_(False) + + # # Freeze the transformer blocks + # for block in self.blocks: + # for param in block.parameters(): + # param.requires_grad_(False) + + # # Freeze the norm layer if it exists + # if hasattr(self, 'norm') and self.norm is not None: + # for param in self.norm.parameters(): + # param.requires_grad_(False) + + # # Freeze any head/projection layers + # if hasattr(self, 'head') and self.head is not None: + # for param in self.head.parameters(): + # param.requires_grad_(False) self.init_weights() + # self.freeze_dino_weights() + + + def freeze_dino_weights(self): + """Freeze all original DINO weights and keep only the adapter trainable.""" + for name, param in self.named_parameters(): + # Unfreeze only pre_encoder, model_adapter, and merge_blocks + if not any(sub in name for sub in ["pre_encoder", "model_adapter", "merge_blocks"]): + param.requires_grad = False + else: + param.requires_grad = True # Ensure adapters are trainable def init_weights(self): trunc_normal_(self.pos_embed, std=0.02) @@ -284,33 +333,33 @@ class DinoVisionTransformer(nn.Module): def prepare_tokens_with_masks(self, x, masks=None): B, nc, w, h = x.shape - # print("BLOCKS NUM: " , len(self.blocks), len(self.merge_blocks)) - + if self.w_lut: # I1, I2, I3, I4 x_raw = self.pre_encoder(x) + if self.w_lut: # I1, I2, I3, I4 ada = self.model_adapter([x_raw[0], x_raw[1], x_raw[2], x_raw[3]]) - # else: # I1, I2, I3 - # ada = self.model_adapter([x_raw[0], x_raw[1], x_raw[2]]) + else: # I1, I2, I3 + ada = self.model_adapter([x_raw[0], x_raw[1], x_raw[2]]) - # x = x_raw[-1] + x = x_raw[-1] # print("X before patch embedding ",ada.shape, x.shape ) x = self.patch_embed(x) - # if x.shape[1] == 256: - # ada = F.interpolate(ada, size=(64, 64), mode='bilinear', align_corners=False) - # elif x.shape[1] == 49: - # ada = F.interpolate(ada, size=(28, 28), mode='bilinear', align_corners=False) + if x.shape[1] == 256: + ada = F.interpolate(ada, size=(64, 64), mode='bilinear', align_corners=False) + elif x.shape[1] == 49: + ada = F.interpolate(ada, size=(28, 28), mode='bilinear', align_corners=False) # print("ada.shape ", ada.shape, x.shape) - # ada = self.patch_embed_for_model_adapter(ada) + ada = self.patch_embed_for_model_adapter(ada) # tensor2_reshaped = ada.transpose(1, 2) # [32, 768, 196] # print("ada.shape after embedding ", ada.shape, x.shape) if masks is not None: x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) - # ada = torch.where(masks.unsqueeze(-1), self.mask_token.to(ada.dtype).unsqueeze(0), ada) + ada = torch.where(masks.unsqueeze(-1), self.mask_token.to(ada.dtype).unsqueeze(0), ada) x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) - # ada = torch.cat((self.cls_token.expand(ada.shape[0], -1, -1), ada), dim=1) - # ada = ada + self.interpolate_pos_encoding(ada, w, h) + ada = torch.cat((self.cls_token.expand(ada.shape[0], -1, -1), ada), dim=1) + ada = ada + self.interpolate_pos_encoding(ada, w, h) x = x + self.interpolate_pos_encoding(x, w, h) @@ -328,8 +377,8 @@ class DinoVisionTransformer(nn.Module): # print("x.shape",x.shape) - # return x, ada - return x, None + return x, ada + # return x, None def forward_features_list(self, x_list, masks_list): @@ -349,7 +398,8 @@ class DinoVisionTransformer(nn.Module): x = blk(x) - if self.w_lut and ada is not None and i in self.merge_blocks_indexes: + if self.w_lut and ada is not None and i in self.merge_block_indexes: + # print("ERROR!") x_ada_pairs = [self.merge_blocks[indx](x_i, ada_i, ratio=self.merge_ratio) for x_i, ada_i in zip(x, ada_list)] x, ada_list = map(list, zip(*x_ada_pairs)) indx += 1 @@ -376,9 +426,11 @@ class DinoVisionTransformer(nn.Module): x, ada = self.prepare_tokens_with_masks(x, masks) indx = 0 for i, blk in enumerate(self.blocks): + # print("X type: ", x.dtype) x = blk(x) - if self.w_lut and ada is not None and i in self.merge_blocks_indexes: + if self.w_lut and ada is not None and i in self.merge_block_indexes: # print("HERE 11", x.shape, ada.shape) + # print("ERROR!") x, ada = self.merge_blocks[indx](x, ada, ratio=self.merge_ratio) indx += 1 @@ -392,7 +444,7 @@ class DinoVisionTransformer(nn.Module): } def _get_intermediate_layers_not_chunked(self, x, n=1): - x = self.prepare_tokens_with_masks(x) + x, ada = self.prepare_tokens_with_masks(x) # If n is an int, take the n last blocks. If it's a list, take them output, total_block_len = [], len(self.blocks) blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n @@ -404,7 +456,7 @@ class DinoVisionTransformer(nn.Module): return output def _get_intermediate_layers_chunked(self, x, n=1): - x = self.prepare_tokens_with_masks(x) + x, ada = self.prepare_tokens_with_masks(x) output, i, total_block_len = [], 0, len(self.blocks[-1]) # If n is an int, take the n last blocks. If it's a list, take them blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n diff --git a/dinov2/train/ssl_meta_arch.py b/dinov2/train/ssl_meta_arch.py index 4dabf09..219c55c 100644 --- a/dinov2/train/ssl_meta_arch.py +++ b/dinov2/train/ssl_meta_arch.py @@ -26,6 +26,39 @@ except ImportError: logger = logging.getLogger("dinov2") +import math +def interpolate_pos_encoding(x, w, h): + N = x.shape[1] - 1 + dim = x.shape[-1] + w0 = w / int(math.sqrt(N)) + h0 = h / int(math.sqrt(N)) + + # Interpolate the position embeddings without changing the first row (class token) + patch_pos_embed = nn.functional.interpolate( + x[:, 1:].reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), + scale_factor=(w0, h0), + mode="bicubic", + ) + + # assert int(w0) == patch_pos_embed.shape[-2] + # assert int(h0) == patch_pos_embed.shape[-1] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + + # Concatenate the class token with the interpolated position embeddings + return torch.cat((x[:, :1], patch_pos_embed), dim=1) + + +def get_downloaded_dino_vit_interpolated(modelname="dinov2_vits14", merge_block_indexes=[], is_teacher=False): + print("HERE !") + model = torch.hub.load("facebookresearch/dinov2", modelname, pretrained=False, merge_block_indexes=merge_block_indexes, is_teacher=is_teacher) # + print("HERE 2") + input_tensor = model.pos_embed + input_tensor = input_tensor.to('cuda') + tensor_corr_shape = interpolate_pos_encoding(input_tensor, 16, 16) + pos_embed = nn.Parameter(torch.zeros(1, 257)) + pos_embed.data = tensor_corr_shape + model.pos_embed = pos_embed + return model class SSLMetaArch(nn.Module): @@ -37,7 +70,14 @@ class SSLMetaArch(nn.Module): student_model_dict = dict() teacher_model_dict = dict() - student_backbone, teacher_backbone, embed_dim = build_model_from_cfg(cfg) + if cfg.student.arch in ["dinov2_vits14", "dinov2_vitb14", "dinov2_vitl14", "dinov2_vitg14"]: + print("Load pre-trained encoder:") + student_backbone = get_downloaded_dino_vit_interpolated(cfg.student.arch, cfg.student.merge_block_indexes) + teacher_backbone = get_downloaded_dino_vit_interpolated(cfg.student.arch, cfg.student.merge_block_indexes, is_teacher=True) + embed_dict = {"dinov2_vits14": 384, "dinov2_vitb14": 768, "dinov2_vitl14": 1024, "dinov2_vitg14": 1536} + embed_dim = embed_dict[cfg.student.arch] + else: + student_backbone, teacher_backbone, embed_dim = build_model_from_cfg(cfg) student_model_dict["backbone"] = student_backbone teacher_model_dict["backbone"] = teacher_backbone logger.info(f"OPTIONS -- architecture : embed_dim: {embed_dim}") @@ -47,8 +87,9 @@ class SSLMetaArch(nn.Module): logger.info(f"OPTIONS -- pretrained weights: loading from {cfg.student.pretrained_weights}") student_backbone.load_state_dict(chkpt["model"], strict=False) - # if cfg.student.merge_block_indexes: - # merge_blocks_ind = cfg.student.merge_block_indexes + print("Before IF") + if cfg.student.merge_block_indexes: + self.merge_blocks_ind = cfg.student.merge_block_indexes self.embed_dim = embed_dim self.dino_out_dim = cfg.dino.head_n_prototypes @@ -115,6 +156,7 @@ class SSLMetaArch(nn.Module): self.need_to_synchronize_fsdp_streams = True + print("Student model dict: ", student_model_dict) self.student = nn.ModuleDict(student_model_dict) self.teacher = nn.ModuleDict(teacher_model_dict) @@ -392,12 +434,216 @@ class SSLMetaArch(nn.Module): def prepare_for_distributed_training(self): logger.info("DISTRIBUTED FSDP -- preparing model for distributed training") - if has_batchnorms(self.student): - raise NotImplementedError + # if has_batchnorms(self.student): + # raise NotImplementedError # below will synchronize all student subnetworks across gpus: + # print("Self.student: ", self.student.keys()) + # for name, param in self.named_parameters(): + # # Unfreeze only pre_encoder, model_adapter, and merge_blocks + # if not any(sub in name for sub in ["pre_encoder", "model_adapter", "merge_blocks"]): + # param.requires_grad = False + # else: + # param.requires_grad = True + + print("Student keys: ", self.student.items()) for k, v in self.student.items(): self.teacher[k].load_state_dict(self.student[k].state_dict()) student_model_cfg = self.cfg.compute_precision.student[k] + print("Cfg: ", student_model_cfg) + self.student[k] = get_fsdp_wrapper(student_model_cfg, modules_to_wrap={BlockChunk})(self.student[k]) teacher_model_cfg = self.cfg.compute_precision.teacher[k] self.teacher[k] = get_fsdp_wrapper(teacher_model_cfg, modules_to_wrap={BlockChunk})(self.teacher[k]) + + print("Type: ", type(self.student["backbone"].pre_encoder)) + # if self.student["backbone"].pre_encoder: + # cfg = self.cfg.compute_precision.student.pre_encoder + # self.student["backbone"].pre_encoder = get_fsdp_wrapper(cfg, modules_to_wrap={nn.Module})(self.student["backbone"].pre_encoder) + + cfg = self.cfg.compute_precision.student.pre_encoder + # if not isinstance(self.student["backbone"].pre_encoder.Predictor_K, FSDP): + # self.student["backbone"].pre_encoder.Predictor_K = get_fsdp_wrapper(cfg, modules_to_wrap={nn.Linear, nn.Conv2d, nn.LayerNorm})(self.student["backbone"].pre_encoder.Predictor_K ) + import dinov2.distributed as distributed + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + self.student["backbone"].pre_encoder.Predictor_K = FSDP( + self.student["backbone"].pre_encoder.Predictor_K, + sharding_strategy=cfg.sharding_strategy, + mixed_precision=cfg.mixed_precision, + device_id=distributed.get_local_rank(), + sync_module_states=True, + use_orig_params=True, + ) + self.teacher["backbone"].pre_encoder.Predictor_K = FSDP( + self.teacher["backbone"].pre_encoder.Predictor_K, + sharding_strategy=cfg.sharding_strategy, + mixed_precision=cfg.mixed_precision, + device_id=distributed.get_local_rank(), + sync_module_states=True, + use_orig_params=True, + ) + # if not isinstance(self.student["backbone"].pre_encoder.Predictor_M, FSDP): + # self.student["backbone"].pre_encoder.Predictor_M = get_fsdp_wrapper(cfg, modules_to_wrap={nn.Module})(self.student["backbone"].pre_encoder.Predictor_M ) + self.student["backbone"].pre_encoder.Predictor_M = FSDP( + self.student["backbone"].pre_encoder.Predictor_M, + sharding_strategy=cfg.sharding_strategy, + mixed_precision=cfg.mixed_precision, + device_id=distributed.get_local_rank(), + sync_module_states=True, + use_orig_params=True, + ) + self.teacher["backbone"].pre_encoder.Predictor_M = FSDP( + self.teacher["backbone"].pre_encoder.Predictor_M, + sharding_strategy=cfg.sharding_strategy, + mixed_precision=cfg.mixed_precision, + device_id=distributed.get_local_rank(), + sync_module_states=True, + use_orig_params=True, + ) + # if not isinstance(self.student["backbone"].pre_encoder.LUT, FSDP): + self.student["backbone"].pre_encoder.LUT = get_fsdp_wrapper(cfg, modules_to_wrap={nn.Module})(self.student["backbone"].pre_encoder.LUT) + self.teacher["backbone"].pre_encoder.LUT = get_fsdp_wrapper(cfg, modules_to_wrap={nn.Module})(self.teacher["backbone"].pre_encoder.LUT) + cfg = self.cfg.compute_precision.student.model_adapter + self.student["backbone"].model_adapter = get_fsdp_wrapper(cfg, modules_to_wrap={nn.Module})(self.student["backbone"].model_adapter) + self.teacher["backbone"].model_adapter = get_fsdp_wrapper(cfg, modules_to_wrap={nn.Module})(self.teacher["backbone"].model_adapter) + cfg = self.cfg.compute_precision.student.merge_blocks + for block in range(len(self.student["backbone"].merge_blocks)): + self.student["backbone"].merge_blocks[block] = get_fsdp_wrapper(cfg, modules_to_wrap={nn.Module})(self.student["backbone"].merge_blocks[block]) + + + # def prepare_for_distributed_training(self): + # logger.info("DISTRIBUTED FSDP -- preparing model for distributed training") + + # # Synchronize all student subnetworks across GPUs + # for k, v in self.student.items(): + # self.teacher[k].load_state_dict(self.student[k].state_dict()) + # student_model_cfg = self.cfg.compute_precision.student[k] + # self.student[k] = get_fsdp_wrapper(student_model_cfg, modules_to_wrap={BlockChunk})(self.student[k]) + # teacher_model_cfg = self.cfg.compute_precision.teacher[k] + # self.teacher[k] = get_fsdp_wrapper(teacher_model_cfg, modules_to_wrap={BlockChunk})(self.teacher[k]) + + # # Wrap the pre-encoder, model adapter, and merge blocks separately + # if hasattr(self, 'pre_encoder'): + # logger.info("Wrapping pre-encoder for FSDP") + # self.pre_encoder = get_fsdp_wrapper(self.cfg.compute_precision.student, modules_to_wrap={nn.Module})(self.pre_encoder) + + # if hasattr(self, 'model_adapter'): + # logger.info("Wrapping model adapter for FSDP") + # self.model_adapter = get_fsdp_wrapper(self.cfg.compute_precision.student, modules_to_wrap={nn.Module})(self.model_adapter) + + # if hasattr(self, 'merge_blocks'): + # logger.info("Wrapping merge blocks for FSDP") + # for i, block in enumerate(self.merge_blocks): + # self.merge_blocks[i] = get_fsdp_wrapper(self.cfg.compute_precision.student, modules_to_wrap={nn.Module})(block) + + # def prepare_for_distributed_training(self): + # """ + # Prepare both student and teacher models for distributed training using FSDP. + # This handles the issue of having both trainable and non-trainable parameters. + # """ + # from torch.distributed.fsdp.wrap import ModuleWrapPolicy + # import copy + + # # Process student models + # for k in self.student.keys(): + # print("Preparing student model for distributed training:", k) + + # # First make all parameters trainable to satisfy FSDP requirements + # for param in self.student[k].parameters(): + # param.requires_grad_(True) + + # # Wrap with FSDP + # student_model_cfg = self.cfg.compute_precision.student[k] + # print("Student __: ", self.student[k], student_model_cfg) + # self.student[k] = get_fsdp_wrapper(student_model_cfg, modules_to_wrap={BlockChunk})(self.student[k]) + + # # After FSDP wrapping, set appropriate parameters to not require gradient + # for name, param in self.student[k].named_parameters(): + # # Freeze DINO backbone parameters + # if any([x in name for x in [ + # 'pre_encoder', 'model_adapter', 'merge_blocks' + # ]]): + # # print("HERE HERE", name) + # param.requires_grad_(True) + # else: + # print("Name: ", name) + # param.requires_grad_(False) + # # print("Second option") + + # # Print statistics about trainable parameters + # trainable_params = 0 + # total_params = 0 + # for name, param in self.student[k].named_parameters(): + # total_params += param.numel() + # if param.requires_grad: + # trainable_params += param.numel() + + # print(f"Student {k}: Total params: {total_params:,}, Trainable params: {trainable_params:,} ({trainable_params/total_params:.2%})") + + # # Process teacher models + # for k in self.teacher.keys(): + # print("Preparing teacher model for distributed training:", k) + + # # Teacher model doesn't need gradient computation during training + # # But we first make all parameters trainable for FSDP wrapping + # for param in self.teacher[k].parameters(): + # param.requires_grad_(True) + + # # Wrap with FSDP + # teacher_model_cfg = self.cfg.compute_precision.teacher[k] + # print("Teacher __: ", self.teacher[k], teacher_model_cfg) + # self.teacher[k] = get_fsdp_wrapper(teacher_model_cfg, modules_to_wrap={BlockChunk})(self.teacher[k]) + + # # After FSDP wrapping, set all parameters to not require gradient + # # Teacher models are typically frozen and only updated via EMA + # for param in self.teacher[k].parameters(): + # param.requires_grad_(False) + + # # Print statistics + # total_params = sum(p.numel() for p in self.teacher[k].parameters()) + # print(f"Teacher {k}: Total params: {total_params:,}, All parameters frozen") + + # def prepare_for_distributed_training(self): + # logger.info("DISTRIBUTED FSDP -- preparing model for distributed training") + + # # Before FSDP wrapping, ensure uniform requires_grad within each BlockChunk + # for k, v in self.student.items(): + # for name, module in v.named_modules(): + # if isinstance(module, BlockChunk): + # # # Make requires_grad uniform within each BlockChunk + # # # Option 1: Make all parameters in the BlockChunk trainable if any are trainable + # # any_trainable = any(p.requires_grad for p in module.parameters()) + # # if any_trainable: + # # for param in module.parameters(): + # # param.requires_grad = True + + # # Option 2: Or alternatively, make all parameters in the BlockChunk non-trainable + # for param in module.parameters(): + # param.requires_grad = False + + # # Now proceed with FSDP wrapping + # for k, v in self.student.items(): + # self.teacher[k].load_state_dict(self.student[k].state_dict()) + # student_model_cfg = self.cfg.compute_precision.student[k] + # self.student[k] = get_fsdp_wrapper(student_model_cfg, modules_to_wrap={BlockChunk})(self.student[k]) + # teacher_model_cfg = self.cfg.compute_precision.teacher[k] + # self.teacher[k] = get_fsdp_wrapper(teacher_model_cfg, modules_to_wrap={BlockChunk})(self.teacher[k]) + + + def freeze_original_dino_weights(self): + print("Freezing") + + # Freeze the original DINO weights + for param in self.parameters(): + param.requires_grad = False + + # Unfreeze the pre-encoder, model adapter, and merge blocks + if hasattr(self, 'pre_encoder'): + for param in self.pre_encoder.parameters(): + param.requires_grad = True + if hasattr(self, 'model_adapter'): + for param in self.model_adapter.parameters(): + param.requires_grad = True + if hasattr(self, 'merge_blocks'): + for block in self.merge_blocks: + for param in block.parameters(): + param.requires_grad = True \ No newline at end of file diff --git a/dinov2/train/train.py b/dinov2/train/train.py index 8a63224..b1485b1 100644 --- a/dinov2/train/train.py +++ b/dinov2/train/train.py @@ -12,8 +12,7 @@ from functools import partial from fvcore.common.checkpoint import PeriodicCheckpointer import torch -from dinov2.data import SamplerType, make_data_loader, make_dataset -from dinov2.data import collate_data_and_cast, DataAugmentationDINO, MaskingGenerator + import dinov2.distributed as distributed from dinov2.fsdp import FSDPCheckpointer from dinov2.logging import MetricLogger @@ -118,6 +117,36 @@ def apply_optim_scheduler(optimizer, lr, wd, last_layer_lr): param_group["weight_decay"] = wd * wd_multiplier param_group["lr"] = (last_layer_lr if is_last_layer else lr) * lr_multiplier +from torch.distributed.fsdp import StateDictType + +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import FullStateDictConfig, StateDictType + +def do_test(cfg, model, iteration): + # Ensure we are on the main process before saving + if distributed.is_main_process(): + iterstring = str(iteration) + eval_dir = os.path.join(cfg.train.output_dir, "eval", iterstring) + os.makedirs(eval_dir, exist_ok=True) + + # Define full state dict config + # full_state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + + # # Set FSDP state_dict type before calling state_dict() + # with FSDP.state_dict_type(model.teacher, StateDictType.FULL_STATE_DICT, full_state_dict_config): + # new_state_dict = model.teacher.state_dict() + + from torch.distributed.fsdp import FullStateDictConfig, StateDictType, state_dict_type + + full_state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + with state_dict_type(model, StateDictType.FULL_STATE_DICT, full_state_dict_config): + checkpointer = FSDPCheckpointer(model, save_dir=cfg.train.output_dir) + checkpointer.resume_or_load(cfg.MODEL.WEIGHTS, resume=not args.no_resume) + + # Save teacher checkpoint + # teacher_ckp_path = os.path.join(eval_dir, "teacher_checkpoint.pth") + # torch.save({"teacher": new_state_dict}, teacher_ckp_path) + def do_test(cfg, model, iteration): new_state_dict = model.teacher.state_dict() @@ -133,15 +162,208 @@ def do_test(cfg, model, iteration): from torch.utils.tensorboard import SummaryWriter +# def do_train(cfg, model, resume=False): + +# from dinov2.data import SamplerType, make_data_loader, make_dataset +# from dinov2.data import collate_data_and_cast, DataAugmentationDINO, MaskingGenerator +# writer = SummaryWriter(log_dir="/home/paperspace/Documents/nika_space/dinov2/dinov2/tensorboard_logs/with_merged_blocks_small_lr") +# model.train() +# inputs_dtype = torch.half +# fp16_scaler = model.fp16_scaler # for mixed precision training + +# # setup optimizer +# print("Cfg optim: ", cfg.optim) +# optimizer = build_optimizer(cfg, model.get_params_groups()) +# # optimizer = build_optimizer(cfg, [p for p in model.parameters() if p.requires_grad]) +# # trainable_params = [p for p in model.parameters() if p.requires_grad] +# # optimizer = torch.optim.AdamW(trainable_params, lr=lr) + +# ( +# lr_schedule, +# wd_schedule, +# momentum_schedule, +# teacher_temp_schedule, +# last_layer_lr_schedule, +# ) = build_schedulers(cfg) + +# # checkpointer +# checkpointer = FSDPCheckpointer(model, cfg.train.output_dir, optimizer=optimizer, save_to_disk=True) + +# start_iter = checkpointer.resume_or_load(cfg.MODEL.WEIGHTS, resume=resume).get("iteration", -1) + 1 + +# OFFICIAL_EPOCH_LENGTH = cfg.train.OFFICIAL_EPOCH_LENGTH +# max_iter = cfg.optim.epochs * OFFICIAL_EPOCH_LENGTH + +# periodic_checkpointer = PeriodicCheckpointer( +# checkpointer, +# period=3 * OFFICIAL_EPOCH_LENGTH, +# max_iter=max_iter, +# max_to_keep=3, +# ) + +# # setup data preprocessing + +# img_size = cfg.crops.global_crops_size +# patch_size = cfg.student.patch_size +# n_tokens = (img_size // patch_size) ** 2 +# mask_generator = MaskingGenerator( +# input_size=(img_size // patch_size, img_size // patch_size), +# max_num_patches=0.5 * img_size // patch_size * img_size // patch_size, +# ) + +# data_transform = DataAugmentationDINO( +# cfg.crops.global_crops_scale, +# cfg.crops.local_crops_scale, +# cfg.crops.local_crops_number, +# global_crops_size=cfg.crops.global_crops_size, +# local_crops_size=cfg.crops.local_crops_size, +# ) + +# collate_fn = partial( +# collate_data_and_cast, +# mask_ratio_tuple=cfg.ibot.mask_ratio_min_max, +# mask_probability=cfg.ibot.mask_sample_probability, +# n_tokens=n_tokens, +# mask_generator=mask_generator, +# dtype=inputs_dtype, +# ) + +# # setup data loader + +# dataset = make_dataset( +# dataset_str=cfg.train.dataset_path, +# transform=data_transform, +# target_transform=lambda _: (), +# ) +# # sampler_type = SamplerType.INFINITE +# sampler_type = SamplerType.SHARDED_INFINITE +# data_loader = make_data_loader( +# dataset=dataset, +# batch_size=cfg.train.batch_size_per_gpu, +# num_workers=cfg.train.num_workers, +# shuffle=True, +# seed=start_iter, # TODO: Fix this -- cfg.train.seed +# sampler_type=sampler_type, +# sampler_advance=0, # TODO(qas): fix this -- start_iter * cfg.train.batch_size_per_gpu, +# drop_last=True, +# collate_fn=collate_fn, +# ) + +# # training loop + +# iteration = start_iter + +# logger.info("Starting training from iteration {}".format(start_iter)) +# metrics_file = os.path.join(cfg.train.output_dir, "training_metrics.json") +# metric_logger = MetricLogger(delimiter=" ", output_file=metrics_file) +# header = "Training" + +# for data in metric_logger.log_every( +# data_loader, +# 10, +# header, +# max_iter, +# start_iter, +# ): +# current_batch_size = data["collated_global_crops"].shape[0] / 2 +# if iteration > max_iter: +# return + +# # apply schedules + +# lr = lr_schedule[iteration] +# wd = wd_schedule[iteration] +# mom = momentum_schedule[iteration] +# teacher_temp = teacher_temp_schedule[iteration] +# last_layer_lr = last_layer_lr_schedule[iteration] +# apply_optim_scheduler(optimizer, lr, wd, last_layer_lr) + +# # compute losses + +# optimizer.zero_grad(set_to_none=True) +# loss_dict = model.forward_backward(data, teacher_temp=teacher_temp) + +# # clip gradients + +# if fp16_scaler is not None: +# if cfg.optim.clip_grad: +# fp16_scaler.unscale_(optimizer) +# for v in model.student.values(): +# v.clip_grad_norm_(cfg.optim.clip_grad) +# fp16_scaler.step(optimizer) +# fp16_scaler.update() +# else: +# if cfg.optim.clip_grad: +# for v in model.student.values(): +# v.clip_grad_norm_(cfg.optim.clip_grad) +# optimizer.step() + +# # perform teacher EMA update + +# model.update_teacher(mom) + +# # logging + +# if distributed.get_global_size() > 1: +# for v in loss_dict.values(): +# torch.distributed.all_reduce(v) +# loss_dict_reduced = {k: v.item() / distributed.get_global_size() for k, v in loss_dict.items()} + +# if math.isnan(sum(loss_dict_reduced.values())): +# logger.info("NaN detected") +# raise AssertionError +# losses_reduced = sum(loss for loss in loss_dict_reduced.values()) + +# metric_logger.update(lr=lr) +# metric_logger.update(wd=wd) +# metric_logger.update(mom=mom) +# metric_logger.update(last_layer_lr=last_layer_lr) +# metric_logger.update(current_batch_size=current_batch_size) +# metric_logger.update(total_loss=losses_reduced, **loss_dict_reduced) +# writer.add_scalar('Loss/Total_Loss', losses_reduced, iteration) +# writer.add_scalar('Learning_Rate', lr, iteration) +# writer.add_scalar('Weight_Decay', wd, iteration) +# writer.add_scalar('Momentum', mom, iteration) +# # checkpointing and testing + +# if cfg.evaluation.eval_period_iterations > 0 and (iteration + 1) % cfg.evaluation.eval_period_iterations == 0: +# do_test(cfg, model, f"training_{iteration}") +# torch.cuda.synchronize() +# periodic_checkpointer.step(iteration) + +# iteration = iteration + 1 +# metric_logger.synchronize_between_processes() +# writer.close() +# return {k: meter.global_avg for k, meter in metric_logger.meters.items()} + def do_train(cfg, model, resume=False): - writer = SummaryWriter(log_dir="/home/paperspace/Documents/nika_space/dinov2/dinov2/tensorboard_logs/standart_custom_conf") + from dinov2.data import SamplerType, make_data_loader, make_dataset + from dinov2.data import collate_data_and_cast, DataAugmentationDINO, MaskingGenerator + writer = SummaryWriter(log_dir="/home/paperspace/Documents/nika_space/dinov2/dinov2/tensorboard_logs/with_merged_blocks_small_lr") model.train() inputs_dtype = torch.half fp16_scaler = model.fp16_scaler # for mixed precision training - # setup optimizer + # # Freeze the original DINO weights + # for param in model.parameters(): + # param.requires_grad = False + # # Unfreeze the pre-encoder, model adapter, and merge blocks + # for param in model.pre_encoder.parameters(): + # param.requires_grad = True + # for param in model.model_adapter.parameters(): + # param.requires_grad = True + # for block in model.merge_blocks: + # for param in block.parameters(): + # param.requires_grad = True + + # setup optimizer + print("Cfg optim: ", cfg.optim) optimizer = build_optimizer(cfg, model.get_params_groups()) + # optimizer = build_optimizer(cfg, [p for p in model.parameters() if p.requires_grad]) + # trainable_params = [p for p in model.parameters() if p.requires_grad] + # optimizer = torch.optim.AdamW(trainable_params, lr=lr) + ( lr_schedule, wd_schedule, @@ -221,6 +443,7 @@ def do_train(cfg, model, resume=False): metrics_file = os.path.join(cfg.train.output_dir, "training_metrics.json") metric_logger = MetricLogger(delimiter=" ", output_file=metrics_file) header = "Training" + warmup_iterations = cfg.optim.warmup_epochs * OFFICIAL_EPOCH_LENGTH for data in metric_logger.log_every( data_loader, @@ -233,6 +456,12 @@ def do_train(cfg, model, resume=False): if iteration > max_iter: return + # Unfreeze original DINO weights after warmup + if iteration == warmup_iterations: + for param in model.parameters(): + param.requires_grad = True + + # apply schedules lr = lr_schedule[iteration] @@ -263,7 +492,6 @@ def do_train(cfg, model, resume=False): optimizer.step() # perform teacher EMA update - model.update_teacher(mom) # logging @@ -315,16 +543,17 @@ def parse_merge_block_indexes(config_value: str) -> List[int]: return [] return list(map(int, re.split(r'\s*,\s*', config_value))) - def main(args): cfg = setup(args) + print("cfg.student.merge_block_indexes ", cfg.student.merge_block_indexes) + cfg.student.merge_block_indexes = parse_merge_block_indexes(cfg.student.merge_block_indexes) + print("cfg.student.merge_block_indexes:", cfg.student.merge_block_indexes) - # cfg.student.merge_block_indexes = parse_merge_block_indexes(cfg.student.merge_block_indexes) - # print("INDEXES", cfg.student.merge_block_indexes) model = SSLMetaArch(cfg).to(torch.device("cuda")) model.prepare_for_distributed_training() - - # logger.info("Model:\n{}".format(model)) + + # model.freeze_original_dino_weights() + if args.eval_only: iteration = ( FSDPCheckpointer(model, save_dir=cfg.train.output_dir) @@ -333,7 +562,7 @@ def main(args): + 1 ) return do_test(cfg, model, f"manual_{iteration}") - + do_train(cfg, model, resume=not args.no_resume) diff --git a/dinov2/utils/utils.py b/dinov2/utils/utils.py index 68f8e2c..b19407a 100644 --- a/dinov2/utils/utils.py +++ b/dinov2/utils/utils.py @@ -18,17 +18,26 @@ logger = logging.getLogger("dinov2") def load_pretrained_weights(model, pretrained_weights, checkpoint_key): + if urlparse(pretrained_weights).scheme: # If it looks like an URL state_dict = torch.hub.load_state_dict_from_url(pretrained_weights, map_location="cpu") else: state_dict = torch.load(pretrained_weights, map_location="cpu") - if checkpoint_key is not None and checkpoint_key in state_dict: + state_dict = state_dict['model'] + if checkpoint_key is not None: + print("INSIDE IF") logger.info(f"Take key {checkpoint_key} in provided checkpoint dict") - state_dict = state_dict[checkpoint_key] + state_dict = {key: value for key, value in state_dict.items() if key.startswith(checkpoint_key) } + # state_dict = state_dict[checkpoint_key] # remove `module.` prefix state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} # remove `backbone.` prefix induced by multicrop wrapper - state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} + + state_dict = {k.replace("teacher.backbone.", ""): v for k, v in state_dict.items()} + # print("Model Keys:") + # for key in state_dict.keys(): + # print(key) + # print("State dict: ", state_dict.keys()) msg = model.load_state_dict(state_dict, strict=False) logger.info("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg))