fsdp for raw adapter

pull/511/head
Veronikkkka 2025-03-18 10:22:05 +00:00
parent 63dfb318ec
commit 753b990c64
9 changed files with 691 additions and 101 deletions

View File

@ -5,38 +5,76 @@ compute_precision:
backbone:
sharding_strategy: SHARD_GRAD_OP
mixed_precision:
param_dtype: fp16
reduce_dtype: fp16
param_dtype: fp32
reduce_dtype: fp32
buffer_dtype: fp32
dino_head:
sharding_strategy: SHARD_GRAD_OP
mixed_precision:
param_dtype: fp16
reduce_dtype: fp16
buffer_dtype: fp32
ibot_head:
sharding_strategy: SHARD_GRAD_OP
mixed_precision:
param_dtype: fp16
reduce_dtype: fp16
buffer_dtype: fp32
student:
backbone:
sharding_strategy: SHARD_GRAD_OP
mixed_precision:
param_dtype: fp16
reduce_dtype: fp16
buffer_dtype: fp32
dino_head:
sharding_strategy: SHARD_GRAD_OP
mixed_precision:
param_dtype: fp16
param_dtype: fp32
reduce_dtype: fp32
buffer_dtype: fp32
ibot_head:
sharding_strategy: SHARD_GRAD_OP
mixed_precision:
param_dtype: fp16
param_dtype: fp32
reduce_dtype: fp32
buffer_dtype: fp32
merge_blocks:
sharding_strategy: SHARD_GRAD_OP
mixed_precision:
param_dtype: fp32
reduce_dtype: fp32
buffer_dtype: fp32
pre_encoder:
sharding_strategy: SHARD_GRAD_OP
mixed_precision:
param_dtype: fp32
reduce_dtype: fp32
buffer_dtype: fp32
cast_forward_inputs: False
model_adapter:
sharding_strategy: SHARD_GRAD_OP
mixed_precision:
param_dtype: fp32
reduce_dtype: fp32
buffer_dtype: fp32
student:
backbone:
sharding_strategy: SHARD_GRAD_OP
mixed_precision:
param_dtype: fp32
reduce_dtype: fp32
buffer_dtype: fp32
dino_head:
sharding_strategy: SHARD_GRAD_OP
mixed_precision:
param_dtype: fp32
reduce_dtype: fp32
buffer_dtype: fp32
ibot_head:
sharding_strategy: SHARD_GRAD_OP
mixed_precision:
param_dtype: fp32
reduce_dtype: fp32
buffer_dtype: fp32
merge_blocks:
sharding_strategy: SHARD_GRAD_OP
mixed_precision:
param_dtype: fp32
reduce_dtype: fp32
buffer_dtype: fp32
pre_encoder:
sharding_strategy: SHARD_GRAD_OP
mixed_precision:
param_dtype: fp32
reduce_dtype: fp32
buffer_dtype: fp32
cast_forward_inputs: False
model_adapter:
sharding_strategy: SHARD_GRAD_OP
mixed_precision:
param_dtype: fp32
reduce_dtype: fp32
buffer_dtype: fp32
@ -48,8 +86,8 @@ data_transform: "default"
train:
batch_size_per_gpu: 16 #vitg 26+, vitl: 56, vits:152, vitb:120 for 8 node
num_workers: 1
OFFICIAL_EPOCH_LENGTH: 100 # 1250
dataset_path: ImageNet:root=/home/paperspace/Documents/nika_space/ADE20K/ADEChallengeData2016/images/training_raw/
OFFICIAL_EPOCH_LENGTH: 1000 # 1250
dataset_path: ImageNet:root=/home/paperspace/Documents/nika_space/ADE20K/ADEChallengeData2016/images/training_raw/ #MIT:root=/home/paperspace/Documents/nika_space/mit_dataset/train/
centering: sinkhorn_knopp
drop_path_rate: 0.4
@ -61,8 +99,8 @@ train:
teacher:
momentum_teacher: 0.994
optim:
epochs: 20 # 500
weight_decay_end: 0.2
epochs: 50 # 500
weight_decay_end: 0.3
base_lr: 0.0001 # learning rate for a batch size of 1024
warmup_epochs: 20 # 80
layerwise_decay: 1.0
@ -76,7 +114,7 @@ evaluation:
# "dinov2_vits14","dinov2_vitb14","dinov2_vitl14","dinov2_vitg14"
student:
arch: vit_base
arch: dinov2_vitb14
patch_size: 14
merge_block_indexes: "" # num, num, num,
crops:

View File

@ -17,6 +17,7 @@ from torch.distributed.fsdp import StateDictType
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torch.distributed.fsdp._runtime_utils import _reshard
import torch.nn as nn
def get_fsdp_wrapper(model_cfg, modules_to_wrap=set()):
@ -32,16 +33,22 @@ def get_fsdp_wrapper(model_cfg, modules_to_wrap=set()):
"bf16": torch.bfloat16,
}
param_dtype = dtype_dict[model_cfg.mixed_precision.param_dtype]
reduce_dtype = dtype_dict[model_cfg.mixed_precision.reduce_dtype]
buffer_dtype = dtype_dict[model_cfg.mixed_precision.buffer_dtype]
mixed_precision_config = MixedPrecision(
param_dtype=dtype_dict[model_cfg.mixed_precision.param_dtype],
reduce_dtype=dtype_dict[model_cfg.mixed_precision.reduce_dtype],
buffer_dtype=dtype_dict[model_cfg.mixed_precision.buffer_dtype],
param_dtype=param_dtype,
reduce_dtype=reduce_dtype,
buffer_dtype=buffer_dtype,
cast_forward_inputs=model_cfg.mixed_precision.get("cast_forward_inputs", True)
)
sharding_strategy_config = sharding_strategy_dict[model_cfg.sharding_strategy]
local_rank = distributed.get_local_rank()
print("Modules to wrap: ", modules_to_wrap)
fsdp_wrapper = partial(
FSDP,
sharding_strategy=sharding_strategy_config,
@ -51,6 +58,7 @@ def get_fsdp_wrapper(model_cfg, modules_to_wrap=set()):
use_orig_params=True,
auto_wrap_policy=ModuleWrapPolicy(modules_to_wrap),
)
return fsdp_wrapper
@ -112,7 +120,7 @@ class FSDPCheckpointer(Checkpointer):
self.tag_last_checkpoint(basename)
def load(self, *args, **kwargs):
with FSDP.state_dict_type(self.model, StateDictType.LOCAL_STATE_DICT):
with FSDP.state_dict_type(self.model, StateDictType.FULL_STATE_DICT):
return super().load(*args, **kwargs)
def has_checkpoint(self) -> bool:

View File

@ -28,6 +28,7 @@ def _make_dinov2_model(
interpolate_offset: float = 0.1,
pretrained: bool = True,
weights: Union[Weights, str] = Weights.LVD142M,
merge_block_indexes: list[int] = [0],
**kwargs,
):
from ..models import vision_transformer as vits
@ -48,8 +49,10 @@ def _make_dinov2_model(
num_register_tokens=num_register_tokens,
interpolate_antialias=interpolate_antialias,
interpolate_offset=interpolate_offset,
merge_block_indexes=merge_block_indexes
)
vit_kwargs.update(**kwargs)
print("Kwargs: ", vit_kwargs)
model = vits.__dict__[arch_name](**vit_kwargs)
if pretrained:

View File

@ -26,8 +26,10 @@ def build_model(args, only_teacher=False, img_size=224):
num_register_tokens=args.num_register_tokens,
interpolate_offset=args.interpolate_offset,
interpolate_antialias=args.interpolate_antialias,
# merge_blocks_indexes=args.merge_block_indexes,
merge_block_indexes=args.student.merge_block_indexes,
)
print( vits.__dict__.keys())
print("Merge block ind: ", args.student.merge_block_indexes)
teacher = vits.__dict__[args.arch](**vit_kwargs)
if only_teacher:
return teacher, teacher.embed_dim
@ -41,4 +43,5 @@ def build_model(args, only_teacher=False, img_size=224):
def build_model_from_cfg(cfg, only_teacher=False):
print("Only teacher", only_teacher)
return build_model(cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size)

View File

@ -25,28 +25,30 @@ class Merge_block(BaseModule):
self.ada_c = ada_c
# 784 - embedded dim + adapter_c
self.embeded_dim = 768
self.fc_1 = nn.Linear(self.embeded_dim*2, mid_c).to(torch.float16)
self.fc_2 = nn.Linear(mid_c, self.embeded_dim).to(torch.float16)
self.fc_1 = nn.Linear(self.embeded_dim*2, mid_c)
print("Fc 1 type: ", self.fc_1.weight.dtype, self.fc_1.bias.dtype)
self.fc_2 = nn.Linear(mid_c, self.embeded_dim)
self.return_ada = return_ada
if self.return_ada:
self.conv_3 = nn.Conv1d(mid_c, self.embeded_dim, kernel_size=1).to(torch.float16) # 1D Conv instead of 3x3
self.conv_3 = nn.Conv1d(mid_c, self.embeded_dim, kernel_size=1) # 1D Conv instead of 3x3
else:
self.conv_3 = None
def forward(self, fea, adapter, ratio=1.0):
res = fea
# print("Before concatenation: ", fea.shape, adapter.shape, self.fea_c, self.ada_c)
# print("before concatenation: ", fea.shape, adapter.shape)
# print("before concatenation: ", fea.dtype, adapter.dtype)
fea = torch.cat([fea, adapter], dim=-1) # (B, seq_len, fea_c + ada_c)
# print("after concatenation: ", fea.shape, adapter.shape)
B, seq_len, C = fea.shape
fea = fea.view(B * seq_len, C)
# print("before concatenation: ", fea.dtype, adapter.dtype)
fea = fea.to(self.fc_1.weight.dtype)
fea = self.fc_1(fea)
fea = fea.view(B, seq_len, -1)
ada = self.fc_2(fea)
fea_out = ratio * ada + res
if self.return_ada:
ada = self.conv_3(fea.permute(0, 2, 1))
@ -313,7 +315,7 @@ class CustomLayerNorm(nn.Module):
# Predictor P_K
class Kernel_Predictor(nn.Module):
def __init__(self, dim, mode='low', num_heads=1, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
def __init__(self, dim, mode='normal', num_heads=1, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
@ -546,7 +548,7 @@ def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Te
return img
def Gain_Denoise(I1, r1, r2, gain, sigma, k_size=3): # [9, 9] in LOD dataset, [3, 3] in other dataset
def Gain_Denoise(I1, r1, r2, gain, sigma, k_size=1): # [9, 9] in LOD dataset, [3, 3] in other dataset
out = []
for i in range(I1.shape[0]):
I1_gain = gain[i] * I1[i,:,:,:]
@ -583,11 +585,11 @@ def WB_CCM(I2, ccm_matrix, distance):
out_I4 = []
for i in range(I2.shape[0]):
# SOG White Balance Algorithm
I3 = SoG_algo(I2[i,:,:,:], distance[i])
I3 = SoG_algo(I2[i,:,:,:])
# Camera Color Matrix
I4 = torch.tensordot(I3, ccm_matrix[i,:,:], dims=[[-1], [-1]])
I4 = torch.clamp(I4, 1e-5, 1.0)
I4 = torch.clamp(I4, 1e-7, 1.0)
out_I3.append(I3)
out_I4.append(I4)
@ -620,7 +622,7 @@ class VitInputLevelAdapter(nn.Module):
# (1). I1 --> I2: Denoise & Enhancement & Sharpen
r1, r2, gain, sigma = self.Predictor_K(I1)
I2 = Gain_Denoise(I1, r1, r2, gain, sigma, k_size=self.k_size) # (B,C,H,W)
I2 = torch.clamp(I2, 1e-5, 1.0) # normal & over-exposure
I2 = torch.clamp(I2, 1e-7, 1.0) # normal & over-exposure
ccm_matrix, distance = self.Predictor_M(I2)
# (2). I2 --> I3: White Balance, Shade of Gray

View File

@ -19,7 +19,7 @@ from torch.nn.init import trunc_normal_
import torch.nn.functional as F
from dinov2.layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
from dinov2.models.help import Merge_block, Model_level_Adapeter
from dinov2.models.help import VitInputLevelAdapter as Input_level_Adapeter
from dinov2.models.input_level_adapter import Input_level_Adapeter
logger = logging.getLogger("dinov2")
@ -68,17 +68,18 @@ class DinoVisionTransformer(nn.Module):
interpolate_antialias=False,
interpolate_offset=0.1,
# RAW adapter parameters
w_lut=False,
w_lut=True,
light_mode='normal',
lut_dim=32,
k_size=3,
merge_ratio=1.0,
model_adapter_path='/home/paperspace/Documents/nika_space/ECCV_RAW_Adapter/extracted_model_adapter_weights.pth',
input_level_adapter_path='/home/paperspace/Documents/nika_space/ECCV_RAW_Adapter/extracted_pre_encoder_weights.pth',
input_level_adapter_path='/home/paperspace/Documents/nika_space/ECCV_RAW_Adapter/mmsegmentation_github/extracted_pre_encoder_weights.pth',
fea_c_s = [384, 768, 1920],
ada_c_s = [16, 32, 64],
mid_c_s = [384, 576, 768],
# merge_blocks_indexes=[],
merge_block_indexes=[],
is_teacher=False,
):
"""
Args:
@ -190,28 +191,31 @@ class DinoVisionTransformer(nn.Module):
self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
# Initialize RAW adapter
# self.merge_block_indexes = merge_block_indexes
self.merge_block_indexes = [0, 6, 10]
print("Before if:")
if self.w_lut:
self.pre_encoder = Input_level_Adapeter(mode=light_mode, lut_dim=lut_dim, k_size=k_size, w_lut=w_lut)
for param in self.pre_encoder.parameters():
param.requires_grad_(True)
# self.model_adapter = Model_level_Adapeter(in_c=in_chans, w_lut=w_lut)
self.model_adapter = Model_level_Adapeter(in_c=in_chans, w_lut=w_lut)
self.model_adapter = Model_level_Adapeter(in_c=3, in_dim=ada_c_s[0], w_lut=self.w_lut)
# if model_adapter_path is not None:
# print("Loading model adapter:", model_adapter_path)
# adapter_state = torch.load(model_adapter_path, map_location="cpu")
# self.model_adapter.load_state_dict(adapter_state, strict=False)
# if input_level_adapter_path is not None:
# print("Loading input-level adapter:", input_level_adapter_path)
# adapter_state = torch.load(input_level_adapter_path, map_location="cpu")
# self.pre_encoder.load_state_dict(adapter_state)
if model_adapter_path is not None:
print("Loading model adapter:", model_adapter_path)
adapter_state = torch.load(model_adapter_path, map_location="cpu")
self.model_adapter.load_state_dict(adapter_state, strict=False)
if input_level_adapter_path is not None:
print("Loading input-level adapter:", input_level_adapter_path)
adapter_state = torch.load(input_level_adapter_path, map_location="cpu")
self.pre_encoder.load_state_dict(adapter_state)
self.merge_blocks = []
self.merge_blocks_indexes = merge_blocks_indexes
# Loop through the merge_blocks_indexes and create Merge_block instances
for i, idx in enumerate(self.merge_blocks_indexes):
return_ada = False if i == len(self.merge_blocks_indexes) - 1 else True # Only the last block gets return_ada=False
if i != 0 or i != len(self.merge_blocks_indexes) - 1:
for i, idx in enumerate(self.merge_block_indexes):
return_ada = False if i == len(self.merge_block_indexes) - 1 else True # Only the last block gets return_ada=False
if i != 0 or i != len(self.merge_block_indexes) - 1:
k = 1
else:
k = i
@ -221,12 +225,57 @@ class DinoVisionTransformer(nn.Module):
mid_c=mid_c_s[k],
return_ada=return_ada
).to("cuda")
self.merge_blocks.append(merge_block)
# self.merge_blocks.to("cuda")
print(self.merge_blocks)
# merge_block_state = torch.load("/home/paperspace/Documents/nika_space/ECCV_RAW_Adapter/mmsegmentation_github/extracted_merged_blocks_weights_2.pth", map_location="cpu")
# merge_block.load_state_dict(merge_block_state)
# if is_teacher:
# for param in merge_block.parameters():
# param.requires_grad = False
self.merge_blocks.append(merge_block)
# self.merge_blocks.to("cuda")
print("MERGED BLOCKS:", self.merge_block_indexes)
# # Freeze the patch embedding
# for param in self.patch_embed.parameters():
# param.requires_grad_(False)
# # Freeze the position embedding
# if hasattr(self, 'pos_embed') and self.pos_embed is not None:
# self.pos_embed.requires_grad_(False)
# # Freeze the cls token if it exists
# if hasattr(self, 'cls_token') and self.cls_token is not None:
# self.cls_token.requires_grad_(False)
# # Freeze the transformer blocks
# for block in self.blocks:
# for param in block.parameters():
# param.requires_grad_(False)
# # Freeze the norm layer if it exists
# if hasattr(self, 'norm') and self.norm is not None:
# for param in self.norm.parameters():
# param.requires_grad_(False)
# # Freeze any head/projection layers
# if hasattr(self, 'head') and self.head is not None:
# for param in self.head.parameters():
# param.requires_grad_(False)
self.init_weights()
# self.freeze_dino_weights()
def freeze_dino_weights(self):
"""Freeze all original DINO weights and keep only the adapter trainable."""
for name, param in self.named_parameters():
# Unfreeze only pre_encoder, model_adapter, and merge_blocks
if not any(sub in name for sub in ["pre_encoder", "model_adapter", "merge_blocks"]):
param.requires_grad = False
else:
param.requires_grad = True # Ensure adapters are trainable
def init_weights(self):
trunc_normal_(self.pos_embed, std=0.02)
@ -284,33 +333,33 @@ class DinoVisionTransformer(nn.Module):
def prepare_tokens_with_masks(self, x, masks=None):
B, nc, w, h = x.shape
# print("BLOCKS NUM: " , len(self.blocks), len(self.merge_blocks))
if self.w_lut: # I1, I2, I3, I4
x_raw = self.pre_encoder(x)
if self.w_lut: # I1, I2, I3, I4
ada = self.model_adapter([x_raw[0], x_raw[1], x_raw[2], x_raw[3]])
# else: # I1, I2, I3
# ada = self.model_adapter([x_raw[0], x_raw[1], x_raw[2]])
else: # I1, I2, I3
ada = self.model_adapter([x_raw[0], x_raw[1], x_raw[2]])
# x = x_raw[-1]
x = x_raw[-1]
# print("X before patch embedding ",ada.shape, x.shape )
x = self.patch_embed(x)
# if x.shape[1] == 256:
# ada = F.interpolate(ada, size=(64, 64), mode='bilinear', align_corners=False)
# elif x.shape[1] == 49:
# ada = F.interpolate(ada, size=(28, 28), mode='bilinear', align_corners=False)
if x.shape[1] == 256:
ada = F.interpolate(ada, size=(64, 64), mode='bilinear', align_corners=False)
elif x.shape[1] == 49:
ada = F.interpolate(ada, size=(28, 28), mode='bilinear', align_corners=False)
# print("ada.shape ", ada.shape, x.shape)
# ada = self.patch_embed_for_model_adapter(ada)
ada = self.patch_embed_for_model_adapter(ada)
# tensor2_reshaped = ada.transpose(1, 2) # [32, 768, 196]
# print("ada.shape after embedding ", ada.shape, x.shape)
if masks is not None:
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
# ada = torch.where(masks.unsqueeze(-1), self.mask_token.to(ada.dtype).unsqueeze(0), ada)
ada = torch.where(masks.unsqueeze(-1), self.mask_token.to(ada.dtype).unsqueeze(0), ada)
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
# ada = torch.cat((self.cls_token.expand(ada.shape[0], -1, -1), ada), dim=1)
# ada = ada + self.interpolate_pos_encoding(ada, w, h)
ada = torch.cat((self.cls_token.expand(ada.shape[0], -1, -1), ada), dim=1)
ada = ada + self.interpolate_pos_encoding(ada, w, h)
x = x + self.interpolate_pos_encoding(x, w, h)
@ -328,8 +377,8 @@ class DinoVisionTransformer(nn.Module):
# print("x.shape",x.shape)
# return x, ada
return x, None
return x, ada
# return x, None
def forward_features_list(self, x_list, masks_list):
@ -349,7 +398,8 @@ class DinoVisionTransformer(nn.Module):
x = blk(x)
if self.w_lut and ada is not None and i in self.merge_blocks_indexes:
if self.w_lut and ada is not None and i in self.merge_block_indexes:
# print("ERROR!")
x_ada_pairs = [self.merge_blocks[indx](x_i, ada_i, ratio=self.merge_ratio) for x_i, ada_i in zip(x, ada_list)]
x, ada_list = map(list, zip(*x_ada_pairs))
indx += 1
@ -376,9 +426,11 @@ class DinoVisionTransformer(nn.Module):
x, ada = self.prepare_tokens_with_masks(x, masks)
indx = 0
for i, blk in enumerate(self.blocks):
# print("X type: ", x.dtype)
x = blk(x)
if self.w_lut and ada is not None and i in self.merge_blocks_indexes:
if self.w_lut and ada is not None and i in self.merge_block_indexes:
# print("HERE 11", x.shape, ada.shape)
# print("ERROR!")
x, ada = self.merge_blocks[indx](x, ada, ratio=self.merge_ratio)
indx += 1
@ -392,7 +444,7 @@ class DinoVisionTransformer(nn.Module):
}
def _get_intermediate_layers_not_chunked(self, x, n=1):
x = self.prepare_tokens_with_masks(x)
x, ada = self.prepare_tokens_with_masks(x)
# If n is an int, take the n last blocks. If it's a list, take them
output, total_block_len = [], len(self.blocks)
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
@ -404,7 +456,7 @@ class DinoVisionTransformer(nn.Module):
return output
def _get_intermediate_layers_chunked(self, x, n=1):
x = self.prepare_tokens_with_masks(x)
x, ada = self.prepare_tokens_with_masks(x)
output, i, total_block_len = [], 0, len(self.blocks[-1])
# If n is an int, take the n last blocks. If it's a list, take them
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n

View File

@ -26,6 +26,39 @@ except ImportError:
logger = logging.getLogger("dinov2")
import math
def interpolate_pos_encoding(x, w, h):
N = x.shape[1] - 1
dim = x.shape[-1]
w0 = w / int(math.sqrt(N))
h0 = h / int(math.sqrt(N))
# Interpolate the position embeddings without changing the first row (class token)
patch_pos_embed = nn.functional.interpolate(
x[:, 1:].reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
scale_factor=(w0, h0),
mode="bicubic",
)
# assert int(w0) == patch_pos_embed.shape[-2]
# assert int(h0) == patch_pos_embed.shape[-1]
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
# Concatenate the class token with the interpolated position embeddings
return torch.cat((x[:, :1], patch_pos_embed), dim=1)
def get_downloaded_dino_vit_interpolated(modelname="dinov2_vits14", merge_block_indexes=[], is_teacher=False):
print("HERE !")
model = torch.hub.load("facebookresearch/dinov2", modelname, pretrained=False, merge_block_indexes=merge_block_indexes, is_teacher=is_teacher) #
print("HERE 2")
input_tensor = model.pos_embed
input_tensor = input_tensor.to('cuda')
tensor_corr_shape = interpolate_pos_encoding(input_tensor, 16, 16)
pos_embed = nn.Parameter(torch.zeros(1, 257))
pos_embed.data = tensor_corr_shape
model.pos_embed = pos_embed
return model
class SSLMetaArch(nn.Module):
@ -37,7 +70,14 @@ class SSLMetaArch(nn.Module):
student_model_dict = dict()
teacher_model_dict = dict()
student_backbone, teacher_backbone, embed_dim = build_model_from_cfg(cfg)
if cfg.student.arch in ["dinov2_vits14", "dinov2_vitb14", "dinov2_vitl14", "dinov2_vitg14"]:
print("Load pre-trained encoder:")
student_backbone = get_downloaded_dino_vit_interpolated(cfg.student.arch, cfg.student.merge_block_indexes)
teacher_backbone = get_downloaded_dino_vit_interpolated(cfg.student.arch, cfg.student.merge_block_indexes, is_teacher=True)
embed_dict = {"dinov2_vits14": 384, "dinov2_vitb14": 768, "dinov2_vitl14": 1024, "dinov2_vitg14": 1536}
embed_dim = embed_dict[cfg.student.arch]
else:
student_backbone, teacher_backbone, embed_dim = build_model_from_cfg(cfg)
student_model_dict["backbone"] = student_backbone
teacher_model_dict["backbone"] = teacher_backbone
logger.info(f"OPTIONS -- architecture : embed_dim: {embed_dim}")
@ -47,8 +87,9 @@ class SSLMetaArch(nn.Module):
logger.info(f"OPTIONS -- pretrained weights: loading from {cfg.student.pretrained_weights}")
student_backbone.load_state_dict(chkpt["model"], strict=False)
# if cfg.student.merge_block_indexes:
# merge_blocks_ind = cfg.student.merge_block_indexes
print("Before IF")
if cfg.student.merge_block_indexes:
self.merge_blocks_ind = cfg.student.merge_block_indexes
self.embed_dim = embed_dim
self.dino_out_dim = cfg.dino.head_n_prototypes
@ -115,6 +156,7 @@ class SSLMetaArch(nn.Module):
self.need_to_synchronize_fsdp_streams = True
print("Student model dict: ", student_model_dict)
self.student = nn.ModuleDict(student_model_dict)
self.teacher = nn.ModuleDict(teacher_model_dict)
@ -392,12 +434,216 @@ class SSLMetaArch(nn.Module):
def prepare_for_distributed_training(self):
logger.info("DISTRIBUTED FSDP -- preparing model for distributed training")
if has_batchnorms(self.student):
raise NotImplementedError
# if has_batchnorms(self.student):
# raise NotImplementedError
# below will synchronize all student subnetworks across gpus:
# print("Self.student: ", self.student.keys())
# for name, param in self.named_parameters():
# # Unfreeze only pre_encoder, model_adapter, and merge_blocks
# if not any(sub in name for sub in ["pre_encoder", "model_adapter", "merge_blocks"]):
# param.requires_grad = False
# else:
# param.requires_grad = True
print("Student keys: ", self.student.items())
for k, v in self.student.items():
self.teacher[k].load_state_dict(self.student[k].state_dict())
student_model_cfg = self.cfg.compute_precision.student[k]
print("Cfg: ", student_model_cfg)
self.student[k] = get_fsdp_wrapper(student_model_cfg, modules_to_wrap={BlockChunk})(self.student[k])
teacher_model_cfg = self.cfg.compute_precision.teacher[k]
self.teacher[k] = get_fsdp_wrapper(teacher_model_cfg, modules_to_wrap={BlockChunk})(self.teacher[k])
print("Type: ", type(self.student["backbone"].pre_encoder))
# if self.student["backbone"].pre_encoder:
# cfg = self.cfg.compute_precision.student.pre_encoder
# self.student["backbone"].pre_encoder = get_fsdp_wrapper(cfg, modules_to_wrap={nn.Module})(self.student["backbone"].pre_encoder)
cfg = self.cfg.compute_precision.student.pre_encoder
# if not isinstance(self.student["backbone"].pre_encoder.Predictor_K, FSDP):
# self.student["backbone"].pre_encoder.Predictor_K = get_fsdp_wrapper(cfg, modules_to_wrap={nn.Linear, nn.Conv2d, nn.LayerNorm})(self.student["backbone"].pre_encoder.Predictor_K )
import dinov2.distributed as distributed
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
self.student["backbone"].pre_encoder.Predictor_K = FSDP(
self.student["backbone"].pre_encoder.Predictor_K,
sharding_strategy=cfg.sharding_strategy,
mixed_precision=cfg.mixed_precision,
device_id=distributed.get_local_rank(),
sync_module_states=True,
use_orig_params=True,
)
self.teacher["backbone"].pre_encoder.Predictor_K = FSDP(
self.teacher["backbone"].pre_encoder.Predictor_K,
sharding_strategy=cfg.sharding_strategy,
mixed_precision=cfg.mixed_precision,
device_id=distributed.get_local_rank(),
sync_module_states=True,
use_orig_params=True,
)
# if not isinstance(self.student["backbone"].pre_encoder.Predictor_M, FSDP):
# self.student["backbone"].pre_encoder.Predictor_M = get_fsdp_wrapper(cfg, modules_to_wrap={nn.Module})(self.student["backbone"].pre_encoder.Predictor_M )
self.student["backbone"].pre_encoder.Predictor_M = FSDP(
self.student["backbone"].pre_encoder.Predictor_M,
sharding_strategy=cfg.sharding_strategy,
mixed_precision=cfg.mixed_precision,
device_id=distributed.get_local_rank(),
sync_module_states=True,
use_orig_params=True,
)
self.teacher["backbone"].pre_encoder.Predictor_M = FSDP(
self.teacher["backbone"].pre_encoder.Predictor_M,
sharding_strategy=cfg.sharding_strategy,
mixed_precision=cfg.mixed_precision,
device_id=distributed.get_local_rank(),
sync_module_states=True,
use_orig_params=True,
)
# if not isinstance(self.student["backbone"].pre_encoder.LUT, FSDP):
self.student["backbone"].pre_encoder.LUT = get_fsdp_wrapper(cfg, modules_to_wrap={nn.Module})(self.student["backbone"].pre_encoder.LUT)
self.teacher["backbone"].pre_encoder.LUT = get_fsdp_wrapper(cfg, modules_to_wrap={nn.Module})(self.teacher["backbone"].pre_encoder.LUT)
cfg = self.cfg.compute_precision.student.model_adapter
self.student["backbone"].model_adapter = get_fsdp_wrapper(cfg, modules_to_wrap={nn.Module})(self.student["backbone"].model_adapter)
self.teacher["backbone"].model_adapter = get_fsdp_wrapper(cfg, modules_to_wrap={nn.Module})(self.teacher["backbone"].model_adapter)
cfg = self.cfg.compute_precision.student.merge_blocks
for block in range(len(self.student["backbone"].merge_blocks)):
self.student["backbone"].merge_blocks[block] = get_fsdp_wrapper(cfg, modules_to_wrap={nn.Module})(self.student["backbone"].merge_blocks[block])
# def prepare_for_distributed_training(self):
# logger.info("DISTRIBUTED FSDP -- preparing model for distributed training")
# # Synchronize all student subnetworks across GPUs
# for k, v in self.student.items():
# self.teacher[k].load_state_dict(self.student[k].state_dict())
# student_model_cfg = self.cfg.compute_precision.student[k]
# self.student[k] = get_fsdp_wrapper(student_model_cfg, modules_to_wrap={BlockChunk})(self.student[k])
# teacher_model_cfg = self.cfg.compute_precision.teacher[k]
# self.teacher[k] = get_fsdp_wrapper(teacher_model_cfg, modules_to_wrap={BlockChunk})(self.teacher[k])
# # Wrap the pre-encoder, model adapter, and merge blocks separately
# if hasattr(self, 'pre_encoder'):
# logger.info("Wrapping pre-encoder for FSDP")
# self.pre_encoder = get_fsdp_wrapper(self.cfg.compute_precision.student, modules_to_wrap={nn.Module})(self.pre_encoder)
# if hasattr(self, 'model_adapter'):
# logger.info("Wrapping model adapter for FSDP")
# self.model_adapter = get_fsdp_wrapper(self.cfg.compute_precision.student, modules_to_wrap={nn.Module})(self.model_adapter)
# if hasattr(self, 'merge_blocks'):
# logger.info("Wrapping merge blocks for FSDP")
# for i, block in enumerate(self.merge_blocks):
# self.merge_blocks[i] = get_fsdp_wrapper(self.cfg.compute_precision.student, modules_to_wrap={nn.Module})(block)
# def prepare_for_distributed_training(self):
# """
# Prepare both student and teacher models for distributed training using FSDP.
# This handles the issue of having both trainable and non-trainable parameters.
# """
# from torch.distributed.fsdp.wrap import ModuleWrapPolicy
# import copy
# # Process student models
# for k in self.student.keys():
# print("Preparing student model for distributed training:", k)
# # First make all parameters trainable to satisfy FSDP requirements
# for param in self.student[k].parameters():
# param.requires_grad_(True)
# # Wrap with FSDP
# student_model_cfg = self.cfg.compute_precision.student[k]
# print("Student __: ", self.student[k], student_model_cfg)
# self.student[k] = get_fsdp_wrapper(student_model_cfg, modules_to_wrap={BlockChunk})(self.student[k])
# # After FSDP wrapping, set appropriate parameters to not require gradient
# for name, param in self.student[k].named_parameters():
# # Freeze DINO backbone parameters
# if any([x in name for x in [
# 'pre_encoder', 'model_adapter', 'merge_blocks'
# ]]):
# # print("HERE HERE", name)
# param.requires_grad_(True)
# else:
# print("Name: ", name)
# param.requires_grad_(False)
# # print("Second option")
# # Print statistics about trainable parameters
# trainable_params = 0
# total_params = 0
# for name, param in self.student[k].named_parameters():
# total_params += param.numel()
# if param.requires_grad:
# trainable_params += param.numel()
# print(f"Student {k}: Total params: {total_params:,}, Trainable params: {trainable_params:,} ({trainable_params/total_params:.2%})")
# # Process teacher models
# for k in self.teacher.keys():
# print("Preparing teacher model for distributed training:", k)
# # Teacher model doesn't need gradient computation during training
# # But we first make all parameters trainable for FSDP wrapping
# for param in self.teacher[k].parameters():
# param.requires_grad_(True)
# # Wrap with FSDP
# teacher_model_cfg = self.cfg.compute_precision.teacher[k]
# print("Teacher __: ", self.teacher[k], teacher_model_cfg)
# self.teacher[k] = get_fsdp_wrapper(teacher_model_cfg, modules_to_wrap={BlockChunk})(self.teacher[k])
# # After FSDP wrapping, set all parameters to not require gradient
# # Teacher models are typically frozen and only updated via EMA
# for param in self.teacher[k].parameters():
# param.requires_grad_(False)
# # Print statistics
# total_params = sum(p.numel() for p in self.teacher[k].parameters())
# print(f"Teacher {k}: Total params: {total_params:,}, All parameters frozen")
# def prepare_for_distributed_training(self):
# logger.info("DISTRIBUTED FSDP -- preparing model for distributed training")
# # Before FSDP wrapping, ensure uniform requires_grad within each BlockChunk
# for k, v in self.student.items():
# for name, module in v.named_modules():
# if isinstance(module, BlockChunk):
# # # Make requires_grad uniform within each BlockChunk
# # # Option 1: Make all parameters in the BlockChunk trainable if any are trainable
# # any_trainable = any(p.requires_grad for p in module.parameters())
# # if any_trainable:
# # for param in module.parameters():
# # param.requires_grad = True
# # Option 2: Or alternatively, make all parameters in the BlockChunk non-trainable
# for param in module.parameters():
# param.requires_grad = False
# # Now proceed with FSDP wrapping
# for k, v in self.student.items():
# self.teacher[k].load_state_dict(self.student[k].state_dict())
# student_model_cfg = self.cfg.compute_precision.student[k]
# self.student[k] = get_fsdp_wrapper(student_model_cfg, modules_to_wrap={BlockChunk})(self.student[k])
# teacher_model_cfg = self.cfg.compute_precision.teacher[k]
# self.teacher[k] = get_fsdp_wrapper(teacher_model_cfg, modules_to_wrap={BlockChunk})(self.teacher[k])
def freeze_original_dino_weights(self):
print("Freezing")
# Freeze the original DINO weights
for param in self.parameters():
param.requires_grad = False
# Unfreeze the pre-encoder, model adapter, and merge blocks
if hasattr(self, 'pre_encoder'):
for param in self.pre_encoder.parameters():
param.requires_grad = True
if hasattr(self, 'model_adapter'):
for param in self.model_adapter.parameters():
param.requires_grad = True
if hasattr(self, 'merge_blocks'):
for block in self.merge_blocks:
for param in block.parameters():
param.requires_grad = True

View File

@ -12,8 +12,7 @@ from functools import partial
from fvcore.common.checkpoint import PeriodicCheckpointer
import torch
from dinov2.data import SamplerType, make_data_loader, make_dataset
from dinov2.data import collate_data_and_cast, DataAugmentationDINO, MaskingGenerator
import dinov2.distributed as distributed
from dinov2.fsdp import FSDPCheckpointer
from dinov2.logging import MetricLogger
@ -118,6 +117,36 @@ def apply_optim_scheduler(optimizer, lr, wd, last_layer_lr):
param_group["weight_decay"] = wd * wd_multiplier
param_group["lr"] = (last_layer_lr if is_last_layer else lr) * lr_multiplier
from torch.distributed.fsdp import StateDictType
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import FullStateDictConfig, StateDictType
def do_test(cfg, model, iteration):
# Ensure we are on the main process before saving
if distributed.is_main_process():
iterstring = str(iteration)
eval_dir = os.path.join(cfg.train.output_dir, "eval", iterstring)
os.makedirs(eval_dir, exist_ok=True)
# Define full state dict config
# full_state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
# # Set FSDP state_dict type before calling state_dict()
# with FSDP.state_dict_type(model.teacher, StateDictType.FULL_STATE_DICT, full_state_dict_config):
# new_state_dict = model.teacher.state_dict()
from torch.distributed.fsdp import FullStateDictConfig, StateDictType, state_dict_type
full_state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with state_dict_type(model, StateDictType.FULL_STATE_DICT, full_state_dict_config):
checkpointer = FSDPCheckpointer(model, save_dir=cfg.train.output_dir)
checkpointer.resume_or_load(cfg.MODEL.WEIGHTS, resume=not args.no_resume)
# Save teacher checkpoint
# teacher_ckp_path = os.path.join(eval_dir, "teacher_checkpoint.pth")
# torch.save({"teacher": new_state_dict}, teacher_ckp_path)
def do_test(cfg, model, iteration):
new_state_dict = model.teacher.state_dict()
@ -133,15 +162,208 @@ def do_test(cfg, model, iteration):
from torch.utils.tensorboard import SummaryWriter
# def do_train(cfg, model, resume=False):
# from dinov2.data import SamplerType, make_data_loader, make_dataset
# from dinov2.data import collate_data_and_cast, DataAugmentationDINO, MaskingGenerator
# writer = SummaryWriter(log_dir="/home/paperspace/Documents/nika_space/dinov2/dinov2/tensorboard_logs/with_merged_blocks_small_lr")
# model.train()
# inputs_dtype = torch.half
# fp16_scaler = model.fp16_scaler # for mixed precision training
# # setup optimizer
# print("Cfg optim: ", cfg.optim)
# optimizer = build_optimizer(cfg, model.get_params_groups())
# # optimizer = build_optimizer(cfg, [p for p in model.parameters() if p.requires_grad])
# # trainable_params = [p for p in model.parameters() if p.requires_grad]
# # optimizer = torch.optim.AdamW(trainable_params, lr=lr)
# (
# lr_schedule,
# wd_schedule,
# momentum_schedule,
# teacher_temp_schedule,
# last_layer_lr_schedule,
# ) = build_schedulers(cfg)
# # checkpointer
# checkpointer = FSDPCheckpointer(model, cfg.train.output_dir, optimizer=optimizer, save_to_disk=True)
# start_iter = checkpointer.resume_or_load(cfg.MODEL.WEIGHTS, resume=resume).get("iteration", -1) + 1
# OFFICIAL_EPOCH_LENGTH = cfg.train.OFFICIAL_EPOCH_LENGTH
# max_iter = cfg.optim.epochs * OFFICIAL_EPOCH_LENGTH
# periodic_checkpointer = PeriodicCheckpointer(
# checkpointer,
# period=3 * OFFICIAL_EPOCH_LENGTH,
# max_iter=max_iter,
# max_to_keep=3,
# )
# # setup data preprocessing
# img_size = cfg.crops.global_crops_size
# patch_size = cfg.student.patch_size
# n_tokens = (img_size // patch_size) ** 2
# mask_generator = MaskingGenerator(
# input_size=(img_size // patch_size, img_size // patch_size),
# max_num_patches=0.5 * img_size // patch_size * img_size // patch_size,
# )
# data_transform = DataAugmentationDINO(
# cfg.crops.global_crops_scale,
# cfg.crops.local_crops_scale,
# cfg.crops.local_crops_number,
# global_crops_size=cfg.crops.global_crops_size,
# local_crops_size=cfg.crops.local_crops_size,
# )
# collate_fn = partial(
# collate_data_and_cast,
# mask_ratio_tuple=cfg.ibot.mask_ratio_min_max,
# mask_probability=cfg.ibot.mask_sample_probability,
# n_tokens=n_tokens,
# mask_generator=mask_generator,
# dtype=inputs_dtype,
# )
# # setup data loader
# dataset = make_dataset(
# dataset_str=cfg.train.dataset_path,
# transform=data_transform,
# target_transform=lambda _: (),
# )
# # sampler_type = SamplerType.INFINITE
# sampler_type = SamplerType.SHARDED_INFINITE
# data_loader = make_data_loader(
# dataset=dataset,
# batch_size=cfg.train.batch_size_per_gpu,
# num_workers=cfg.train.num_workers,
# shuffle=True,
# seed=start_iter, # TODO: Fix this -- cfg.train.seed
# sampler_type=sampler_type,
# sampler_advance=0, # TODO(qas): fix this -- start_iter * cfg.train.batch_size_per_gpu,
# drop_last=True,
# collate_fn=collate_fn,
# )
# # training loop
# iteration = start_iter
# logger.info("Starting training from iteration {}".format(start_iter))
# metrics_file = os.path.join(cfg.train.output_dir, "training_metrics.json")
# metric_logger = MetricLogger(delimiter=" ", output_file=metrics_file)
# header = "Training"
# for data in metric_logger.log_every(
# data_loader,
# 10,
# header,
# max_iter,
# start_iter,
# ):
# current_batch_size = data["collated_global_crops"].shape[0] / 2
# if iteration > max_iter:
# return
# # apply schedules
# lr = lr_schedule[iteration]
# wd = wd_schedule[iteration]
# mom = momentum_schedule[iteration]
# teacher_temp = teacher_temp_schedule[iteration]
# last_layer_lr = last_layer_lr_schedule[iteration]
# apply_optim_scheduler(optimizer, lr, wd, last_layer_lr)
# # compute losses
# optimizer.zero_grad(set_to_none=True)
# loss_dict = model.forward_backward(data, teacher_temp=teacher_temp)
# # clip gradients
# if fp16_scaler is not None:
# if cfg.optim.clip_grad:
# fp16_scaler.unscale_(optimizer)
# for v in model.student.values():
# v.clip_grad_norm_(cfg.optim.clip_grad)
# fp16_scaler.step(optimizer)
# fp16_scaler.update()
# else:
# if cfg.optim.clip_grad:
# for v in model.student.values():
# v.clip_grad_norm_(cfg.optim.clip_grad)
# optimizer.step()
# # perform teacher EMA update
# model.update_teacher(mom)
# # logging
# if distributed.get_global_size() > 1:
# for v in loss_dict.values():
# torch.distributed.all_reduce(v)
# loss_dict_reduced = {k: v.item() / distributed.get_global_size() for k, v in loss_dict.items()}
# if math.isnan(sum(loss_dict_reduced.values())):
# logger.info("NaN detected")
# raise AssertionError
# losses_reduced = sum(loss for loss in loss_dict_reduced.values())
# metric_logger.update(lr=lr)
# metric_logger.update(wd=wd)
# metric_logger.update(mom=mom)
# metric_logger.update(last_layer_lr=last_layer_lr)
# metric_logger.update(current_batch_size=current_batch_size)
# metric_logger.update(total_loss=losses_reduced, **loss_dict_reduced)
# writer.add_scalar('Loss/Total_Loss', losses_reduced, iteration)
# writer.add_scalar('Learning_Rate', lr, iteration)
# writer.add_scalar('Weight_Decay', wd, iteration)
# writer.add_scalar('Momentum', mom, iteration)
# # checkpointing and testing
# if cfg.evaluation.eval_period_iterations > 0 and (iteration + 1) % cfg.evaluation.eval_period_iterations == 0:
# do_test(cfg, model, f"training_{iteration}")
# torch.cuda.synchronize()
# periodic_checkpointer.step(iteration)
# iteration = iteration + 1
# metric_logger.synchronize_between_processes()
# writer.close()
# return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
def do_train(cfg, model, resume=False):
writer = SummaryWriter(log_dir="/home/paperspace/Documents/nika_space/dinov2/dinov2/tensorboard_logs/standart_custom_conf")
from dinov2.data import SamplerType, make_data_loader, make_dataset
from dinov2.data import collate_data_and_cast, DataAugmentationDINO, MaskingGenerator
writer = SummaryWriter(log_dir="/home/paperspace/Documents/nika_space/dinov2/dinov2/tensorboard_logs/with_merged_blocks_small_lr")
model.train()
inputs_dtype = torch.half
fp16_scaler = model.fp16_scaler # for mixed precision training
# setup optimizer
# # Freeze the original DINO weights
# for param in model.parameters():
# param.requires_grad = False
# # Unfreeze the pre-encoder, model adapter, and merge blocks
# for param in model.pre_encoder.parameters():
# param.requires_grad = True
# for param in model.model_adapter.parameters():
# param.requires_grad = True
# for block in model.merge_blocks:
# for param in block.parameters():
# param.requires_grad = True
# setup optimizer
print("Cfg optim: ", cfg.optim)
optimizer = build_optimizer(cfg, model.get_params_groups())
# optimizer = build_optimizer(cfg, [p for p in model.parameters() if p.requires_grad])
# trainable_params = [p for p in model.parameters() if p.requires_grad]
# optimizer = torch.optim.AdamW(trainable_params, lr=lr)
(
lr_schedule,
wd_schedule,
@ -221,6 +443,7 @@ def do_train(cfg, model, resume=False):
metrics_file = os.path.join(cfg.train.output_dir, "training_metrics.json")
metric_logger = MetricLogger(delimiter=" ", output_file=metrics_file)
header = "Training"
warmup_iterations = cfg.optim.warmup_epochs * OFFICIAL_EPOCH_LENGTH
for data in metric_logger.log_every(
data_loader,
@ -233,6 +456,12 @@ def do_train(cfg, model, resume=False):
if iteration > max_iter:
return
# Unfreeze original DINO weights after warmup
if iteration == warmup_iterations:
for param in model.parameters():
param.requires_grad = True
# apply schedules
lr = lr_schedule[iteration]
@ -263,7 +492,6 @@ def do_train(cfg, model, resume=False):
optimizer.step()
# perform teacher EMA update
model.update_teacher(mom)
# logging
@ -315,16 +543,17 @@ def parse_merge_block_indexes(config_value: str) -> List[int]:
return []
return list(map(int, re.split(r'\s*,\s*', config_value)))
def main(args):
cfg = setup(args)
print("cfg.student.merge_block_indexes ", cfg.student.merge_block_indexes)
cfg.student.merge_block_indexes = parse_merge_block_indexes(cfg.student.merge_block_indexes)
print("cfg.student.merge_block_indexes:", cfg.student.merge_block_indexes)
# cfg.student.merge_block_indexes = parse_merge_block_indexes(cfg.student.merge_block_indexes)
# print("INDEXES", cfg.student.merge_block_indexes)
model = SSLMetaArch(cfg).to(torch.device("cuda"))
model.prepare_for_distributed_training()
# logger.info("Model:\n{}".format(model))
# model.freeze_original_dino_weights()
if args.eval_only:
iteration = (
FSDPCheckpointer(model, save_dir=cfg.train.output_dir)
@ -333,7 +562,7 @@ def main(args):
+ 1
)
return do_test(cfg, model, f"manual_{iteration}")
do_train(cfg, model, resume=not args.no_resume)

View File

@ -18,17 +18,26 @@ logger = logging.getLogger("dinov2")
def load_pretrained_weights(model, pretrained_weights, checkpoint_key):
if urlparse(pretrained_weights).scheme: # If it looks like an URL
state_dict = torch.hub.load_state_dict_from_url(pretrained_weights, map_location="cpu")
else:
state_dict = torch.load(pretrained_weights, map_location="cpu")
if checkpoint_key is not None and checkpoint_key in state_dict:
state_dict = state_dict['model']
if checkpoint_key is not None:
print("INSIDE IF")
logger.info(f"Take key {checkpoint_key} in provided checkpoint dict")
state_dict = state_dict[checkpoint_key]
state_dict = {key: value for key, value in state_dict.items() if key.startswith(checkpoint_key) }
# state_dict = state_dict[checkpoint_key]
# remove `module.` prefix
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
# remove `backbone.` prefix induced by multicrop wrapper
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
state_dict = {k.replace("teacher.backbone.", ""): v for k, v in state_dict.items()}
# print("Model Keys:")
# for key in state_dict.keys():
# print(key)
# print("State dict: ", state_dict.keys())
msg = model.load_state_dict(state_dict, strict=False)
logger.info("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg))