fsdp for raw adapter
parent
63dfb318ec
commit
753b990c64
|
@ -5,38 +5,76 @@ compute_precision:
|
|||
backbone:
|
||||
sharding_strategy: SHARD_GRAD_OP
|
||||
mixed_precision:
|
||||
param_dtype: fp16
|
||||
reduce_dtype: fp16
|
||||
param_dtype: fp32
|
||||
reduce_dtype: fp32
|
||||
buffer_dtype: fp32
|
||||
dino_head:
|
||||
sharding_strategy: SHARD_GRAD_OP
|
||||
mixed_precision:
|
||||
param_dtype: fp16
|
||||
reduce_dtype: fp16
|
||||
buffer_dtype: fp32
|
||||
ibot_head:
|
||||
sharding_strategy: SHARD_GRAD_OP
|
||||
mixed_precision:
|
||||
param_dtype: fp16
|
||||
reduce_dtype: fp16
|
||||
buffer_dtype: fp32
|
||||
student:
|
||||
backbone:
|
||||
sharding_strategy: SHARD_GRAD_OP
|
||||
mixed_precision:
|
||||
param_dtype: fp16
|
||||
reduce_dtype: fp16
|
||||
buffer_dtype: fp32
|
||||
dino_head:
|
||||
sharding_strategy: SHARD_GRAD_OP
|
||||
mixed_precision:
|
||||
param_dtype: fp16
|
||||
param_dtype: fp32
|
||||
reduce_dtype: fp32
|
||||
buffer_dtype: fp32
|
||||
ibot_head:
|
||||
sharding_strategy: SHARD_GRAD_OP
|
||||
mixed_precision:
|
||||
param_dtype: fp16
|
||||
param_dtype: fp32
|
||||
reduce_dtype: fp32
|
||||
buffer_dtype: fp32
|
||||
merge_blocks:
|
||||
sharding_strategy: SHARD_GRAD_OP
|
||||
mixed_precision:
|
||||
param_dtype: fp32
|
||||
reduce_dtype: fp32
|
||||
buffer_dtype: fp32
|
||||
pre_encoder:
|
||||
sharding_strategy: SHARD_GRAD_OP
|
||||
mixed_precision:
|
||||
param_dtype: fp32
|
||||
reduce_dtype: fp32
|
||||
buffer_dtype: fp32
|
||||
cast_forward_inputs: False
|
||||
model_adapter:
|
||||
sharding_strategy: SHARD_GRAD_OP
|
||||
mixed_precision:
|
||||
param_dtype: fp32
|
||||
reduce_dtype: fp32
|
||||
buffer_dtype: fp32
|
||||
student:
|
||||
backbone:
|
||||
sharding_strategy: SHARD_GRAD_OP
|
||||
mixed_precision:
|
||||
param_dtype: fp32
|
||||
reduce_dtype: fp32
|
||||
buffer_dtype: fp32
|
||||
dino_head:
|
||||
sharding_strategy: SHARD_GRAD_OP
|
||||
mixed_precision:
|
||||
param_dtype: fp32
|
||||
reduce_dtype: fp32
|
||||
buffer_dtype: fp32
|
||||
ibot_head:
|
||||
sharding_strategy: SHARD_GRAD_OP
|
||||
mixed_precision:
|
||||
param_dtype: fp32
|
||||
reduce_dtype: fp32
|
||||
buffer_dtype: fp32
|
||||
merge_blocks:
|
||||
sharding_strategy: SHARD_GRAD_OP
|
||||
mixed_precision:
|
||||
param_dtype: fp32
|
||||
reduce_dtype: fp32
|
||||
buffer_dtype: fp32
|
||||
pre_encoder:
|
||||
sharding_strategy: SHARD_GRAD_OP
|
||||
mixed_precision:
|
||||
param_dtype: fp32
|
||||
reduce_dtype: fp32
|
||||
buffer_dtype: fp32
|
||||
cast_forward_inputs: False
|
||||
model_adapter:
|
||||
sharding_strategy: SHARD_GRAD_OP
|
||||
mixed_precision:
|
||||
param_dtype: fp32
|
||||
reduce_dtype: fp32
|
||||
buffer_dtype: fp32
|
||||
|
||||
|
@ -48,8 +86,8 @@ data_transform: "default"
|
|||
train:
|
||||
batch_size_per_gpu: 16 #vitg 26+, vitl: 56, vits:152, vitb:120 for 8 node
|
||||
num_workers: 1
|
||||
OFFICIAL_EPOCH_LENGTH: 100 # 1250
|
||||
dataset_path: ImageNet:root=/home/paperspace/Documents/nika_space/ADE20K/ADEChallengeData2016/images/training_raw/
|
||||
OFFICIAL_EPOCH_LENGTH: 1000 # 1250
|
||||
dataset_path: ImageNet:root=/home/paperspace/Documents/nika_space/ADE20K/ADEChallengeData2016/images/training_raw/ #MIT:root=/home/paperspace/Documents/nika_space/mit_dataset/train/
|
||||
centering: sinkhorn_knopp
|
||||
|
||||
drop_path_rate: 0.4
|
||||
|
@ -61,8 +99,8 @@ train:
|
|||
teacher:
|
||||
momentum_teacher: 0.994
|
||||
optim:
|
||||
epochs: 20 # 500
|
||||
weight_decay_end: 0.2
|
||||
epochs: 50 # 500
|
||||
weight_decay_end: 0.3
|
||||
base_lr: 0.0001 # learning rate for a batch size of 1024
|
||||
warmup_epochs: 20 # 80
|
||||
layerwise_decay: 1.0
|
||||
|
@ -76,7 +114,7 @@ evaluation:
|
|||
# "dinov2_vits14","dinov2_vitb14","dinov2_vitl14","dinov2_vitg14"
|
||||
|
||||
student:
|
||||
arch: vit_base
|
||||
arch: dinov2_vitb14
|
||||
patch_size: 14
|
||||
merge_block_indexes: "" # num, num, num,
|
||||
crops:
|
||||
|
|
|
@ -17,6 +17,7 @@ from torch.distributed.fsdp import StateDictType
|
|||
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
|
||||
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
|
||||
from torch.distributed.fsdp._runtime_utils import _reshard
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def get_fsdp_wrapper(model_cfg, modules_to_wrap=set()):
|
||||
|
@ -32,16 +33,22 @@ def get_fsdp_wrapper(model_cfg, modules_to_wrap=set()):
|
|||
"bf16": torch.bfloat16,
|
||||
}
|
||||
|
||||
param_dtype = dtype_dict[model_cfg.mixed_precision.param_dtype]
|
||||
reduce_dtype = dtype_dict[model_cfg.mixed_precision.reduce_dtype]
|
||||
buffer_dtype = dtype_dict[model_cfg.mixed_precision.buffer_dtype]
|
||||
|
||||
mixed_precision_config = MixedPrecision(
|
||||
param_dtype=dtype_dict[model_cfg.mixed_precision.param_dtype],
|
||||
reduce_dtype=dtype_dict[model_cfg.mixed_precision.reduce_dtype],
|
||||
buffer_dtype=dtype_dict[model_cfg.mixed_precision.buffer_dtype],
|
||||
param_dtype=param_dtype,
|
||||
reduce_dtype=reduce_dtype,
|
||||
buffer_dtype=buffer_dtype,
|
||||
cast_forward_inputs=model_cfg.mixed_precision.get("cast_forward_inputs", True)
|
||||
)
|
||||
|
||||
sharding_strategy_config = sharding_strategy_dict[model_cfg.sharding_strategy]
|
||||
|
||||
local_rank = distributed.get_local_rank()
|
||||
|
||||
print("Modules to wrap: ", modules_to_wrap)
|
||||
fsdp_wrapper = partial(
|
||||
FSDP,
|
||||
sharding_strategy=sharding_strategy_config,
|
||||
|
@ -51,6 +58,7 @@ def get_fsdp_wrapper(model_cfg, modules_to_wrap=set()):
|
|||
use_orig_params=True,
|
||||
auto_wrap_policy=ModuleWrapPolicy(modules_to_wrap),
|
||||
)
|
||||
|
||||
return fsdp_wrapper
|
||||
|
||||
|
||||
|
@ -112,7 +120,7 @@ class FSDPCheckpointer(Checkpointer):
|
|||
self.tag_last_checkpoint(basename)
|
||||
|
||||
def load(self, *args, **kwargs):
|
||||
with FSDP.state_dict_type(self.model, StateDictType.LOCAL_STATE_DICT):
|
||||
with FSDP.state_dict_type(self.model, StateDictType.FULL_STATE_DICT):
|
||||
return super().load(*args, **kwargs)
|
||||
|
||||
def has_checkpoint(self) -> bool:
|
||||
|
|
|
@ -28,6 +28,7 @@ def _make_dinov2_model(
|
|||
interpolate_offset: float = 0.1,
|
||||
pretrained: bool = True,
|
||||
weights: Union[Weights, str] = Weights.LVD142M,
|
||||
merge_block_indexes: list[int] = [0],
|
||||
**kwargs,
|
||||
):
|
||||
from ..models import vision_transformer as vits
|
||||
|
@ -48,8 +49,10 @@ def _make_dinov2_model(
|
|||
num_register_tokens=num_register_tokens,
|
||||
interpolate_antialias=interpolate_antialias,
|
||||
interpolate_offset=interpolate_offset,
|
||||
merge_block_indexes=merge_block_indexes
|
||||
)
|
||||
vit_kwargs.update(**kwargs)
|
||||
print("Kwargs: ", vit_kwargs)
|
||||
model = vits.__dict__[arch_name](**vit_kwargs)
|
||||
|
||||
if pretrained:
|
||||
|
|
|
@ -26,8 +26,10 @@ def build_model(args, only_teacher=False, img_size=224):
|
|||
num_register_tokens=args.num_register_tokens,
|
||||
interpolate_offset=args.interpolate_offset,
|
||||
interpolate_antialias=args.interpolate_antialias,
|
||||
# merge_blocks_indexes=args.merge_block_indexes,
|
||||
merge_block_indexes=args.student.merge_block_indexes,
|
||||
)
|
||||
print( vits.__dict__.keys())
|
||||
print("Merge block ind: ", args.student.merge_block_indexes)
|
||||
teacher = vits.__dict__[args.arch](**vit_kwargs)
|
||||
if only_teacher:
|
||||
return teacher, teacher.embed_dim
|
||||
|
@ -41,4 +43,5 @@ def build_model(args, only_teacher=False, img_size=224):
|
|||
|
||||
|
||||
def build_model_from_cfg(cfg, only_teacher=False):
|
||||
print("Only teacher", only_teacher)
|
||||
return build_model(cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size)
|
||||
|
|
|
@ -25,28 +25,30 @@ class Merge_block(BaseModule):
|
|||
self.ada_c = ada_c
|
||||
# 784 - embedded dim + adapter_c
|
||||
self.embeded_dim = 768
|
||||
self.fc_1 = nn.Linear(self.embeded_dim*2, mid_c).to(torch.float16)
|
||||
self.fc_2 = nn.Linear(mid_c, self.embeded_dim).to(torch.float16)
|
||||
self.fc_1 = nn.Linear(self.embeded_dim*2, mid_c)
|
||||
print("Fc 1 type: ", self.fc_1.weight.dtype, self.fc_1.bias.dtype)
|
||||
self.fc_2 = nn.Linear(mid_c, self.embeded_dim)
|
||||
self.return_ada = return_ada
|
||||
|
||||
if self.return_ada:
|
||||
self.conv_3 = nn.Conv1d(mid_c, self.embeded_dim, kernel_size=1).to(torch.float16) # 1D Conv instead of 3x3
|
||||
self.conv_3 = nn.Conv1d(mid_c, self.embeded_dim, kernel_size=1) # 1D Conv instead of 3x3
|
||||
else:
|
||||
self.conv_3 = None
|
||||
|
||||
def forward(self, fea, adapter, ratio=1.0):
|
||||
res = fea
|
||||
# print("Before concatenation: ", fea.shape, adapter.shape, self.fea_c, self.ada_c)
|
||||
# print("before concatenation: ", fea.shape, adapter.shape)
|
||||
# print("before concatenation: ", fea.dtype, adapter.dtype)
|
||||
fea = torch.cat([fea, adapter], dim=-1) # (B, seq_len, fea_c + ada_c)
|
||||
# print("after concatenation: ", fea.shape, adapter.shape)
|
||||
B, seq_len, C = fea.shape
|
||||
fea = fea.view(B * seq_len, C)
|
||||
# print("before concatenation: ", fea.dtype, adapter.dtype)
|
||||
fea = fea.to(self.fc_1.weight.dtype)
|
||||
fea = self.fc_1(fea)
|
||||
fea = fea.view(B, seq_len, -1)
|
||||
ada = self.fc_2(fea)
|
||||
fea_out = ratio * ada + res
|
||||
|
||||
if self.return_ada:
|
||||
|
||||
ada = self.conv_3(fea.permute(0, 2, 1))
|
||||
|
@ -313,7 +315,7 @@ class CustomLayerNorm(nn.Module):
|
|||
|
||||
# Predictor P_K
|
||||
class Kernel_Predictor(nn.Module):
|
||||
def __init__(self, dim, mode='low', num_heads=1, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
|
||||
def __init__(self, dim, mode='normal', num_heads=1, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
|
@ -546,7 +548,7 @@ def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Te
|
|||
return img
|
||||
|
||||
|
||||
def Gain_Denoise(I1, r1, r2, gain, sigma, k_size=3): # [9, 9] in LOD dataset, [3, 3] in other dataset
|
||||
def Gain_Denoise(I1, r1, r2, gain, sigma, k_size=1): # [9, 9] in LOD dataset, [3, 3] in other dataset
|
||||
out = []
|
||||
for i in range(I1.shape[0]):
|
||||
I1_gain = gain[i] * I1[i,:,:,:]
|
||||
|
@ -583,11 +585,11 @@ def WB_CCM(I2, ccm_matrix, distance):
|
|||
out_I4 = []
|
||||
for i in range(I2.shape[0]):
|
||||
# SOG White Balance Algorithm
|
||||
I3 = SoG_algo(I2[i,:,:,:], distance[i])
|
||||
I3 = SoG_algo(I2[i,:,:,:])
|
||||
|
||||
# Camera Color Matrix
|
||||
I4 = torch.tensordot(I3, ccm_matrix[i,:,:], dims=[[-1], [-1]])
|
||||
I4 = torch.clamp(I4, 1e-5, 1.0)
|
||||
I4 = torch.clamp(I4, 1e-7, 1.0)
|
||||
|
||||
out_I3.append(I3)
|
||||
out_I4.append(I4)
|
||||
|
@ -620,7 +622,7 @@ class VitInputLevelAdapter(nn.Module):
|
|||
# (1). I1 --> I2: Denoise & Enhancement & Sharpen
|
||||
r1, r2, gain, sigma = self.Predictor_K(I1)
|
||||
I2 = Gain_Denoise(I1, r1, r2, gain, sigma, k_size=self.k_size) # (B,C,H,W)
|
||||
I2 = torch.clamp(I2, 1e-5, 1.0) # normal & over-exposure
|
||||
I2 = torch.clamp(I2, 1e-7, 1.0) # normal & over-exposure
|
||||
|
||||
ccm_matrix, distance = self.Predictor_M(I2)
|
||||
# (2). I2 --> I3: White Balance, Shade of Gray
|
||||
|
|
|
@ -19,7 +19,7 @@ from torch.nn.init import trunc_normal_
|
|||
import torch.nn.functional as F
|
||||
from dinov2.layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
|
||||
from dinov2.models.help import Merge_block, Model_level_Adapeter
|
||||
from dinov2.models.help import VitInputLevelAdapter as Input_level_Adapeter
|
||||
from dinov2.models.input_level_adapter import Input_level_Adapeter
|
||||
|
||||
|
||||
logger = logging.getLogger("dinov2")
|
||||
|
@ -68,17 +68,18 @@ class DinoVisionTransformer(nn.Module):
|
|||
interpolate_antialias=False,
|
||||
interpolate_offset=0.1,
|
||||
# RAW adapter parameters
|
||||
w_lut=False,
|
||||
w_lut=True,
|
||||
light_mode='normal',
|
||||
lut_dim=32,
|
||||
k_size=3,
|
||||
merge_ratio=1.0,
|
||||
model_adapter_path='/home/paperspace/Documents/nika_space/ECCV_RAW_Adapter/extracted_model_adapter_weights.pth',
|
||||
input_level_adapter_path='/home/paperspace/Documents/nika_space/ECCV_RAW_Adapter/extracted_pre_encoder_weights.pth',
|
||||
input_level_adapter_path='/home/paperspace/Documents/nika_space/ECCV_RAW_Adapter/mmsegmentation_github/extracted_pre_encoder_weights.pth',
|
||||
fea_c_s = [384, 768, 1920],
|
||||
ada_c_s = [16, 32, 64],
|
||||
mid_c_s = [384, 576, 768],
|
||||
# merge_blocks_indexes=[],
|
||||
merge_block_indexes=[],
|
||||
is_teacher=False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
|
@ -190,28 +191,31 @@ class DinoVisionTransformer(nn.Module):
|
|||
self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
|
||||
|
||||
# Initialize RAW adapter
|
||||
# self.merge_block_indexes = merge_block_indexes
|
||||
self.merge_block_indexes = [0, 6, 10]
|
||||
print("Before if:")
|
||||
if self.w_lut:
|
||||
self.pre_encoder = Input_level_Adapeter(mode=light_mode, lut_dim=lut_dim, k_size=k_size, w_lut=w_lut)
|
||||
for param in self.pre_encoder.parameters():
|
||||
param.requires_grad_(True)
|
||||
|
||||
# self.model_adapter = Model_level_Adapeter(in_c=in_chans, w_lut=w_lut)
|
||||
self.model_adapter = Model_level_Adapeter(in_c=in_chans, w_lut=w_lut)
|
||||
self.model_adapter = Model_level_Adapeter(in_c=3, in_dim=ada_c_s[0], w_lut=self.w_lut)
|
||||
# if model_adapter_path is not None:
|
||||
# print("Loading model adapter:", model_adapter_path)
|
||||
# adapter_state = torch.load(model_adapter_path, map_location="cpu")
|
||||
# self.model_adapter.load_state_dict(adapter_state, strict=False)
|
||||
# if input_level_adapter_path is not None:
|
||||
# print("Loading input-level adapter:", input_level_adapter_path)
|
||||
# adapter_state = torch.load(input_level_adapter_path, map_location="cpu")
|
||||
# self.pre_encoder.load_state_dict(adapter_state)
|
||||
if model_adapter_path is not None:
|
||||
print("Loading model adapter:", model_adapter_path)
|
||||
adapter_state = torch.load(model_adapter_path, map_location="cpu")
|
||||
self.model_adapter.load_state_dict(adapter_state, strict=False)
|
||||
if input_level_adapter_path is not None:
|
||||
print("Loading input-level adapter:", input_level_adapter_path)
|
||||
adapter_state = torch.load(input_level_adapter_path, map_location="cpu")
|
||||
self.pre_encoder.load_state_dict(adapter_state)
|
||||
|
||||
self.merge_blocks = []
|
||||
self.merge_blocks_indexes = merge_blocks_indexes
|
||||
|
||||
# Loop through the merge_blocks_indexes and create Merge_block instances
|
||||
for i, idx in enumerate(self.merge_blocks_indexes):
|
||||
return_ada = False if i == len(self.merge_blocks_indexes) - 1 else True # Only the last block gets return_ada=False
|
||||
if i != 0 or i != len(self.merge_blocks_indexes) - 1:
|
||||
for i, idx in enumerate(self.merge_block_indexes):
|
||||
return_ada = False if i == len(self.merge_block_indexes) - 1 else True # Only the last block gets return_ada=False
|
||||
if i != 0 or i != len(self.merge_block_indexes) - 1:
|
||||
k = 1
|
||||
else:
|
||||
k = i
|
||||
|
@ -221,12 +225,57 @@ class DinoVisionTransformer(nn.Module):
|
|||
mid_c=mid_c_s[k],
|
||||
return_ada=return_ada
|
||||
).to("cuda")
|
||||
self.merge_blocks.append(merge_block)
|
||||
# self.merge_blocks.to("cuda")
|
||||
print(self.merge_blocks)
|
||||
|
||||
# merge_block_state = torch.load("/home/paperspace/Documents/nika_space/ECCV_RAW_Adapter/mmsegmentation_github/extracted_merged_blocks_weights_2.pth", map_location="cpu")
|
||||
# merge_block.load_state_dict(merge_block_state)
|
||||
# if is_teacher:
|
||||
# for param in merge_block.parameters():
|
||||
# param.requires_grad = False
|
||||
|
||||
self.merge_blocks.append(merge_block)
|
||||
# self.merge_blocks.to("cuda")
|
||||
print("MERGED BLOCKS:", self.merge_block_indexes)
|
||||
|
||||
|
||||
# # Freeze the patch embedding
|
||||
# for param in self.patch_embed.parameters():
|
||||
# param.requires_grad_(False)
|
||||
|
||||
# # Freeze the position embedding
|
||||
# if hasattr(self, 'pos_embed') and self.pos_embed is not None:
|
||||
# self.pos_embed.requires_grad_(False)
|
||||
|
||||
# # Freeze the cls token if it exists
|
||||
# if hasattr(self, 'cls_token') and self.cls_token is not None:
|
||||
# self.cls_token.requires_grad_(False)
|
||||
|
||||
# # Freeze the transformer blocks
|
||||
# for block in self.blocks:
|
||||
# for param in block.parameters():
|
||||
# param.requires_grad_(False)
|
||||
|
||||
# # Freeze the norm layer if it exists
|
||||
# if hasattr(self, 'norm') and self.norm is not None:
|
||||
# for param in self.norm.parameters():
|
||||
# param.requires_grad_(False)
|
||||
|
||||
# # Freeze any head/projection layers
|
||||
# if hasattr(self, 'head') and self.head is not None:
|
||||
# for param in self.head.parameters():
|
||||
# param.requires_grad_(False)
|
||||
|
||||
self.init_weights()
|
||||
# self.freeze_dino_weights()
|
||||
|
||||
|
||||
def freeze_dino_weights(self):
|
||||
"""Freeze all original DINO weights and keep only the adapter trainable."""
|
||||
for name, param in self.named_parameters():
|
||||
# Unfreeze only pre_encoder, model_adapter, and merge_blocks
|
||||
if not any(sub in name for sub in ["pre_encoder", "model_adapter", "merge_blocks"]):
|
||||
param.requires_grad = False
|
||||
else:
|
||||
param.requires_grad = True # Ensure adapters are trainable
|
||||
|
||||
def init_weights(self):
|
||||
trunc_normal_(self.pos_embed, std=0.02)
|
||||
|
@ -284,33 +333,33 @@ class DinoVisionTransformer(nn.Module):
|
|||
|
||||
def prepare_tokens_with_masks(self, x, masks=None):
|
||||
B, nc, w, h = x.shape
|
||||
# print("BLOCKS NUM: " , len(self.blocks), len(self.merge_blocks))
|
||||
|
||||
|
||||
if self.w_lut: # I1, I2, I3, I4
|
||||
x_raw = self.pre_encoder(x)
|
||||
if self.w_lut: # I1, I2, I3, I4
|
||||
ada = self.model_adapter([x_raw[0], x_raw[1], x_raw[2], x_raw[3]])
|
||||
# else: # I1, I2, I3
|
||||
# ada = self.model_adapter([x_raw[0], x_raw[1], x_raw[2]])
|
||||
else: # I1, I2, I3
|
||||
ada = self.model_adapter([x_raw[0], x_raw[1], x_raw[2]])
|
||||
|
||||
# x = x_raw[-1]
|
||||
x = x_raw[-1]
|
||||
# print("X before patch embedding ",ada.shape, x.shape )
|
||||
x = self.patch_embed(x)
|
||||
# if x.shape[1] == 256:
|
||||
# ada = F.interpolate(ada, size=(64, 64), mode='bilinear', align_corners=False)
|
||||
# elif x.shape[1] == 49:
|
||||
# ada = F.interpolate(ada, size=(28, 28), mode='bilinear', align_corners=False)
|
||||
if x.shape[1] == 256:
|
||||
ada = F.interpolate(ada, size=(64, 64), mode='bilinear', align_corners=False)
|
||||
elif x.shape[1] == 49:
|
||||
ada = F.interpolate(ada, size=(28, 28), mode='bilinear', align_corners=False)
|
||||
# print("ada.shape ", ada.shape, x.shape)
|
||||
# ada = self.patch_embed_for_model_adapter(ada)
|
||||
ada = self.patch_embed_for_model_adapter(ada)
|
||||
# tensor2_reshaped = ada.transpose(1, 2) # [32, 768, 196]
|
||||
# print("ada.shape after embedding ", ada.shape, x.shape)
|
||||
|
||||
if masks is not None:
|
||||
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
|
||||
# ada = torch.where(masks.unsqueeze(-1), self.mask_token.to(ada.dtype).unsqueeze(0), ada)
|
||||
ada = torch.where(masks.unsqueeze(-1), self.mask_token.to(ada.dtype).unsqueeze(0), ada)
|
||||
|
||||
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
||||
# ada = torch.cat((self.cls_token.expand(ada.shape[0], -1, -1), ada), dim=1)
|
||||
# ada = ada + self.interpolate_pos_encoding(ada, w, h)
|
||||
ada = torch.cat((self.cls_token.expand(ada.shape[0], -1, -1), ada), dim=1)
|
||||
ada = ada + self.interpolate_pos_encoding(ada, w, h)
|
||||
|
||||
x = x + self.interpolate_pos_encoding(x, w, h)
|
||||
|
||||
|
@ -328,8 +377,8 @@ class DinoVisionTransformer(nn.Module):
|
|||
# print("x.shape",x.shape)
|
||||
|
||||
|
||||
# return x, ada
|
||||
return x, None
|
||||
return x, ada
|
||||
# return x, None
|
||||
|
||||
def forward_features_list(self, x_list, masks_list):
|
||||
|
||||
|
@ -349,7 +398,8 @@ class DinoVisionTransformer(nn.Module):
|
|||
x = blk(x)
|
||||
|
||||
|
||||
if self.w_lut and ada is not None and i in self.merge_blocks_indexes:
|
||||
if self.w_lut and ada is not None and i in self.merge_block_indexes:
|
||||
# print("ERROR!")
|
||||
x_ada_pairs = [self.merge_blocks[indx](x_i, ada_i, ratio=self.merge_ratio) for x_i, ada_i in zip(x, ada_list)]
|
||||
x, ada_list = map(list, zip(*x_ada_pairs))
|
||||
indx += 1
|
||||
|
@ -376,9 +426,11 @@ class DinoVisionTransformer(nn.Module):
|
|||
x, ada = self.prepare_tokens_with_masks(x, masks)
|
||||
indx = 0
|
||||
for i, blk in enumerate(self.blocks):
|
||||
# print("X type: ", x.dtype)
|
||||
x = blk(x)
|
||||
if self.w_lut and ada is not None and i in self.merge_blocks_indexes:
|
||||
if self.w_lut and ada is not None and i in self.merge_block_indexes:
|
||||
# print("HERE 11", x.shape, ada.shape)
|
||||
# print("ERROR!")
|
||||
x, ada = self.merge_blocks[indx](x, ada, ratio=self.merge_ratio)
|
||||
indx += 1
|
||||
|
||||
|
@ -392,7 +444,7 @@ class DinoVisionTransformer(nn.Module):
|
|||
}
|
||||
|
||||
def _get_intermediate_layers_not_chunked(self, x, n=1):
|
||||
x = self.prepare_tokens_with_masks(x)
|
||||
x, ada = self.prepare_tokens_with_masks(x)
|
||||
# If n is an int, take the n last blocks. If it's a list, take them
|
||||
output, total_block_len = [], len(self.blocks)
|
||||
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
||||
|
@ -404,7 +456,7 @@ class DinoVisionTransformer(nn.Module):
|
|||
return output
|
||||
|
||||
def _get_intermediate_layers_chunked(self, x, n=1):
|
||||
x = self.prepare_tokens_with_masks(x)
|
||||
x, ada = self.prepare_tokens_with_masks(x)
|
||||
output, i, total_block_len = [], 0, len(self.blocks[-1])
|
||||
# If n is an int, take the n last blocks. If it's a list, take them
|
||||
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
||||
|
|
|
@ -26,6 +26,39 @@ except ImportError:
|
|||
|
||||
|
||||
logger = logging.getLogger("dinov2")
|
||||
import math
|
||||
def interpolate_pos_encoding(x, w, h):
|
||||
N = x.shape[1] - 1
|
||||
dim = x.shape[-1]
|
||||
w0 = w / int(math.sqrt(N))
|
||||
h0 = h / int(math.sqrt(N))
|
||||
|
||||
# Interpolate the position embeddings without changing the first row (class token)
|
||||
patch_pos_embed = nn.functional.interpolate(
|
||||
x[:, 1:].reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
|
||||
scale_factor=(w0, h0),
|
||||
mode="bicubic",
|
||||
)
|
||||
|
||||
# assert int(w0) == patch_pos_embed.shape[-2]
|
||||
# assert int(h0) == patch_pos_embed.shape[-1]
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||
|
||||
# Concatenate the class token with the interpolated position embeddings
|
||||
return torch.cat((x[:, :1], patch_pos_embed), dim=1)
|
||||
|
||||
|
||||
def get_downloaded_dino_vit_interpolated(modelname="dinov2_vits14", merge_block_indexes=[], is_teacher=False):
|
||||
print("HERE !")
|
||||
model = torch.hub.load("facebookresearch/dinov2", modelname, pretrained=False, merge_block_indexes=merge_block_indexes, is_teacher=is_teacher) #
|
||||
print("HERE 2")
|
||||
input_tensor = model.pos_embed
|
||||
input_tensor = input_tensor.to('cuda')
|
||||
tensor_corr_shape = interpolate_pos_encoding(input_tensor, 16, 16)
|
||||
pos_embed = nn.Parameter(torch.zeros(1, 257))
|
||||
pos_embed.data = tensor_corr_shape
|
||||
model.pos_embed = pos_embed
|
||||
return model
|
||||
|
||||
|
||||
class SSLMetaArch(nn.Module):
|
||||
|
@ -37,7 +70,14 @@ class SSLMetaArch(nn.Module):
|
|||
student_model_dict = dict()
|
||||
teacher_model_dict = dict()
|
||||
|
||||
student_backbone, teacher_backbone, embed_dim = build_model_from_cfg(cfg)
|
||||
if cfg.student.arch in ["dinov2_vits14", "dinov2_vitb14", "dinov2_vitl14", "dinov2_vitg14"]:
|
||||
print("Load pre-trained encoder:")
|
||||
student_backbone = get_downloaded_dino_vit_interpolated(cfg.student.arch, cfg.student.merge_block_indexes)
|
||||
teacher_backbone = get_downloaded_dino_vit_interpolated(cfg.student.arch, cfg.student.merge_block_indexes, is_teacher=True)
|
||||
embed_dict = {"dinov2_vits14": 384, "dinov2_vitb14": 768, "dinov2_vitl14": 1024, "dinov2_vitg14": 1536}
|
||||
embed_dim = embed_dict[cfg.student.arch]
|
||||
else:
|
||||
student_backbone, teacher_backbone, embed_dim = build_model_from_cfg(cfg)
|
||||
student_model_dict["backbone"] = student_backbone
|
||||
teacher_model_dict["backbone"] = teacher_backbone
|
||||
logger.info(f"OPTIONS -- architecture : embed_dim: {embed_dim}")
|
||||
|
@ -47,8 +87,9 @@ class SSLMetaArch(nn.Module):
|
|||
logger.info(f"OPTIONS -- pretrained weights: loading from {cfg.student.pretrained_weights}")
|
||||
student_backbone.load_state_dict(chkpt["model"], strict=False)
|
||||
|
||||
# if cfg.student.merge_block_indexes:
|
||||
# merge_blocks_ind = cfg.student.merge_block_indexes
|
||||
print("Before IF")
|
||||
if cfg.student.merge_block_indexes:
|
||||
self.merge_blocks_ind = cfg.student.merge_block_indexes
|
||||
|
||||
self.embed_dim = embed_dim
|
||||
self.dino_out_dim = cfg.dino.head_n_prototypes
|
||||
|
@ -115,6 +156,7 @@ class SSLMetaArch(nn.Module):
|
|||
|
||||
self.need_to_synchronize_fsdp_streams = True
|
||||
|
||||
print("Student model dict: ", student_model_dict)
|
||||
self.student = nn.ModuleDict(student_model_dict)
|
||||
self.teacher = nn.ModuleDict(teacher_model_dict)
|
||||
|
||||
|
@ -392,12 +434,216 @@ class SSLMetaArch(nn.Module):
|
|||
|
||||
def prepare_for_distributed_training(self):
|
||||
logger.info("DISTRIBUTED FSDP -- preparing model for distributed training")
|
||||
if has_batchnorms(self.student):
|
||||
raise NotImplementedError
|
||||
# if has_batchnorms(self.student):
|
||||
# raise NotImplementedError
|
||||
# below will synchronize all student subnetworks across gpus:
|
||||
# print("Self.student: ", self.student.keys())
|
||||
# for name, param in self.named_parameters():
|
||||
# # Unfreeze only pre_encoder, model_adapter, and merge_blocks
|
||||
# if not any(sub in name for sub in ["pre_encoder", "model_adapter", "merge_blocks"]):
|
||||
# param.requires_grad = False
|
||||
# else:
|
||||
# param.requires_grad = True
|
||||
|
||||
print("Student keys: ", self.student.items())
|
||||
for k, v in self.student.items():
|
||||
self.teacher[k].load_state_dict(self.student[k].state_dict())
|
||||
student_model_cfg = self.cfg.compute_precision.student[k]
|
||||
print("Cfg: ", student_model_cfg)
|
||||
|
||||
self.student[k] = get_fsdp_wrapper(student_model_cfg, modules_to_wrap={BlockChunk})(self.student[k])
|
||||
teacher_model_cfg = self.cfg.compute_precision.teacher[k]
|
||||
self.teacher[k] = get_fsdp_wrapper(teacher_model_cfg, modules_to_wrap={BlockChunk})(self.teacher[k])
|
||||
|
||||
print("Type: ", type(self.student["backbone"].pre_encoder))
|
||||
# if self.student["backbone"].pre_encoder:
|
||||
# cfg = self.cfg.compute_precision.student.pre_encoder
|
||||
# self.student["backbone"].pre_encoder = get_fsdp_wrapper(cfg, modules_to_wrap={nn.Module})(self.student["backbone"].pre_encoder)
|
||||
|
||||
cfg = self.cfg.compute_precision.student.pre_encoder
|
||||
# if not isinstance(self.student["backbone"].pre_encoder.Predictor_K, FSDP):
|
||||
# self.student["backbone"].pre_encoder.Predictor_K = get_fsdp_wrapper(cfg, modules_to_wrap={nn.Linear, nn.Conv2d, nn.LayerNorm})(self.student["backbone"].pre_encoder.Predictor_K )
|
||||
import dinov2.distributed as distributed
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
self.student["backbone"].pre_encoder.Predictor_K = FSDP(
|
||||
self.student["backbone"].pre_encoder.Predictor_K,
|
||||
sharding_strategy=cfg.sharding_strategy,
|
||||
mixed_precision=cfg.mixed_precision,
|
||||
device_id=distributed.get_local_rank(),
|
||||
sync_module_states=True,
|
||||
use_orig_params=True,
|
||||
)
|
||||
self.teacher["backbone"].pre_encoder.Predictor_K = FSDP(
|
||||
self.teacher["backbone"].pre_encoder.Predictor_K,
|
||||
sharding_strategy=cfg.sharding_strategy,
|
||||
mixed_precision=cfg.mixed_precision,
|
||||
device_id=distributed.get_local_rank(),
|
||||
sync_module_states=True,
|
||||
use_orig_params=True,
|
||||
)
|
||||
# if not isinstance(self.student["backbone"].pre_encoder.Predictor_M, FSDP):
|
||||
# self.student["backbone"].pre_encoder.Predictor_M = get_fsdp_wrapper(cfg, modules_to_wrap={nn.Module})(self.student["backbone"].pre_encoder.Predictor_M )
|
||||
self.student["backbone"].pre_encoder.Predictor_M = FSDP(
|
||||
self.student["backbone"].pre_encoder.Predictor_M,
|
||||
sharding_strategy=cfg.sharding_strategy,
|
||||
mixed_precision=cfg.mixed_precision,
|
||||
device_id=distributed.get_local_rank(),
|
||||
sync_module_states=True,
|
||||
use_orig_params=True,
|
||||
)
|
||||
self.teacher["backbone"].pre_encoder.Predictor_M = FSDP(
|
||||
self.teacher["backbone"].pre_encoder.Predictor_M,
|
||||
sharding_strategy=cfg.sharding_strategy,
|
||||
mixed_precision=cfg.mixed_precision,
|
||||
device_id=distributed.get_local_rank(),
|
||||
sync_module_states=True,
|
||||
use_orig_params=True,
|
||||
)
|
||||
# if not isinstance(self.student["backbone"].pre_encoder.LUT, FSDP):
|
||||
self.student["backbone"].pre_encoder.LUT = get_fsdp_wrapper(cfg, modules_to_wrap={nn.Module})(self.student["backbone"].pre_encoder.LUT)
|
||||
self.teacher["backbone"].pre_encoder.LUT = get_fsdp_wrapper(cfg, modules_to_wrap={nn.Module})(self.teacher["backbone"].pre_encoder.LUT)
|
||||
cfg = self.cfg.compute_precision.student.model_adapter
|
||||
self.student["backbone"].model_adapter = get_fsdp_wrapper(cfg, modules_to_wrap={nn.Module})(self.student["backbone"].model_adapter)
|
||||
self.teacher["backbone"].model_adapter = get_fsdp_wrapper(cfg, modules_to_wrap={nn.Module})(self.teacher["backbone"].model_adapter)
|
||||
cfg = self.cfg.compute_precision.student.merge_blocks
|
||||
for block in range(len(self.student["backbone"].merge_blocks)):
|
||||
self.student["backbone"].merge_blocks[block] = get_fsdp_wrapper(cfg, modules_to_wrap={nn.Module})(self.student["backbone"].merge_blocks[block])
|
||||
|
||||
|
||||
# def prepare_for_distributed_training(self):
|
||||
# logger.info("DISTRIBUTED FSDP -- preparing model for distributed training")
|
||||
|
||||
# # Synchronize all student subnetworks across GPUs
|
||||
# for k, v in self.student.items():
|
||||
# self.teacher[k].load_state_dict(self.student[k].state_dict())
|
||||
# student_model_cfg = self.cfg.compute_precision.student[k]
|
||||
# self.student[k] = get_fsdp_wrapper(student_model_cfg, modules_to_wrap={BlockChunk})(self.student[k])
|
||||
# teacher_model_cfg = self.cfg.compute_precision.teacher[k]
|
||||
# self.teacher[k] = get_fsdp_wrapper(teacher_model_cfg, modules_to_wrap={BlockChunk})(self.teacher[k])
|
||||
|
||||
# # Wrap the pre-encoder, model adapter, and merge blocks separately
|
||||
# if hasattr(self, 'pre_encoder'):
|
||||
# logger.info("Wrapping pre-encoder for FSDP")
|
||||
# self.pre_encoder = get_fsdp_wrapper(self.cfg.compute_precision.student, modules_to_wrap={nn.Module})(self.pre_encoder)
|
||||
|
||||
# if hasattr(self, 'model_adapter'):
|
||||
# logger.info("Wrapping model adapter for FSDP")
|
||||
# self.model_adapter = get_fsdp_wrapper(self.cfg.compute_precision.student, modules_to_wrap={nn.Module})(self.model_adapter)
|
||||
|
||||
# if hasattr(self, 'merge_blocks'):
|
||||
# logger.info("Wrapping merge blocks for FSDP")
|
||||
# for i, block in enumerate(self.merge_blocks):
|
||||
# self.merge_blocks[i] = get_fsdp_wrapper(self.cfg.compute_precision.student, modules_to_wrap={nn.Module})(block)
|
||||
|
||||
# def prepare_for_distributed_training(self):
|
||||
# """
|
||||
# Prepare both student and teacher models for distributed training using FSDP.
|
||||
# This handles the issue of having both trainable and non-trainable parameters.
|
||||
# """
|
||||
# from torch.distributed.fsdp.wrap import ModuleWrapPolicy
|
||||
# import copy
|
||||
|
||||
# # Process student models
|
||||
# for k in self.student.keys():
|
||||
# print("Preparing student model for distributed training:", k)
|
||||
|
||||
# # First make all parameters trainable to satisfy FSDP requirements
|
||||
# for param in self.student[k].parameters():
|
||||
# param.requires_grad_(True)
|
||||
|
||||
# # Wrap with FSDP
|
||||
# student_model_cfg = self.cfg.compute_precision.student[k]
|
||||
# print("Student __: ", self.student[k], student_model_cfg)
|
||||
# self.student[k] = get_fsdp_wrapper(student_model_cfg, modules_to_wrap={BlockChunk})(self.student[k])
|
||||
|
||||
# # After FSDP wrapping, set appropriate parameters to not require gradient
|
||||
# for name, param in self.student[k].named_parameters():
|
||||
# # Freeze DINO backbone parameters
|
||||
# if any([x in name for x in [
|
||||
# 'pre_encoder', 'model_adapter', 'merge_blocks'
|
||||
# ]]):
|
||||
# # print("HERE HERE", name)
|
||||
# param.requires_grad_(True)
|
||||
# else:
|
||||
# print("Name: ", name)
|
||||
# param.requires_grad_(False)
|
||||
# # print("Second option")
|
||||
|
||||
# # Print statistics about trainable parameters
|
||||
# trainable_params = 0
|
||||
# total_params = 0
|
||||
# for name, param in self.student[k].named_parameters():
|
||||
# total_params += param.numel()
|
||||
# if param.requires_grad:
|
||||
# trainable_params += param.numel()
|
||||
|
||||
# print(f"Student {k}: Total params: {total_params:,}, Trainable params: {trainable_params:,} ({trainable_params/total_params:.2%})")
|
||||
|
||||
# # Process teacher models
|
||||
# for k in self.teacher.keys():
|
||||
# print("Preparing teacher model for distributed training:", k)
|
||||
|
||||
# # Teacher model doesn't need gradient computation during training
|
||||
# # But we first make all parameters trainable for FSDP wrapping
|
||||
# for param in self.teacher[k].parameters():
|
||||
# param.requires_grad_(True)
|
||||
|
||||
# # Wrap with FSDP
|
||||
# teacher_model_cfg = self.cfg.compute_precision.teacher[k]
|
||||
# print("Teacher __: ", self.teacher[k], teacher_model_cfg)
|
||||
# self.teacher[k] = get_fsdp_wrapper(teacher_model_cfg, modules_to_wrap={BlockChunk})(self.teacher[k])
|
||||
|
||||
# # After FSDP wrapping, set all parameters to not require gradient
|
||||
# # Teacher models are typically frozen and only updated via EMA
|
||||
# for param in self.teacher[k].parameters():
|
||||
# param.requires_grad_(False)
|
||||
|
||||
# # Print statistics
|
||||
# total_params = sum(p.numel() for p in self.teacher[k].parameters())
|
||||
# print(f"Teacher {k}: Total params: {total_params:,}, All parameters frozen")
|
||||
|
||||
# def prepare_for_distributed_training(self):
|
||||
# logger.info("DISTRIBUTED FSDP -- preparing model for distributed training")
|
||||
|
||||
# # Before FSDP wrapping, ensure uniform requires_grad within each BlockChunk
|
||||
# for k, v in self.student.items():
|
||||
# for name, module in v.named_modules():
|
||||
# if isinstance(module, BlockChunk):
|
||||
# # # Make requires_grad uniform within each BlockChunk
|
||||
# # # Option 1: Make all parameters in the BlockChunk trainable if any are trainable
|
||||
# # any_trainable = any(p.requires_grad for p in module.parameters())
|
||||
# # if any_trainable:
|
||||
# # for param in module.parameters():
|
||||
# # param.requires_grad = True
|
||||
|
||||
# # Option 2: Or alternatively, make all parameters in the BlockChunk non-trainable
|
||||
# for param in module.parameters():
|
||||
# param.requires_grad = False
|
||||
|
||||
# # Now proceed with FSDP wrapping
|
||||
# for k, v in self.student.items():
|
||||
# self.teacher[k].load_state_dict(self.student[k].state_dict())
|
||||
# student_model_cfg = self.cfg.compute_precision.student[k]
|
||||
# self.student[k] = get_fsdp_wrapper(student_model_cfg, modules_to_wrap={BlockChunk})(self.student[k])
|
||||
# teacher_model_cfg = self.cfg.compute_precision.teacher[k]
|
||||
# self.teacher[k] = get_fsdp_wrapper(teacher_model_cfg, modules_to_wrap={BlockChunk})(self.teacher[k])
|
||||
|
||||
|
||||
def freeze_original_dino_weights(self):
|
||||
print("Freezing")
|
||||
|
||||
# Freeze the original DINO weights
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
# Unfreeze the pre-encoder, model adapter, and merge blocks
|
||||
if hasattr(self, 'pre_encoder'):
|
||||
for param in self.pre_encoder.parameters():
|
||||
param.requires_grad = True
|
||||
if hasattr(self, 'model_adapter'):
|
||||
for param in self.model_adapter.parameters():
|
||||
param.requires_grad = True
|
||||
if hasattr(self, 'merge_blocks'):
|
||||
for block in self.merge_blocks:
|
||||
for param in block.parameters():
|
||||
param.requires_grad = True
|
|
@ -12,8 +12,7 @@ from functools import partial
|
|||
from fvcore.common.checkpoint import PeriodicCheckpointer
|
||||
import torch
|
||||
|
||||
from dinov2.data import SamplerType, make_data_loader, make_dataset
|
||||
from dinov2.data import collate_data_and_cast, DataAugmentationDINO, MaskingGenerator
|
||||
|
||||
import dinov2.distributed as distributed
|
||||
from dinov2.fsdp import FSDPCheckpointer
|
||||
from dinov2.logging import MetricLogger
|
||||
|
@ -118,6 +117,36 @@ def apply_optim_scheduler(optimizer, lr, wd, last_layer_lr):
|
|||
param_group["weight_decay"] = wd * wd_multiplier
|
||||
param_group["lr"] = (last_layer_lr if is_last_layer else lr) * lr_multiplier
|
||||
|
||||
from torch.distributed.fsdp import StateDictType
|
||||
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp import FullStateDictConfig, StateDictType
|
||||
|
||||
def do_test(cfg, model, iteration):
|
||||
# Ensure we are on the main process before saving
|
||||
if distributed.is_main_process():
|
||||
iterstring = str(iteration)
|
||||
eval_dir = os.path.join(cfg.train.output_dir, "eval", iterstring)
|
||||
os.makedirs(eval_dir, exist_ok=True)
|
||||
|
||||
# Define full state dict config
|
||||
# full_state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
|
||||
|
||||
# # Set FSDP state_dict type before calling state_dict()
|
||||
# with FSDP.state_dict_type(model.teacher, StateDictType.FULL_STATE_DICT, full_state_dict_config):
|
||||
# new_state_dict = model.teacher.state_dict()
|
||||
|
||||
from torch.distributed.fsdp import FullStateDictConfig, StateDictType, state_dict_type
|
||||
|
||||
full_state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
|
||||
with state_dict_type(model, StateDictType.FULL_STATE_DICT, full_state_dict_config):
|
||||
checkpointer = FSDPCheckpointer(model, save_dir=cfg.train.output_dir)
|
||||
checkpointer.resume_or_load(cfg.MODEL.WEIGHTS, resume=not args.no_resume)
|
||||
|
||||
# Save teacher checkpoint
|
||||
# teacher_ckp_path = os.path.join(eval_dir, "teacher_checkpoint.pth")
|
||||
# torch.save({"teacher": new_state_dict}, teacher_ckp_path)
|
||||
|
||||
|
||||
def do_test(cfg, model, iteration):
|
||||
new_state_dict = model.teacher.state_dict()
|
||||
|
@ -133,15 +162,208 @@ def do_test(cfg, model, iteration):
|
|||
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
# def do_train(cfg, model, resume=False):
|
||||
|
||||
# from dinov2.data import SamplerType, make_data_loader, make_dataset
|
||||
# from dinov2.data import collate_data_and_cast, DataAugmentationDINO, MaskingGenerator
|
||||
# writer = SummaryWriter(log_dir="/home/paperspace/Documents/nika_space/dinov2/dinov2/tensorboard_logs/with_merged_blocks_small_lr")
|
||||
# model.train()
|
||||
# inputs_dtype = torch.half
|
||||
# fp16_scaler = model.fp16_scaler # for mixed precision training
|
||||
|
||||
# # setup optimizer
|
||||
# print("Cfg optim: ", cfg.optim)
|
||||
# optimizer = build_optimizer(cfg, model.get_params_groups())
|
||||
# # optimizer = build_optimizer(cfg, [p for p in model.parameters() if p.requires_grad])
|
||||
# # trainable_params = [p for p in model.parameters() if p.requires_grad]
|
||||
# # optimizer = torch.optim.AdamW(trainable_params, lr=lr)
|
||||
|
||||
# (
|
||||
# lr_schedule,
|
||||
# wd_schedule,
|
||||
# momentum_schedule,
|
||||
# teacher_temp_schedule,
|
||||
# last_layer_lr_schedule,
|
||||
# ) = build_schedulers(cfg)
|
||||
|
||||
# # checkpointer
|
||||
# checkpointer = FSDPCheckpointer(model, cfg.train.output_dir, optimizer=optimizer, save_to_disk=True)
|
||||
|
||||
# start_iter = checkpointer.resume_or_load(cfg.MODEL.WEIGHTS, resume=resume).get("iteration", -1) + 1
|
||||
|
||||
# OFFICIAL_EPOCH_LENGTH = cfg.train.OFFICIAL_EPOCH_LENGTH
|
||||
# max_iter = cfg.optim.epochs * OFFICIAL_EPOCH_LENGTH
|
||||
|
||||
# periodic_checkpointer = PeriodicCheckpointer(
|
||||
# checkpointer,
|
||||
# period=3 * OFFICIAL_EPOCH_LENGTH,
|
||||
# max_iter=max_iter,
|
||||
# max_to_keep=3,
|
||||
# )
|
||||
|
||||
# # setup data preprocessing
|
||||
|
||||
# img_size = cfg.crops.global_crops_size
|
||||
# patch_size = cfg.student.patch_size
|
||||
# n_tokens = (img_size // patch_size) ** 2
|
||||
# mask_generator = MaskingGenerator(
|
||||
# input_size=(img_size // patch_size, img_size // patch_size),
|
||||
# max_num_patches=0.5 * img_size // patch_size * img_size // patch_size,
|
||||
# )
|
||||
|
||||
# data_transform = DataAugmentationDINO(
|
||||
# cfg.crops.global_crops_scale,
|
||||
# cfg.crops.local_crops_scale,
|
||||
# cfg.crops.local_crops_number,
|
||||
# global_crops_size=cfg.crops.global_crops_size,
|
||||
# local_crops_size=cfg.crops.local_crops_size,
|
||||
# )
|
||||
|
||||
# collate_fn = partial(
|
||||
# collate_data_and_cast,
|
||||
# mask_ratio_tuple=cfg.ibot.mask_ratio_min_max,
|
||||
# mask_probability=cfg.ibot.mask_sample_probability,
|
||||
# n_tokens=n_tokens,
|
||||
# mask_generator=mask_generator,
|
||||
# dtype=inputs_dtype,
|
||||
# )
|
||||
|
||||
# # setup data loader
|
||||
|
||||
# dataset = make_dataset(
|
||||
# dataset_str=cfg.train.dataset_path,
|
||||
# transform=data_transform,
|
||||
# target_transform=lambda _: (),
|
||||
# )
|
||||
# # sampler_type = SamplerType.INFINITE
|
||||
# sampler_type = SamplerType.SHARDED_INFINITE
|
||||
# data_loader = make_data_loader(
|
||||
# dataset=dataset,
|
||||
# batch_size=cfg.train.batch_size_per_gpu,
|
||||
# num_workers=cfg.train.num_workers,
|
||||
# shuffle=True,
|
||||
# seed=start_iter, # TODO: Fix this -- cfg.train.seed
|
||||
# sampler_type=sampler_type,
|
||||
# sampler_advance=0, # TODO(qas): fix this -- start_iter * cfg.train.batch_size_per_gpu,
|
||||
# drop_last=True,
|
||||
# collate_fn=collate_fn,
|
||||
# )
|
||||
|
||||
# # training loop
|
||||
|
||||
# iteration = start_iter
|
||||
|
||||
# logger.info("Starting training from iteration {}".format(start_iter))
|
||||
# metrics_file = os.path.join(cfg.train.output_dir, "training_metrics.json")
|
||||
# metric_logger = MetricLogger(delimiter=" ", output_file=metrics_file)
|
||||
# header = "Training"
|
||||
|
||||
# for data in metric_logger.log_every(
|
||||
# data_loader,
|
||||
# 10,
|
||||
# header,
|
||||
# max_iter,
|
||||
# start_iter,
|
||||
# ):
|
||||
# current_batch_size = data["collated_global_crops"].shape[0] / 2
|
||||
# if iteration > max_iter:
|
||||
# return
|
||||
|
||||
# # apply schedules
|
||||
|
||||
# lr = lr_schedule[iteration]
|
||||
# wd = wd_schedule[iteration]
|
||||
# mom = momentum_schedule[iteration]
|
||||
# teacher_temp = teacher_temp_schedule[iteration]
|
||||
# last_layer_lr = last_layer_lr_schedule[iteration]
|
||||
# apply_optim_scheduler(optimizer, lr, wd, last_layer_lr)
|
||||
|
||||
# # compute losses
|
||||
|
||||
# optimizer.zero_grad(set_to_none=True)
|
||||
# loss_dict = model.forward_backward(data, teacher_temp=teacher_temp)
|
||||
|
||||
# # clip gradients
|
||||
|
||||
# if fp16_scaler is not None:
|
||||
# if cfg.optim.clip_grad:
|
||||
# fp16_scaler.unscale_(optimizer)
|
||||
# for v in model.student.values():
|
||||
# v.clip_grad_norm_(cfg.optim.clip_grad)
|
||||
# fp16_scaler.step(optimizer)
|
||||
# fp16_scaler.update()
|
||||
# else:
|
||||
# if cfg.optim.clip_grad:
|
||||
# for v in model.student.values():
|
||||
# v.clip_grad_norm_(cfg.optim.clip_grad)
|
||||
# optimizer.step()
|
||||
|
||||
# # perform teacher EMA update
|
||||
|
||||
# model.update_teacher(mom)
|
||||
|
||||
# # logging
|
||||
|
||||
# if distributed.get_global_size() > 1:
|
||||
# for v in loss_dict.values():
|
||||
# torch.distributed.all_reduce(v)
|
||||
# loss_dict_reduced = {k: v.item() / distributed.get_global_size() for k, v in loss_dict.items()}
|
||||
|
||||
# if math.isnan(sum(loss_dict_reduced.values())):
|
||||
# logger.info("NaN detected")
|
||||
# raise AssertionError
|
||||
# losses_reduced = sum(loss for loss in loss_dict_reduced.values())
|
||||
|
||||
# metric_logger.update(lr=lr)
|
||||
# metric_logger.update(wd=wd)
|
||||
# metric_logger.update(mom=mom)
|
||||
# metric_logger.update(last_layer_lr=last_layer_lr)
|
||||
# metric_logger.update(current_batch_size=current_batch_size)
|
||||
# metric_logger.update(total_loss=losses_reduced, **loss_dict_reduced)
|
||||
# writer.add_scalar('Loss/Total_Loss', losses_reduced, iteration)
|
||||
# writer.add_scalar('Learning_Rate', lr, iteration)
|
||||
# writer.add_scalar('Weight_Decay', wd, iteration)
|
||||
# writer.add_scalar('Momentum', mom, iteration)
|
||||
# # checkpointing and testing
|
||||
|
||||
# if cfg.evaluation.eval_period_iterations > 0 and (iteration + 1) % cfg.evaluation.eval_period_iterations == 0:
|
||||
# do_test(cfg, model, f"training_{iteration}")
|
||||
# torch.cuda.synchronize()
|
||||
# periodic_checkpointer.step(iteration)
|
||||
|
||||
# iteration = iteration + 1
|
||||
# metric_logger.synchronize_between_processes()
|
||||
# writer.close()
|
||||
# return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
||||
|
||||
def do_train(cfg, model, resume=False):
|
||||
writer = SummaryWriter(log_dir="/home/paperspace/Documents/nika_space/dinov2/dinov2/tensorboard_logs/standart_custom_conf")
|
||||
from dinov2.data import SamplerType, make_data_loader, make_dataset
|
||||
from dinov2.data import collate_data_and_cast, DataAugmentationDINO, MaskingGenerator
|
||||
writer = SummaryWriter(log_dir="/home/paperspace/Documents/nika_space/dinov2/dinov2/tensorboard_logs/with_merged_blocks_small_lr")
|
||||
model.train()
|
||||
inputs_dtype = torch.half
|
||||
fp16_scaler = model.fp16_scaler # for mixed precision training
|
||||
|
||||
# setup optimizer
|
||||
# # Freeze the original DINO weights
|
||||
# for param in model.parameters():
|
||||
# param.requires_grad = False
|
||||
|
||||
# # Unfreeze the pre-encoder, model adapter, and merge blocks
|
||||
# for param in model.pre_encoder.parameters():
|
||||
# param.requires_grad = True
|
||||
# for param in model.model_adapter.parameters():
|
||||
# param.requires_grad = True
|
||||
# for block in model.merge_blocks:
|
||||
# for param in block.parameters():
|
||||
# param.requires_grad = True
|
||||
|
||||
# setup optimizer
|
||||
print("Cfg optim: ", cfg.optim)
|
||||
optimizer = build_optimizer(cfg, model.get_params_groups())
|
||||
# optimizer = build_optimizer(cfg, [p for p in model.parameters() if p.requires_grad])
|
||||
# trainable_params = [p for p in model.parameters() if p.requires_grad]
|
||||
# optimizer = torch.optim.AdamW(trainable_params, lr=lr)
|
||||
|
||||
(
|
||||
lr_schedule,
|
||||
wd_schedule,
|
||||
|
@ -221,6 +443,7 @@ def do_train(cfg, model, resume=False):
|
|||
metrics_file = os.path.join(cfg.train.output_dir, "training_metrics.json")
|
||||
metric_logger = MetricLogger(delimiter=" ", output_file=metrics_file)
|
||||
header = "Training"
|
||||
warmup_iterations = cfg.optim.warmup_epochs * OFFICIAL_EPOCH_LENGTH
|
||||
|
||||
for data in metric_logger.log_every(
|
||||
data_loader,
|
||||
|
@ -233,6 +456,12 @@ def do_train(cfg, model, resume=False):
|
|||
if iteration > max_iter:
|
||||
return
|
||||
|
||||
# Unfreeze original DINO weights after warmup
|
||||
if iteration == warmup_iterations:
|
||||
for param in model.parameters():
|
||||
param.requires_grad = True
|
||||
|
||||
|
||||
# apply schedules
|
||||
|
||||
lr = lr_schedule[iteration]
|
||||
|
@ -263,7 +492,6 @@ def do_train(cfg, model, resume=False):
|
|||
optimizer.step()
|
||||
|
||||
# perform teacher EMA update
|
||||
|
||||
model.update_teacher(mom)
|
||||
|
||||
# logging
|
||||
|
@ -315,16 +543,17 @@ def parse_merge_block_indexes(config_value: str) -> List[int]:
|
|||
return []
|
||||
return list(map(int, re.split(r'\s*,\s*', config_value)))
|
||||
|
||||
|
||||
def main(args):
|
||||
cfg = setup(args)
|
||||
print("cfg.student.merge_block_indexes ", cfg.student.merge_block_indexes)
|
||||
cfg.student.merge_block_indexes = parse_merge_block_indexes(cfg.student.merge_block_indexes)
|
||||
print("cfg.student.merge_block_indexes:", cfg.student.merge_block_indexes)
|
||||
|
||||
# cfg.student.merge_block_indexes = parse_merge_block_indexes(cfg.student.merge_block_indexes)
|
||||
# print("INDEXES", cfg.student.merge_block_indexes)
|
||||
model = SSLMetaArch(cfg).to(torch.device("cuda"))
|
||||
model.prepare_for_distributed_training()
|
||||
|
||||
# logger.info("Model:\n{}".format(model))
|
||||
|
||||
# model.freeze_original_dino_weights()
|
||||
|
||||
if args.eval_only:
|
||||
iteration = (
|
||||
FSDPCheckpointer(model, save_dir=cfg.train.output_dir)
|
||||
|
@ -333,7 +562,7 @@ def main(args):
|
|||
+ 1
|
||||
)
|
||||
return do_test(cfg, model, f"manual_{iteration}")
|
||||
|
||||
|
||||
do_train(cfg, model, resume=not args.no_resume)
|
||||
|
||||
|
||||
|
|
|
@ -18,17 +18,26 @@ logger = logging.getLogger("dinov2")
|
|||
|
||||
|
||||
def load_pretrained_weights(model, pretrained_weights, checkpoint_key):
|
||||
|
||||
if urlparse(pretrained_weights).scheme: # If it looks like an URL
|
||||
state_dict = torch.hub.load_state_dict_from_url(pretrained_weights, map_location="cpu")
|
||||
else:
|
||||
state_dict = torch.load(pretrained_weights, map_location="cpu")
|
||||
if checkpoint_key is not None and checkpoint_key in state_dict:
|
||||
state_dict = state_dict['model']
|
||||
if checkpoint_key is not None:
|
||||
print("INSIDE IF")
|
||||
logger.info(f"Take key {checkpoint_key} in provided checkpoint dict")
|
||||
state_dict = state_dict[checkpoint_key]
|
||||
state_dict = {key: value for key, value in state_dict.items() if key.startswith(checkpoint_key) }
|
||||
# state_dict = state_dict[checkpoint_key]
|
||||
# remove `module.` prefix
|
||||
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
|
||||
# remove `backbone.` prefix induced by multicrop wrapper
|
||||
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
|
||||
|
||||
state_dict = {k.replace("teacher.backbone.", ""): v for k, v in state_dict.items()}
|
||||
# print("Model Keys:")
|
||||
# for key in state_dict.keys():
|
||||
# print(key)
|
||||
# print("State dict: ", state_dict.keys())
|
||||
msg = model.load_state_dict(state_dict, strict=False)
|
||||
logger.info("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg))
|
||||
|
||||
|
|
Loading…
Reference in New Issue