431 lines
20 KiB
Python
431 lines
20 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
from functools import partial
|
|
import logging
|
|
import copy
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
from dinov2.loss import DINOLoss, iBOTPatchLoss, KoLeoLoss
|
|
from dinov2.models import build_model
|
|
from dinov2.layers import DINOHead
|
|
from dinov2.utils.utils import has_batchnorms
|
|
from dinov2.utils.param_groups import get_params_groups_with_decay, fuse_params_groups
|
|
from dinov2.fsdp import get_fsdp_wrapper, ShardedGradScaler, get_fsdp_modules, reshard_fsdp_model
|
|
|
|
from dinov2.models.vision_transformer import BlockChunk
|
|
|
|
try:
|
|
from xformers.ops import fmha
|
|
|
|
XFORMERS_AVAILABLE = True
|
|
except ImportError:
|
|
XFORMERS_AVAILABLE = False
|
|
assert XFORMERS_AVAILABLE, "xFormers is required for DINOv2 training"
|
|
|
|
|
|
logger = logging.getLogger("dinov2")
|
|
|
|
|
|
class DistillMetaArch(nn.Module):
|
|
def __init__(self, cfg):
|
|
super().__init__()
|
|
self.cfg = cfg
|
|
self.fp16_scaler = ShardedGradScaler() if cfg.compute_precision.grad_scaler else None
|
|
|
|
student_model_dict = dict()
|
|
teacher_model_dict = dict()
|
|
|
|
student_backbone, student_embed_dim = build_model(cfg.student, only_teacher=True, img_size=cfg.crops.global_crops_size) # pyright: ignore
|
|
teacher_backbone, teacher_embed_dim = build_model(cfg.teacher, only_teacher=True, img_size=cfg.crops.global_crops_size) # pyright: ignore
|
|
student_model_dict["backbone"] = student_backbone
|
|
teacher_model_dict["backbone"] = teacher_backbone
|
|
|
|
# Student and teacher embedding dimensions can be different and therefore DINO head and IBOT heads should be different
|
|
logger.info(f"OPTIONS -- architecture : student_embed_dim: {student_embed_dim}")
|
|
logger.info(f"OPTIONS -- architecture : teacher_embed_dim: {teacher_embed_dim}")
|
|
|
|
self.student_embed_dim = student_embed_dim
|
|
self.teacher_embed_dim = teacher_embed_dim
|
|
self.dino_out_dim = cfg.dino.head_n_prototypes
|
|
|
|
self.do_dino = cfg.dino.loss_weight > 0
|
|
self.do_koleo = cfg.dino.koleo_loss_weight > 0
|
|
self.do_ibot = cfg.ibot.loss_weight > 0
|
|
self.ibot_separate_head = cfg.ibot.separate_head
|
|
|
|
logger.info("OPTIONS -- DINO")
|
|
if self.do_dino:
|
|
logger.info(f"OPTIONS -- DINO -- loss_weight: {cfg.dino.loss_weight}")
|
|
logger.info(f"OPTIONS -- DINO -- head_n_prototypes: {cfg.dino.head_n_prototypes}")
|
|
logger.info(f"OPTIONS -- DINO -- head_bottleneck_dim: {cfg.dino.head_bottleneck_dim}")
|
|
logger.info(f"OPTIONS -- DINO -- head_hidden_dim: {cfg.dino.head_hidden_dim}")
|
|
self.dino_loss_weight = cfg.dino.loss_weight
|
|
dino_head = partial(
|
|
DINOHead,
|
|
out_dim=cfg.dino.head_n_prototypes,
|
|
hidden_dim=cfg.dino.head_hidden_dim,
|
|
bottleneck_dim=cfg.dino.head_bottleneck_dim,
|
|
nlayers=cfg.dino.head_nlayers,
|
|
)
|
|
|
|
self.dino_loss = DINOLoss(self.dino_out_dim)
|
|
if self.do_koleo:
|
|
logger.info("OPTIONS -- DINO -- applying KOLEO regularization")
|
|
self.koleo_loss = KoLeoLoss()
|
|
|
|
else:
|
|
logger.info("OPTIONS -- DINO -- not using DINO")
|
|
|
|
if self.do_dino or self.do_ibot:
|
|
student_model_dict["dino_head"] = dino_head(in_dim=student_embed_dim) # pyright: ignore
|
|
teacher_model_dict["dino_head"] = dino_head(in_dim=teacher_embed_dim) # pyright: ignore
|
|
|
|
logger.info("OPTIONS -- IBOT")
|
|
logger.info(f"OPTIONS -- IBOT -- loss_weight: {cfg.ibot.loss_weight}")
|
|
logger.info(f"OPTIONS -- IBOT masking -- ibot_mask_ratio_tuple: {cfg.ibot.mask_ratio_min_max}")
|
|
logger.info(f"OPTIONS -- IBOT masking -- ibot_mask_sample_probability: {cfg.ibot.mask_sample_probability}")
|
|
if self.do_ibot:
|
|
self.ibot_loss_weight = cfg.ibot.loss_weight
|
|
assert max(cfg.ibot.mask_ratio_min_max) > 0, "please provide a positive mask ratio tuple for ibot"
|
|
assert cfg.ibot.mask_sample_probability > 0, "please provide a positive mask probability for ibot"
|
|
self.ibot_out_dim = cfg.ibot.head_n_prototypes if self.ibot_separate_head else cfg.dino.head_n_prototypes
|
|
self.ibot_patch_loss = iBOTPatchLoss(self.ibot_out_dim)
|
|
if self.ibot_separate_head:
|
|
logger.info(f"OPTIONS -- IBOT -- loss_weight: {cfg.ibot.loss_weight}")
|
|
logger.info(f"OPTIONS -- IBOT -- head_n_prototypes: {cfg.ibot.head_n_prototypes}")
|
|
logger.info(f"OPTIONS -- IBOT -- head_bottleneck_dim: {cfg.ibot.head_bottleneck_dim}")
|
|
logger.info(f"OPTIONS -- IBOT -- head_hidden_dim: {cfg.ibot.head_hidden_dim}")
|
|
ibot_head = partial(
|
|
DINOHead,
|
|
out_dim=cfg.ibot.head_n_prototypes,
|
|
hidden_dim=cfg.ibot.head_hidden_dim,
|
|
bottleneck_dim=cfg.ibot.head_bottleneck_dim,
|
|
nlayers=cfg.ibot.head_nlayers,
|
|
)
|
|
student_model_dict["ibot_head"] = ibot_head(in_dim=student_embed_dim)
|
|
teacher_model_dict["ibot_head"] = ibot_head(in_dim=teacher_embed_dim)
|
|
else:
|
|
logger.info("OPTIONS -- IBOT -- head shared with DINO")
|
|
|
|
# self.need_to_synchronize_fsdp_streams = True
|
|
|
|
self.student = nn.ModuleDict(student_model_dict)
|
|
self.teacher = nn.ModuleDict(teacher_model_dict)
|
|
self.student_shadow = copy.deepcopy(self.student)
|
|
|
|
assert cfg.teacher.pretrained_weights is not None, "Must contain pretrained weights for distillation."
|
|
chkpt = torch.load(cfg.teacher.pretrained_weights)
|
|
logger.info(f"OPTIONS -- pretrained weights: loading from {cfg.teacher.pretrained_weights}")
|
|
self.teacher.load_state_dict(chkpt["model"], strict=False)
|
|
|
|
# there is no backpropagation through the teacher, so no need for gradients
|
|
for p in self.teacher.parameters():
|
|
p.requires_grad = False
|
|
|
|
# there is no backpropagation through the shadow copy either
|
|
for p in self.student_shadow.parameters():
|
|
p.requires_grad = False
|
|
|
|
logger.info(f"Student is built: it is using {cfg.student.arch}")
|
|
logger.info(f"Teacher is built: it is using {cfg.teacher.arch}")
|
|
|
|
def forward(self, inputs):
|
|
raise NotImplementedError
|
|
|
|
def backprop_loss(self, loss):
|
|
if self.fp16_scaler is not None:
|
|
self.fp16_scaler.scale(loss).backward() # pyright: ignore
|
|
else:
|
|
loss.backward()
|
|
|
|
def forward_backward(self, images, teacher_temp):
|
|
n_global_crops = 2
|
|
assert n_global_crops == 2
|
|
n_local_crops = self.cfg.crops.local_crops_number
|
|
|
|
global_crops = images["collated_global_crops"].cuda(non_blocking=True)
|
|
local_crops = images["collated_local_crops"].cuda(non_blocking=True)
|
|
|
|
masks = images["collated_masks"].cuda(non_blocking=True)
|
|
mask_indices_list = images["mask_indices_list"].cuda(non_blocking=True)
|
|
n_masked_patches_tensor = images["n_masked_patches"].cuda(non_blocking=True)
|
|
n_masked_patches = mask_indices_list.shape[0]
|
|
upperbound = images["upperbound"]
|
|
masks_weight = images["masks_weight"].cuda(non_blocking=True)
|
|
|
|
n_local_crops_loss_terms = max(n_local_crops * n_global_crops, 1)
|
|
n_global_crops_loss_terms = (n_global_crops - 1) * n_global_crops
|
|
|
|
do_dino = self.do_dino
|
|
do_ibot = self.do_ibot
|
|
|
|
# loss scales
|
|
ibot_loss_scale = 1.0 / n_global_crops
|
|
|
|
# teacher output
|
|
@torch.no_grad()
|
|
def get_teacher_output():
|
|
x, n_global_crops_teacher = global_crops, n_global_crops
|
|
teacher_backbone_output_dict = self.teacher.backbone(x, is_training=True) # pyright: ignore
|
|
teacher_cls_tokens = teacher_backbone_output_dict["x_norm_clstoken"]
|
|
teacher_cls_tokens = teacher_cls_tokens.chunk(n_global_crops_teacher)
|
|
# watch out: these are chunked and cat'd in reverse so A is matched to B in the global crops dino loss
|
|
teacher_cls_tokens = torch.cat((teacher_cls_tokens[1], teacher_cls_tokens[0]))
|
|
ibot_teacher_patch_tokens = teacher_backbone_output_dict["x_norm_patchtokens"]
|
|
_dim = ibot_teacher_patch_tokens.shape[-1]
|
|
n_cls_tokens = teacher_cls_tokens.shape[0]
|
|
|
|
if do_ibot and not self.ibot_separate_head:
|
|
buffer_tensor_teacher = ibot_teacher_patch_tokens.new_zeros(upperbound + n_cls_tokens, _dim)
|
|
buffer_tensor_teacher[:n_cls_tokens].copy_(teacher_cls_tokens)
|
|
torch.index_select(
|
|
ibot_teacher_patch_tokens.flatten(0, 1),
|
|
dim=0,
|
|
index=mask_indices_list,
|
|
out=buffer_tensor_teacher[n_cls_tokens : n_cls_tokens + n_masked_patches],
|
|
)
|
|
tokens_after_head = self.teacher.dino_head(buffer_tensor_teacher)
|
|
teacher_cls_tokens_after_head = tokens_after_head[:n_cls_tokens]
|
|
masked_teacher_patch_tokens_after_head = tokens_after_head[
|
|
n_cls_tokens : n_cls_tokens + n_masked_patches
|
|
]
|
|
elif do_ibot and self.ibot_separate_head:
|
|
buffer_tensor_teacher = ibot_teacher_patch_tokens.new_zeros(upperbound, _dim)
|
|
torch.index_select(
|
|
ibot_teacher_patch_tokens.flatten(0, 1),
|
|
dim=0,
|
|
index=mask_indices_list,
|
|
out=buffer_tensor_teacher[:n_masked_patches],
|
|
)
|
|
teacher_cls_tokens_after_head = self.teacher.dino_head(teacher_cls_tokens)
|
|
masked_teacher_patch_tokens_after_head = self.teacher.ibot_head(buffer_tensor_teacher)[
|
|
:n_masked_patches
|
|
]
|
|
else:
|
|
teacher_cls_tokens_after_head = self.teacher.dino_head(teacher_cls_tokens)
|
|
masked_teacher_ibot_softmaxed_centered = None
|
|
|
|
if self.cfg.train.centering == "centering":
|
|
teacher_dino_softmaxed_centered_list = self.dino_loss.softmax_center_teacher(
|
|
teacher_cls_tokens_after_head, teacher_temp=teacher_temp
|
|
).view(n_global_crops_teacher, -1, *teacher_cls_tokens_after_head.shape[1:])
|
|
self.dino_loss.update_center(teacher_cls_tokens_after_head)
|
|
if do_ibot:
|
|
masked_teacher_patch_tokens_after_head = masked_teacher_patch_tokens_after_head.unsqueeze(0)
|
|
masked_teacher_ibot_softmaxed_centered = self.ibot_patch_loss.softmax_center_teacher(
|
|
masked_teacher_patch_tokens_after_head[:, :n_masked_patches], teacher_temp=teacher_temp
|
|
)
|
|
masked_teacher_ibot_softmaxed_centered = masked_teacher_ibot_softmaxed_centered.squeeze(0)
|
|
self.ibot_patch_loss.update_center(masked_teacher_patch_tokens_after_head[:n_masked_patches])
|
|
|
|
elif self.cfg.train.centering == "sinkhorn_knopp":
|
|
teacher_dino_softmaxed_centered_list = self.dino_loss.sinkhorn_knopp_teacher(
|
|
teacher_cls_tokens_after_head, teacher_temp=teacher_temp
|
|
).view(n_global_crops_teacher, -1, *teacher_cls_tokens_after_head.shape[1:])
|
|
|
|
if do_ibot:
|
|
masked_teacher_ibot_softmaxed_centered = self.ibot_patch_loss.sinkhorn_knopp_teacher(
|
|
masked_teacher_patch_tokens_after_head,
|
|
teacher_temp=teacher_temp,
|
|
n_masked_patches_tensor=n_masked_patches_tensor,
|
|
)
|
|
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
return teacher_dino_softmaxed_centered_list, masked_teacher_ibot_softmaxed_centered
|
|
|
|
teacher_dino_softmaxed_centered_list, masked_teacher_ibot_softmaxed_centered = get_teacher_output()
|
|
reshard_fsdp_model(self.teacher)
|
|
|
|
loss_dict = {}
|
|
|
|
loss_accumulator = 0 # for backprop
|
|
student_global_backbone_output_dict, student_local_backbone_output_dict = self.student.backbone(
|
|
[global_crops, local_crops], masks=[masks, None], is_training=True
|
|
)
|
|
|
|
inputs_for_student_head_list = []
|
|
|
|
# 1a: local crops cls tokens
|
|
student_local_cls_tokens = student_local_backbone_output_dict["x_norm_clstoken"]
|
|
inputs_for_student_head_list.append(student_local_cls_tokens.unsqueeze(0))
|
|
|
|
# 1b: global crops cls tokens
|
|
student_global_cls_tokens = student_global_backbone_output_dict["x_norm_clstoken"]
|
|
inputs_for_student_head_list.append(student_global_cls_tokens.unsqueeze(0))
|
|
|
|
# 1c: global crops patch tokens
|
|
if do_ibot:
|
|
_dim = student_global_backbone_output_dict["x_norm_clstoken"].shape[-1]
|
|
ibot_student_patch_tokens = student_global_backbone_output_dict["x_norm_patchtokens"]
|
|
buffer_tensor_patch_tokens = ibot_student_patch_tokens.new_zeros(upperbound, _dim)
|
|
buffer_tensor_patch_tokens[:n_masked_patches].copy_(
|
|
torch.index_select(ibot_student_patch_tokens.flatten(0, 1), dim=0, index=mask_indices_list)
|
|
)
|
|
if not self.ibot_separate_head:
|
|
inputs_for_student_head_list.append(buffer_tensor_patch_tokens.unsqueeze(0))
|
|
else:
|
|
student_global_masked_patch_tokens_after_head = self.student.ibot_head(buffer_tensor_patch_tokens)[
|
|
:n_masked_patches
|
|
]
|
|
|
|
# 2: run
|
|
_attn_bias, cat_inputs = fmha.BlockDiagonalMask.from_tensor_list(inputs_for_student_head_list)
|
|
outputs_list = _attn_bias.split(self.student.dino_head(cat_inputs))
|
|
|
|
# 3a: local crops cls tokens
|
|
student_local_cls_tokens_after_head = outputs_list.pop(0).squeeze(0)
|
|
|
|
# 3b: global crops cls tokens
|
|
student_global_cls_tokens_after_head = outputs_list.pop(0).squeeze(0)
|
|
|
|
# 3c: global crops patch tokens
|
|
if do_ibot and not self.ibot_separate_head:
|
|
student_global_masked_patch_tokens_after_head = outputs_list.pop(0).squeeze(0)[:n_masked_patches]
|
|
|
|
if n_local_crops > 0:
|
|
dino_local_crops_loss = self.dino_loss(
|
|
student_output_list=student_local_cls_tokens_after_head.chunk(n_local_crops),
|
|
teacher_out_softmaxed_centered_list=teacher_dino_softmaxed_centered_list,
|
|
) / (n_global_crops_loss_terms + n_local_crops_loss_terms)
|
|
|
|
# store for display
|
|
loss_dict["dino_local_crops_loss"] = dino_local_crops_loss
|
|
|
|
# accumulate loss
|
|
loss_accumulator += self.dino_loss_weight * dino_local_crops_loss
|
|
|
|
# process global crops
|
|
loss_scales = 2 # this is here since we process global crops together
|
|
|
|
if do_dino:
|
|
# compute loss
|
|
dino_global_crops_loss = (
|
|
self.dino_loss(
|
|
student_output_list=[student_global_cls_tokens_after_head],
|
|
teacher_out_softmaxed_centered_list=[
|
|
teacher_dino_softmaxed_centered_list.flatten(0, 1)
|
|
], # these were chunked and stacked in reverse so A is matched to B
|
|
)
|
|
* loss_scales
|
|
/ (n_global_crops_loss_terms + n_local_crops_loss_terms)
|
|
)
|
|
|
|
loss_dict["dino_global_crops_loss"] = dino_global_crops_loss
|
|
|
|
# accumulate loss
|
|
loss_accumulator += self.dino_loss_weight * dino_global_crops_loss
|
|
|
|
student_cls_tokens = student_global_cls_tokens
|
|
|
|
if self.do_koleo:
|
|
koleo_loss = self.cfg.dino.koleo_loss_weight * sum(
|
|
self.koleo_loss(p) for p in student_cls_tokens.chunk(2)
|
|
) # we don't apply koleo loss between cls tokens of a same image
|
|
loss_accumulator += koleo_loss
|
|
loss_dict["koleo_loss"] = (
|
|
koleo_loss / loss_scales
|
|
) # this is to display the same losses as before but we can remove eventually
|
|
|
|
if do_ibot:
|
|
# compute loss
|
|
ibot_patch_loss = (
|
|
self.ibot_patch_loss.forward_masked(
|
|
student_global_masked_patch_tokens_after_head,
|
|
masked_teacher_ibot_softmaxed_centered,
|
|
student_masks_flat=masks,
|
|
n_masked_patches=n_masked_patches,
|
|
masks_weight=masks_weight,
|
|
)
|
|
* loss_scales
|
|
* ibot_loss_scale
|
|
)
|
|
|
|
# store for display
|
|
loss_dict["ibot_loss"] = ibot_patch_loss / 2
|
|
|
|
# accumulate loss
|
|
loss_accumulator += self.ibot_loss_weight * ibot_patch_loss
|
|
|
|
self.backprop_loss(loss_accumulator)
|
|
|
|
# self.fsdp_synchronize_streams()
|
|
|
|
return loss_dict
|
|
|
|
# def fsdp_synchronize_streams(self):
|
|
# if self.need_to_synchronize_fsdp_streams:
|
|
# torch.cuda.synchronize()
|
|
# self.student.dino_head._streams = (
|
|
# self.teacher.dino_head._streams
|
|
# ) = self.student.backbone._streams = self.teacher.backbone._streams
|
|
# self.need_to_synchronize_fsdp_streams = False
|
|
|
|
# def update_teacher(self, m):
|
|
# student_param_list = []
|
|
# teacher_param_list = []
|
|
# with torch.no_grad():
|
|
# for k in self.student.keys():
|
|
# for ms, mt in zip(get_fsdp_modules(self.student[k]), get_fsdp_modules(self.teacher[k])):
|
|
# student_param_list += ms.params
|
|
# teacher_param_list += mt.params
|
|
# torch._foreach_mul_(teacher_param_list, m)
|
|
# torch._foreach_add_(teacher_param_list, student_param_list, alpha=1 - m)
|
|
|
|
def update_student_shadow(self, m):
|
|
student_param_list = []
|
|
student_shadow_param_list = []
|
|
with torch.no_grad():
|
|
for k in self.student.keys():
|
|
for ms, mss in zip(get_fsdp_modules(self.student[k]), get_fsdp_modules(self.student_shadow[k])):
|
|
student_param_list += ms.params
|
|
student_shadow_param_list += mss.params
|
|
torch._foreach_mul_(student_shadow_param_list, m)
|
|
torch._foreach_add_(student_shadow_param_list, student_param_list, alpha=1 - m)
|
|
|
|
def train(self):
|
|
super().train()
|
|
self.teacher.eval()
|
|
self.student_shadow.eval()
|
|
|
|
def get_maybe_fused_params_for_submodel(self, m):
|
|
params_groups = get_params_groups_with_decay(
|
|
model=m,
|
|
lr_decay_rate=self.cfg.optim.layerwise_decay,
|
|
patch_embed_lr_mult=self.cfg.optim.patch_embed_lr_mult,
|
|
)
|
|
fused_params_groups = fuse_params_groups(params_groups)
|
|
logger.info("fusing param groups")
|
|
|
|
for g in fused_params_groups:
|
|
g["foreach"] = True
|
|
return fused_params_groups
|
|
|
|
def get_params_groups(self):
|
|
all_params_groups = []
|
|
for m in self.student.values():
|
|
all_params_groups += self.get_maybe_fused_params_for_submodel(m)
|
|
return all_params_groups
|
|
|
|
def prepare_for_distributed_training(self):
|
|
logger.info("DISTRIBUTED FSDP -- preparing model for distributed training")
|
|
if has_batchnorms(self.student):
|
|
raise NotImplementedError
|
|
|
|
# below will synchronize all student subnetworks across gpus:
|
|
for k, v in self.student.items():
|
|
self.student_shadow[k].load_state_dict(self.student[k].state_dict())
|
|
# self.teacher[k].load_state_dict(self.student[k].state_dict())
|
|
student_model_cfg = self.cfg.compute_precision.student[k]
|
|
self.student[k] = get_fsdp_wrapper(student_model_cfg, modules_to_wrap={BlockChunk})(self.student[k])
|
|
|
|
for k, v in self.teacher.items():
|
|
teacher_model_cfg = self.cfg.compute_precision.teacher[k]
|
|
self.teacher[k] = get_fsdp_wrapper(teacher_model_cfg, modules_to_wrap={BlockChunk})(self.teacher[k]) |