work for 1 gpu vit model
parent
2aae191d0b
commit
d683211ae3
|
@ -1,126 +0,0 @@
|
|||
MODEL:
|
||||
WEIGHTS: ''
|
||||
compute_precision:
|
||||
grad_scaler: true
|
||||
teacher:
|
||||
backbone:
|
||||
sharding_strategy: SHARD_GRAD_OP
|
||||
mixed_precision:
|
||||
param_dtype: fp16
|
||||
reduce_dtype: fp16
|
||||
buffer_dtype: fp32
|
||||
dino_head:
|
||||
sharding_strategy: SHARD_GRAD_OP
|
||||
mixed_precision:
|
||||
param_dtype: fp16
|
||||
reduce_dtype: fp16
|
||||
buffer_dtype: fp32
|
||||
ibot_head:
|
||||
sharding_strategy: SHARD_GRAD_OP
|
||||
mixed_precision:
|
||||
param_dtype: fp16
|
||||
reduce_dtype: fp16
|
||||
buffer_dtype: fp32
|
||||
student:
|
||||
backbone:
|
||||
sharding_strategy: SHARD_GRAD_OP
|
||||
mixed_precision:
|
||||
param_dtype: fp16
|
||||
reduce_dtype: fp16
|
||||
buffer_dtype: fp32
|
||||
dino_head:
|
||||
sharding_strategy: SHARD_GRAD_OP
|
||||
mixed_precision:
|
||||
param_dtype: fp16
|
||||
reduce_dtype: fp32
|
||||
buffer_dtype: fp32
|
||||
ibot_head:
|
||||
sharding_strategy: SHARD_GRAD_OP
|
||||
mixed_precision:
|
||||
param_dtype: fp16
|
||||
reduce_dtype: fp32
|
||||
buffer_dtype: fp32
|
||||
dino:
|
||||
loss_weight: 1.0
|
||||
head_n_prototypes: 65536
|
||||
head_bottleneck_dim: 256
|
||||
head_nlayers: 3
|
||||
head_hidden_dim: 2048
|
||||
koleo_loss_weight: 0.1
|
||||
ibot:
|
||||
loss_weight: 1.0
|
||||
mask_sample_probability: 0.0
|
||||
mask_ratio_min_max:
|
||||
- 0.1
|
||||
- 0.5
|
||||
separate_head: false
|
||||
head_n_prototypes: 65536
|
||||
head_bottleneck_dim: 256
|
||||
head_nlayers: 3
|
||||
head_hidden_dim: 2048
|
||||
train:
|
||||
batch_size_per_gpu: 64
|
||||
dataset_path: ImageNet:split=TRAIN
|
||||
output_dir: .
|
||||
saveckp_freq: 20
|
||||
seed: 0
|
||||
num_workers: 10
|
||||
OFFICIAL_EPOCH_LENGTH: 1250
|
||||
cache_dataset: true
|
||||
centering: "centering" # or "sinkhorn_knopp"
|
||||
student:
|
||||
arch: vit_large
|
||||
patch_size: 16
|
||||
drop_path_rate: 0.0
|
||||
layerscale: 1.0e-05
|
||||
drop_path_uniform: true
|
||||
pretrained_weights: ''
|
||||
ffn_layer: "mlp"
|
||||
block_chunks: 0
|
||||
qkv_bias: true
|
||||
proj_bias: true
|
||||
ffn_bias: true
|
||||
teacher:
|
||||
arch: vit_base
|
||||
patch_size: 16
|
||||
drop_path_rate: 0.0
|
||||
layerscale: 1.0e-05
|
||||
drop_path_uniform: true
|
||||
pretrained_weights: ''
|
||||
ffn_layer: "mlp"
|
||||
block_chunks: 0
|
||||
qkv_bias: true
|
||||
proj_bias: true
|
||||
ffn_bias: true
|
||||
momentum_teacher: 0.992
|
||||
final_momentum_teacher: 1
|
||||
warmup_teacher_temp: 0.04
|
||||
teacher_temp: 0.07
|
||||
warmup_teacher_temp_epochs: 30
|
||||
optim:
|
||||
epochs: 100
|
||||
weight_decay: 0.04
|
||||
weight_decay_end: 0.4
|
||||
base_lr: 0.004 # learning rate for a batch size of 1024
|
||||
lr: 0. # will be set after applying scaling rule
|
||||
warmup_epochs: 10
|
||||
min_lr: 1.0e-06
|
||||
clip_grad: 3.0
|
||||
freeze_last_layer_epochs: 1
|
||||
scaling_rule: sqrt_wrt_1024
|
||||
patch_embed_lr_mult: 0.2
|
||||
layerwise_decay: 0.9
|
||||
adamw_beta1: 0.9
|
||||
adamw_beta2: 0.999
|
||||
crops:
|
||||
global_crops_scale:
|
||||
- 0.32
|
||||
- 1.0
|
||||
local_crops_number: 8
|
||||
local_crops_scale:
|
||||
- 0.05
|
||||
- 0.32
|
||||
global_crops_size: 224
|
||||
local_crops_size: 96
|
||||
evaluation:
|
||||
eval_period_iterations: 12500
|
|
@ -0,0 +1,7 @@
|
|||
# this corresponds to the default config
|
||||
train:
|
||||
dataset_path: ImageNet:split=TRAIN
|
||||
batch_size_per_gpu: 64
|
||||
student:
|
||||
arch: vit_large # vit_large, nextvit
|
||||
block_chunks: 4
|
|
@ -12,8 +12,8 @@ from typing import Dict, List
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
_LOCAL_RANK = -1
|
||||
_LOCAL_WORLD_SIZE = -1
|
||||
_LOCAL_RANK = 0
|
||||
_LOCAL_WORLD_SIZE = 1
|
||||
|
||||
|
||||
def is_enabled() -> bool:
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
import logging
|
||||
|
||||
from . import vision_transformer as vits
|
||||
from .nextvit import NextVitSmall
|
||||
|
||||
|
||||
logger = logging.getLogger("dinov2")
|
||||
|
@ -27,14 +28,20 @@ def build_model(args, only_teacher=False, img_size=224):
|
|||
interpolate_offset=args.interpolate_offset,
|
||||
interpolate_antialias=args.interpolate_antialias,
|
||||
)
|
||||
teacher = vits.__dict__[args.arch](**vit_kwargs)
|
||||
if args.arch in vits.__dict__:
|
||||
teacher = vits.__dict__[args.arch](**vit_kwargs)
|
||||
else:
|
||||
teacher = NextVitSmall()
|
||||
if only_teacher:
|
||||
return teacher, teacher.embed_dim
|
||||
student = vits.__dict__[args.arch](
|
||||
**vit_kwargs,
|
||||
drop_path_rate=args.drop_path_rate,
|
||||
drop_path_uniform=args.drop_path_uniform,
|
||||
)
|
||||
if args.arch in vits.__dict__:
|
||||
student = vits.__dict__[args.arch](
|
||||
**vit_kwargs,
|
||||
drop_path_rate=args.drop_path_rate,
|
||||
drop_path_uniform=args.drop_path_uniform,
|
||||
)
|
||||
else:
|
||||
student = NextVitSmall()
|
||||
embed_dim = student.embed_dim
|
||||
return student, teacher, embed_dim
|
||||
|
||||
|
|
|
@ -0,0 +1,50 @@
|
|||
import torch
|
||||
from torch.nn import functional
|
||||
|
||||
import sys
|
||||
sys.path.append('/home/li.yu/code/JupiterCVML/europa/base/src/europa')
|
||||
from dl.network.nextvit_brt import _get_nextvit
|
||||
|
||||
class NextVitSmall(torch.nn.Module):
|
||||
"""BRT Segmentation model with definition to make it a custom model supported."""
|
||||
def __init__(self, num_classes=197*1024) -> None:
|
||||
super().__init__()
|
||||
|
||||
# define backbone
|
||||
self.backbone = _get_nextvit(
|
||||
model_size="small",
|
||||
frozen_stages=-1,
|
||||
norm_eval=False,
|
||||
with_extra_norm=True,
|
||||
norm_cfg=dict(type="SyncBN", requires_grad=True),
|
||||
in_channels=3,
|
||||
)
|
||||
|
||||
# self.proj_head = torch.nn.Sequential(
|
||||
# torch.nn.Linear(1024, num_classes),
|
||||
# )
|
||||
assert num_classes == 197 * 1024
|
||||
self.num_register_tokens = 1
|
||||
self.embed_dim = 1024
|
||||
self.proj_head = torch.nn.Linear(1024, num_classes)
|
||||
|
||||
def forward_backbone(self, x, masks=None):
|
||||
y = self.backbone(x)
|
||||
y = functional.adaptive_avg_pool2d(y[-1], (1, 1))
|
||||
y = torch.flatten(y, 1)
|
||||
y = self.proj_head(y)
|
||||
|
||||
n = y.shape[0]
|
||||
y_reshaped = y.reshape(n, 197, 1024)
|
||||
return {
|
||||
"x_norm_clstoken": y_reshaped[:, 0], # teacher 128x1024
|
||||
"x_norm_regtokens": y_reshaped[:, 1 : self.num_register_tokens + 1], # teacher 128x0x1024
|
||||
"x_norm_patchtokens": y_reshaped[:, self.num_register_tokens + 1 :], # teacher 128x196x1024
|
||||
"x_prenorm": None,
|
||||
"masks": masks,
|
||||
}
|
||||
|
||||
def forward(self, x, *args, masks=None, **kwargs):
|
||||
if isinstance(x, list):
|
||||
return [self.forward_backbone(_x, _masks) for _x, _masks in zip(x, masks)]
|
||||
return self.forward_backbone(x, masks)
|
|
@ -212,14 +212,14 @@ class DinoVisionTransformer(nn.Module):
|
|||
|
||||
def prepare_tokens_with_masks(self, x, masks=None):
|
||||
B, nc, w, h = x.shape
|
||||
x = self.patch_embed(x)
|
||||
x = self.patch_embed(x) # teacher 128x3x224x224 -> 128x196x1024
|
||||
if masks is not None:
|
||||
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
|
||||
|
||||
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
||||
x = x + self.interpolate_pos_encoding(x, w, h)
|
||||
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) # teacher 128x197x1024
|
||||
x = x + self.interpolate_pos_encoding(x, w, h) # teacher 128x197x1024
|
||||
|
||||
if self.register_tokens is not None:
|
||||
if self.register_tokens is not None: # self.register_tokens is None
|
||||
x = torch.cat(
|
||||
(
|
||||
x[:, :1],
|
||||
|
@ -232,9 +232,9 @@ class DinoVisionTransformer(nn.Module):
|
|||
return x
|
||||
|
||||
def forward_features_list(self, x_list, masks_list):
|
||||
x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
|
||||
x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] # student 128x3x224x224 -> 128x197x1024, 512x3x96x96 -> 512x37x1024
|
||||
for blk in self.blocks:
|
||||
x = blk(x)
|
||||
x = blk(x) # student 128x197x1024, 512x37x1024
|
||||
|
||||
all_x = x
|
||||
output = []
|
||||
|
@ -242,30 +242,30 @@ class DinoVisionTransformer(nn.Module):
|
|||
x_norm = self.norm(x)
|
||||
output.append(
|
||||
{
|
||||
"x_norm_clstoken": x_norm[:, 0],
|
||||
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
||||
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
||||
"x_prenorm": x,
|
||||
"masks": masks,
|
||||
"x_norm_clstoken": x_norm[:, 0], # student 128x1024, 512x1024
|
||||
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], # student 128x0x1024, 512x0x1024
|
||||
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], # student 128x196x1024, 512x36x1024
|
||||
"x_prenorm": x, # student 128x197x1024, 512x37x1024
|
||||
"masks": masks, # student 128x196, None
|
||||
}
|
||||
)
|
||||
return output
|
||||
|
||||
def forward_features(self, x, masks=None):
|
||||
if isinstance(x, list):
|
||||
if isinstance(x, list): # student x [128x3x224x224, 512x3x96x96], masks [128x196, None]
|
||||
return self.forward_features_list(x, masks)
|
||||
|
||||
x = self.prepare_tokens_with_masks(x, masks)
|
||||
x = self.prepare_tokens_with_masks(x, masks) # teacher 128x3x224x224 -> 128x197x1024
|
||||
|
||||
for blk in self.blocks:
|
||||
x = blk(x)
|
||||
x = blk(x) # teacher 128x197x1024
|
||||
|
||||
x_norm = self.norm(x)
|
||||
x_norm = self.norm(x) # teacher 128x197x1024
|
||||
return {
|
||||
"x_norm_clstoken": x_norm[:, 0],
|
||||
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
||||
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
||||
"x_prenorm": x,
|
||||
"x_norm_clstoken": x_norm[:, 0], # teacher 128x1024
|
||||
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], # teacher 128x0x1024
|
||||
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], # teacher 128x196x1024
|
||||
"x_prenorm": x, # teacher 128x197x1024
|
||||
"masks": masks,
|
||||
}
|
||||
|
||||
|
|
|
@ -9,7 +9,7 @@ import os
|
|||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
import submitit
|
||||
# import submitit
|
||||
|
||||
from dinov2.utils.cluster import (
|
||||
get_slurm_executor_parameters,
|
||||
|
@ -61,16 +61,10 @@ def get_args_parser(
|
|||
help="Partition where to submit",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mem-per-gpu",
|
||||
default="60G",
|
||||
type=str,
|
||||
help="Memory per GPU",
|
||||
"--use-volta32",
|
||||
action="store_true",
|
||||
help="Request V100-32GB GPUs",
|
||||
)
|
||||
# parser.add_argument(
|
||||
# "--use-volta32",
|
||||
# action="store_true",
|
||||
# help="Request V100-32GB GPUs",
|
||||
# )
|
||||
parser.add_argument(
|
||||
"--comment",
|
||||
default="",
|
||||
|
@ -103,8 +97,8 @@ def submit_jobs(task_class, args, name: str):
|
|||
executor = submitit.AutoExecutor(folder=args.output_dir, slurm_max_num_timeout=30)
|
||||
|
||||
kwargs = {}
|
||||
# if args.use_volta32:
|
||||
# kwargs["slurm_constraint"] = "volta32gb"
|
||||
if args.use_volta32:
|
||||
kwargs["slurm_constraint"] = "volta32gb"
|
||||
if args.comment:
|
||||
kwargs["slurm_comment"] = args.comment
|
||||
if args.exclude:
|
||||
|
@ -116,11 +110,10 @@ def submit_jobs(task_class, args, name: str):
|
|||
timeout_min=args.timeout, # max is 60 * 72
|
||||
slurm_signal_delay_s=120,
|
||||
slurm_partition=args.partition,
|
||||
mem_per_gpu=args.mem_per_gpu,
|
||||
**kwargs,
|
||||
)
|
||||
executor.update_parameters(name=name, **executor_params)
|
||||
print(args, executor_params)
|
||||
|
||||
task = task_class(args)
|
||||
job = executor.submit(task)
|
||||
|
||||
|
|
|
@ -9,7 +9,8 @@ import sys
|
|||
|
||||
from dinov2.logging import setup_logging
|
||||
from dinov2.train import get_args_parser as get_train_args_parser
|
||||
from dinov2.run.submit import get_args_parser, submit_jobs
|
||||
# from dinov2.run.submit import get_args_parser, submit_jobs
|
||||
from dinov2.run.submit import get_args_parser
|
||||
|
||||
|
||||
logger = logging.getLogger("dinov2")
|
||||
|
@ -22,7 +23,7 @@ class Trainer(object):
|
|||
def __call__(self):
|
||||
from dinov2.train import main as train_main
|
||||
|
||||
self._setup_args()
|
||||
# self._setup_args()
|
||||
train_main(self.args)
|
||||
|
||||
def checkpoint(self):
|
||||
|
@ -51,7 +52,13 @@ def main():
|
|||
setup_logging()
|
||||
|
||||
assert os.path.exists(args.config_file), "Configuration file does not exist!"
|
||||
submit_jobs(Trainer, args, name="dinov2:train")
|
||||
print(args)
|
||||
# submit_jobs(Trainer, args, name="dinov2:train")
|
||||
|
||||
from dinov2.train import main as train_main
|
||||
|
||||
train_main(args)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
|
|
|
@ -1,431 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from functools import partial
|
||||
import logging
|
||||
import copy
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from dinov2.loss import DINOLoss, iBOTPatchLoss, KoLeoLoss
|
||||
from dinov2.models import build_model
|
||||
from dinov2.layers import DINOHead
|
||||
from dinov2.utils.utils import has_batchnorms
|
||||
from dinov2.utils.param_groups import get_params_groups_with_decay, fuse_params_groups
|
||||
from dinov2.fsdp import get_fsdp_wrapper, ShardedGradScaler, get_fsdp_modules, reshard_fsdp_model
|
||||
|
||||
from dinov2.models.vision_transformer import BlockChunk
|
||||
|
||||
try:
|
||||
from xformers.ops import fmha
|
||||
|
||||
XFORMERS_AVAILABLE = True
|
||||
except ImportError:
|
||||
XFORMERS_AVAILABLE = False
|
||||
assert XFORMERS_AVAILABLE, "xFormers is required for DINOv2 training"
|
||||
|
||||
|
||||
logger = logging.getLogger("dinov2")
|
||||
|
||||
|
||||
class DistillMetaArch(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
self.cfg = cfg
|
||||
self.fp16_scaler = ShardedGradScaler() if cfg.compute_precision.grad_scaler else None
|
||||
|
||||
student_model_dict = dict()
|
||||
teacher_model_dict = dict()
|
||||
|
||||
student_backbone, student_embed_dim = build_model(cfg.student, only_teacher=True, img_size=cfg.crops.global_crops_size) # pyright: ignore
|
||||
teacher_backbone, teacher_embed_dim = build_model(cfg.teacher, only_teacher=True, img_size=cfg.crops.global_crops_size) # pyright: ignore
|
||||
student_model_dict["backbone"] = student_backbone
|
||||
teacher_model_dict["backbone"] = teacher_backbone
|
||||
|
||||
# Student and teacher embedding dimensions can be different and therefore DINO head and IBOT heads should be different
|
||||
logger.info(f"OPTIONS -- architecture : student_embed_dim: {student_embed_dim}")
|
||||
logger.info(f"OPTIONS -- architecture : teacher_embed_dim: {teacher_embed_dim}")
|
||||
|
||||
self.student_embed_dim = student_embed_dim
|
||||
self.teacher_embed_dim = teacher_embed_dim
|
||||
self.dino_out_dim = cfg.dino.head_n_prototypes
|
||||
|
||||
self.do_dino = cfg.dino.loss_weight > 0
|
||||
self.do_koleo = cfg.dino.koleo_loss_weight > 0
|
||||
self.do_ibot = cfg.ibot.loss_weight > 0
|
||||
self.ibot_separate_head = cfg.ibot.separate_head
|
||||
|
||||
logger.info("OPTIONS -- DINO")
|
||||
if self.do_dino:
|
||||
logger.info(f"OPTIONS -- DINO -- loss_weight: {cfg.dino.loss_weight}")
|
||||
logger.info(f"OPTIONS -- DINO -- head_n_prototypes: {cfg.dino.head_n_prototypes}")
|
||||
logger.info(f"OPTIONS -- DINO -- head_bottleneck_dim: {cfg.dino.head_bottleneck_dim}")
|
||||
logger.info(f"OPTIONS -- DINO -- head_hidden_dim: {cfg.dino.head_hidden_dim}")
|
||||
self.dino_loss_weight = cfg.dino.loss_weight
|
||||
dino_head = partial(
|
||||
DINOHead,
|
||||
out_dim=cfg.dino.head_n_prototypes,
|
||||
hidden_dim=cfg.dino.head_hidden_dim,
|
||||
bottleneck_dim=cfg.dino.head_bottleneck_dim,
|
||||
nlayers=cfg.dino.head_nlayers,
|
||||
)
|
||||
|
||||
self.dino_loss = DINOLoss(self.dino_out_dim)
|
||||
if self.do_koleo:
|
||||
logger.info("OPTIONS -- DINO -- applying KOLEO regularization")
|
||||
self.koleo_loss = KoLeoLoss()
|
||||
|
||||
else:
|
||||
logger.info("OPTIONS -- DINO -- not using DINO")
|
||||
|
||||
if self.do_dino or self.do_ibot:
|
||||
student_model_dict["dino_head"] = dino_head(in_dim=student_embed_dim) # pyright: ignore
|
||||
teacher_model_dict["dino_head"] = dino_head(in_dim=teacher_embed_dim) # pyright: ignore
|
||||
|
||||
logger.info("OPTIONS -- IBOT")
|
||||
logger.info(f"OPTIONS -- IBOT -- loss_weight: {cfg.ibot.loss_weight}")
|
||||
logger.info(f"OPTIONS -- IBOT masking -- ibot_mask_ratio_tuple: {cfg.ibot.mask_ratio_min_max}")
|
||||
logger.info(f"OPTIONS -- IBOT masking -- ibot_mask_sample_probability: {cfg.ibot.mask_sample_probability}")
|
||||
if self.do_ibot:
|
||||
self.ibot_loss_weight = cfg.ibot.loss_weight
|
||||
assert max(cfg.ibot.mask_ratio_min_max) > 0, "please provide a positive mask ratio tuple for ibot"
|
||||
assert cfg.ibot.mask_sample_probability > 0, "please provide a positive mask probability for ibot"
|
||||
self.ibot_out_dim = cfg.ibot.head_n_prototypes if self.ibot_separate_head else cfg.dino.head_n_prototypes
|
||||
self.ibot_patch_loss = iBOTPatchLoss(self.ibot_out_dim)
|
||||
if self.ibot_separate_head:
|
||||
logger.info(f"OPTIONS -- IBOT -- loss_weight: {cfg.ibot.loss_weight}")
|
||||
logger.info(f"OPTIONS -- IBOT -- head_n_prototypes: {cfg.ibot.head_n_prototypes}")
|
||||
logger.info(f"OPTIONS -- IBOT -- head_bottleneck_dim: {cfg.ibot.head_bottleneck_dim}")
|
||||
logger.info(f"OPTIONS -- IBOT -- head_hidden_dim: {cfg.ibot.head_hidden_dim}")
|
||||
ibot_head = partial(
|
||||
DINOHead,
|
||||
out_dim=cfg.ibot.head_n_prototypes,
|
||||
hidden_dim=cfg.ibot.head_hidden_dim,
|
||||
bottleneck_dim=cfg.ibot.head_bottleneck_dim,
|
||||
nlayers=cfg.ibot.head_nlayers,
|
||||
)
|
||||
student_model_dict["ibot_head"] = ibot_head(in_dim=student_embed_dim)
|
||||
teacher_model_dict["ibot_head"] = ibot_head(in_dim=teacher_embed_dim)
|
||||
else:
|
||||
logger.info("OPTIONS -- IBOT -- head shared with DINO")
|
||||
|
||||
# self.need_to_synchronize_fsdp_streams = True
|
||||
|
||||
self.student = nn.ModuleDict(student_model_dict)
|
||||
self.teacher = nn.ModuleDict(teacher_model_dict)
|
||||
self.student_shadow = copy.deepcopy(self.student)
|
||||
|
||||
assert cfg.teacher.pretrained_weights is not None, "Must contain pretrained weights for distillation."
|
||||
chkpt = torch.load(cfg.teacher.pretrained_weights)
|
||||
logger.info(f"OPTIONS -- pretrained weights: loading from {cfg.teacher.pretrained_weights}")
|
||||
self.teacher.load_state_dict(chkpt["model"], strict=False)
|
||||
|
||||
# there is no backpropagation through the teacher, so no need for gradients
|
||||
for p in self.teacher.parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
# there is no backpropagation through the shadow copy either
|
||||
for p in self.student_shadow.parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
logger.info(f"Student is built: it is using {cfg.student.arch}")
|
||||
logger.info(f"Teacher is built: it is using {cfg.teacher.arch}")
|
||||
|
||||
def forward(self, inputs):
|
||||
raise NotImplementedError
|
||||
|
||||
def backprop_loss(self, loss):
|
||||
if self.fp16_scaler is not None:
|
||||
self.fp16_scaler.scale(loss).backward() # pyright: ignore
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
def forward_backward(self, images, teacher_temp):
|
||||
n_global_crops = 2
|
||||
assert n_global_crops == 2
|
||||
n_local_crops = self.cfg.crops.local_crops_number
|
||||
|
||||
global_crops = images["collated_global_crops"].cuda(non_blocking=True)
|
||||
local_crops = images["collated_local_crops"].cuda(non_blocking=True)
|
||||
|
||||
masks = images["collated_masks"].cuda(non_blocking=True)
|
||||
mask_indices_list = images["mask_indices_list"].cuda(non_blocking=True)
|
||||
n_masked_patches_tensor = images["n_masked_patches"].cuda(non_blocking=True)
|
||||
n_masked_patches = mask_indices_list.shape[0]
|
||||
upperbound = images["upperbound"]
|
||||
masks_weight = images["masks_weight"].cuda(non_blocking=True)
|
||||
|
||||
n_local_crops_loss_terms = max(n_local_crops * n_global_crops, 1)
|
||||
n_global_crops_loss_terms = (n_global_crops - 1) * n_global_crops
|
||||
|
||||
do_dino = self.do_dino
|
||||
do_ibot = self.do_ibot
|
||||
|
||||
# loss scales
|
||||
ibot_loss_scale = 1.0 / n_global_crops
|
||||
|
||||
# teacher output
|
||||
@torch.no_grad()
|
||||
def get_teacher_output():
|
||||
x, n_global_crops_teacher = global_crops, n_global_crops
|
||||
teacher_backbone_output_dict = self.teacher.backbone(x, is_training=True) # pyright: ignore
|
||||
teacher_cls_tokens = teacher_backbone_output_dict["x_norm_clstoken"]
|
||||
teacher_cls_tokens = teacher_cls_tokens.chunk(n_global_crops_teacher)
|
||||
# watch out: these are chunked and cat'd in reverse so A is matched to B in the global crops dino loss
|
||||
teacher_cls_tokens = torch.cat((teacher_cls_tokens[1], teacher_cls_tokens[0]))
|
||||
ibot_teacher_patch_tokens = teacher_backbone_output_dict["x_norm_patchtokens"]
|
||||
_dim = ibot_teacher_patch_tokens.shape[-1]
|
||||
n_cls_tokens = teacher_cls_tokens.shape[0]
|
||||
|
||||
if do_ibot and not self.ibot_separate_head:
|
||||
buffer_tensor_teacher = ibot_teacher_patch_tokens.new_zeros(upperbound + n_cls_tokens, _dim)
|
||||
buffer_tensor_teacher[:n_cls_tokens].copy_(teacher_cls_tokens)
|
||||
torch.index_select(
|
||||
ibot_teacher_patch_tokens.flatten(0, 1),
|
||||
dim=0,
|
||||
index=mask_indices_list,
|
||||
out=buffer_tensor_teacher[n_cls_tokens : n_cls_tokens + n_masked_patches],
|
||||
)
|
||||
tokens_after_head = self.teacher.dino_head(buffer_tensor_teacher)
|
||||
teacher_cls_tokens_after_head = tokens_after_head[:n_cls_tokens]
|
||||
masked_teacher_patch_tokens_after_head = tokens_after_head[
|
||||
n_cls_tokens : n_cls_tokens + n_masked_patches
|
||||
]
|
||||
elif do_ibot and self.ibot_separate_head:
|
||||
buffer_tensor_teacher = ibot_teacher_patch_tokens.new_zeros(upperbound, _dim)
|
||||
torch.index_select(
|
||||
ibot_teacher_patch_tokens.flatten(0, 1),
|
||||
dim=0,
|
||||
index=mask_indices_list,
|
||||
out=buffer_tensor_teacher[:n_masked_patches],
|
||||
)
|
||||
teacher_cls_tokens_after_head = self.teacher.dino_head(teacher_cls_tokens)
|
||||
masked_teacher_patch_tokens_after_head = self.teacher.ibot_head(buffer_tensor_teacher)[
|
||||
:n_masked_patches
|
||||
]
|
||||
else:
|
||||
teacher_cls_tokens_after_head = self.teacher.dino_head(teacher_cls_tokens)
|
||||
masked_teacher_ibot_softmaxed_centered = None
|
||||
|
||||
if self.cfg.train.centering == "centering":
|
||||
teacher_dino_softmaxed_centered_list = self.dino_loss.softmax_center_teacher(
|
||||
teacher_cls_tokens_after_head, teacher_temp=teacher_temp
|
||||
).view(n_global_crops_teacher, -1, *teacher_cls_tokens_after_head.shape[1:])
|
||||
self.dino_loss.update_center(teacher_cls_tokens_after_head)
|
||||
if do_ibot:
|
||||
masked_teacher_patch_tokens_after_head = masked_teacher_patch_tokens_after_head.unsqueeze(0)
|
||||
masked_teacher_ibot_softmaxed_centered = self.ibot_patch_loss.softmax_center_teacher(
|
||||
masked_teacher_patch_tokens_after_head[:, :n_masked_patches], teacher_temp=teacher_temp
|
||||
)
|
||||
masked_teacher_ibot_softmaxed_centered = masked_teacher_ibot_softmaxed_centered.squeeze(0)
|
||||
self.ibot_patch_loss.update_center(masked_teacher_patch_tokens_after_head[:n_masked_patches])
|
||||
|
||||
elif self.cfg.train.centering == "sinkhorn_knopp":
|
||||
teacher_dino_softmaxed_centered_list = self.dino_loss.sinkhorn_knopp_teacher(
|
||||
teacher_cls_tokens_after_head, teacher_temp=teacher_temp
|
||||
).view(n_global_crops_teacher, -1, *teacher_cls_tokens_after_head.shape[1:])
|
||||
|
||||
if do_ibot:
|
||||
masked_teacher_ibot_softmaxed_centered = self.ibot_patch_loss.sinkhorn_knopp_teacher(
|
||||
masked_teacher_patch_tokens_after_head,
|
||||
teacher_temp=teacher_temp,
|
||||
n_masked_patches_tensor=n_masked_patches_tensor,
|
||||
)
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return teacher_dino_softmaxed_centered_list, masked_teacher_ibot_softmaxed_centered
|
||||
|
||||
teacher_dino_softmaxed_centered_list, masked_teacher_ibot_softmaxed_centered = get_teacher_output()
|
||||
reshard_fsdp_model(self.teacher)
|
||||
|
||||
loss_dict = {}
|
||||
|
||||
loss_accumulator = 0 # for backprop
|
||||
student_global_backbone_output_dict, student_local_backbone_output_dict = self.student.backbone(
|
||||
[global_crops, local_crops], masks=[masks, None], is_training=True
|
||||
)
|
||||
|
||||
inputs_for_student_head_list = []
|
||||
|
||||
# 1a: local crops cls tokens
|
||||
student_local_cls_tokens = student_local_backbone_output_dict["x_norm_clstoken"]
|
||||
inputs_for_student_head_list.append(student_local_cls_tokens.unsqueeze(0))
|
||||
|
||||
# 1b: global crops cls tokens
|
||||
student_global_cls_tokens = student_global_backbone_output_dict["x_norm_clstoken"]
|
||||
inputs_for_student_head_list.append(student_global_cls_tokens.unsqueeze(0))
|
||||
|
||||
# 1c: global crops patch tokens
|
||||
if do_ibot:
|
||||
_dim = student_global_backbone_output_dict["x_norm_clstoken"].shape[-1]
|
||||
ibot_student_patch_tokens = student_global_backbone_output_dict["x_norm_patchtokens"]
|
||||
buffer_tensor_patch_tokens = ibot_student_patch_tokens.new_zeros(upperbound, _dim)
|
||||
buffer_tensor_patch_tokens[:n_masked_patches].copy_(
|
||||
torch.index_select(ibot_student_patch_tokens.flatten(0, 1), dim=0, index=mask_indices_list)
|
||||
)
|
||||
if not self.ibot_separate_head:
|
||||
inputs_for_student_head_list.append(buffer_tensor_patch_tokens.unsqueeze(0))
|
||||
else:
|
||||
student_global_masked_patch_tokens_after_head = self.student.ibot_head(buffer_tensor_patch_tokens)[
|
||||
:n_masked_patches
|
||||
]
|
||||
|
||||
# 2: run
|
||||
_attn_bias, cat_inputs = fmha.BlockDiagonalMask.from_tensor_list(inputs_for_student_head_list)
|
||||
outputs_list = _attn_bias.split(self.student.dino_head(cat_inputs))
|
||||
|
||||
# 3a: local crops cls tokens
|
||||
student_local_cls_tokens_after_head = outputs_list.pop(0).squeeze(0)
|
||||
|
||||
# 3b: global crops cls tokens
|
||||
student_global_cls_tokens_after_head = outputs_list.pop(0).squeeze(0)
|
||||
|
||||
# 3c: global crops patch tokens
|
||||
if do_ibot and not self.ibot_separate_head:
|
||||
student_global_masked_patch_tokens_after_head = outputs_list.pop(0).squeeze(0)[:n_masked_patches]
|
||||
|
||||
if n_local_crops > 0:
|
||||
dino_local_crops_loss = self.dino_loss(
|
||||
student_output_list=student_local_cls_tokens_after_head.chunk(n_local_crops),
|
||||
teacher_out_softmaxed_centered_list=teacher_dino_softmaxed_centered_list,
|
||||
) / (n_global_crops_loss_terms + n_local_crops_loss_terms)
|
||||
|
||||
# store for display
|
||||
loss_dict["dino_local_crops_loss"] = dino_local_crops_loss
|
||||
|
||||
# accumulate loss
|
||||
loss_accumulator += self.dino_loss_weight * dino_local_crops_loss
|
||||
|
||||
# process global crops
|
||||
loss_scales = 2 # this is here since we process global crops together
|
||||
|
||||
if do_dino:
|
||||
# compute loss
|
||||
dino_global_crops_loss = (
|
||||
self.dino_loss(
|
||||
student_output_list=[student_global_cls_tokens_after_head],
|
||||
teacher_out_softmaxed_centered_list=[
|
||||
teacher_dino_softmaxed_centered_list.flatten(0, 1)
|
||||
], # these were chunked and stacked in reverse so A is matched to B
|
||||
)
|
||||
* loss_scales
|
||||
/ (n_global_crops_loss_terms + n_local_crops_loss_terms)
|
||||
)
|
||||
|
||||
loss_dict["dino_global_crops_loss"] = dino_global_crops_loss
|
||||
|
||||
# accumulate loss
|
||||
loss_accumulator += self.dino_loss_weight * dino_global_crops_loss
|
||||
|
||||
student_cls_tokens = student_global_cls_tokens
|
||||
|
||||
if self.do_koleo:
|
||||
koleo_loss = self.cfg.dino.koleo_loss_weight * sum(
|
||||
self.koleo_loss(p) for p in student_cls_tokens.chunk(2)
|
||||
) # we don't apply koleo loss between cls tokens of a same image
|
||||
loss_accumulator += koleo_loss
|
||||
loss_dict["koleo_loss"] = (
|
||||
koleo_loss / loss_scales
|
||||
) # this is to display the same losses as before but we can remove eventually
|
||||
|
||||
if do_ibot:
|
||||
# compute loss
|
||||
ibot_patch_loss = (
|
||||
self.ibot_patch_loss.forward_masked(
|
||||
student_global_masked_patch_tokens_after_head,
|
||||
masked_teacher_ibot_softmaxed_centered,
|
||||
student_masks_flat=masks,
|
||||
n_masked_patches=n_masked_patches,
|
||||
masks_weight=masks_weight,
|
||||
)
|
||||
* loss_scales
|
||||
* ibot_loss_scale
|
||||
)
|
||||
|
||||
# store for display
|
||||
loss_dict["ibot_loss"] = ibot_patch_loss / 2
|
||||
|
||||
# accumulate loss
|
||||
loss_accumulator += self.ibot_loss_weight * ibot_patch_loss
|
||||
|
||||
self.backprop_loss(loss_accumulator)
|
||||
|
||||
# self.fsdp_synchronize_streams()
|
||||
|
||||
return loss_dict
|
||||
|
||||
# def fsdp_synchronize_streams(self):
|
||||
# if self.need_to_synchronize_fsdp_streams:
|
||||
# torch.cuda.synchronize()
|
||||
# self.student.dino_head._streams = (
|
||||
# self.teacher.dino_head._streams
|
||||
# ) = self.student.backbone._streams = self.teacher.backbone._streams
|
||||
# self.need_to_synchronize_fsdp_streams = False
|
||||
|
||||
# def update_teacher(self, m):
|
||||
# student_param_list = []
|
||||
# teacher_param_list = []
|
||||
# with torch.no_grad():
|
||||
# for k in self.student.keys():
|
||||
# for ms, mt in zip(get_fsdp_modules(self.student[k]), get_fsdp_modules(self.teacher[k])):
|
||||
# student_param_list += ms.params
|
||||
# teacher_param_list += mt.params
|
||||
# torch._foreach_mul_(teacher_param_list, m)
|
||||
# torch._foreach_add_(teacher_param_list, student_param_list, alpha=1 - m)
|
||||
|
||||
def update_student_shadow(self, m):
|
||||
student_param_list = []
|
||||
student_shadow_param_list = []
|
||||
with torch.no_grad():
|
||||
for k in self.student.keys():
|
||||
for ms, mss in zip(get_fsdp_modules(self.student[k]), get_fsdp_modules(self.student_shadow[k])):
|
||||
student_param_list += ms.params
|
||||
student_shadow_param_list += mss.params
|
||||
torch._foreach_mul_(student_shadow_param_list, m)
|
||||
torch._foreach_add_(student_shadow_param_list, student_param_list, alpha=1 - m)
|
||||
|
||||
def train(self):
|
||||
super().train()
|
||||
self.teacher.eval()
|
||||
self.student_shadow.eval()
|
||||
|
||||
def get_maybe_fused_params_for_submodel(self, m):
|
||||
params_groups = get_params_groups_with_decay(
|
||||
model=m,
|
||||
lr_decay_rate=self.cfg.optim.layerwise_decay,
|
||||
patch_embed_lr_mult=self.cfg.optim.patch_embed_lr_mult,
|
||||
)
|
||||
fused_params_groups = fuse_params_groups(params_groups)
|
||||
logger.info("fusing param groups")
|
||||
|
||||
for g in fused_params_groups:
|
||||
g["foreach"] = True
|
||||
return fused_params_groups
|
||||
|
||||
def get_params_groups(self):
|
||||
all_params_groups = []
|
||||
for m in self.student.values():
|
||||
all_params_groups += self.get_maybe_fused_params_for_submodel(m)
|
||||
return all_params_groups
|
||||
|
||||
def prepare_for_distributed_training(self):
|
||||
logger.info("DISTRIBUTED FSDP -- preparing model for distributed training")
|
||||
if has_batchnorms(self.student):
|
||||
raise NotImplementedError
|
||||
|
||||
# below will synchronize all student subnetworks across gpus:
|
||||
for k, v in self.student.items():
|
||||
self.student_shadow[k].load_state_dict(self.student[k].state_dict())
|
||||
# self.teacher[k].load_state_dict(self.student[k].state_dict())
|
||||
student_model_cfg = self.cfg.compute_precision.student[k]
|
||||
self.student[k] = get_fsdp_wrapper(student_model_cfg, modules_to_wrap={BlockChunk})(self.student[k])
|
||||
|
||||
for k, v in self.teacher.items():
|
||||
teacher_model_cfg = self.cfg.compute_precision.teacher[k]
|
||||
self.teacher[k] = get_fsdp_wrapper(teacher_model_cfg, modules_to_wrap={BlockChunk})(self.teacher[k])
|
|
@ -37,7 +37,7 @@ class SSLMetaArch(nn.Module):
|
|||
student_model_dict = dict()
|
||||
teacher_model_dict = dict()
|
||||
|
||||
student_backbone, teacher_backbone, embed_dim = build_model_from_cfg(cfg)
|
||||
student_backbone, teacher_backbone, embed_dim = build_model_from_cfg(cfg) # embed_dim=1024
|
||||
student_model_dict["backbone"] = student_backbone
|
||||
teacher_model_dict["backbone"] = teacher_backbone
|
||||
logger.info(f"OPTIONS -- architecture : embed_dim: {embed_dim}")
|
||||
|
@ -47,13 +47,13 @@ class SSLMetaArch(nn.Module):
|
|||
logger.info(f"OPTIONS -- pretrained weights: loading from {cfg.student.pretrained_weights}")
|
||||
student_backbone.load_state_dict(chkpt["model"], strict=False)
|
||||
|
||||
self.embed_dim = embed_dim
|
||||
self.dino_out_dim = cfg.dino.head_n_prototypes
|
||||
self.embed_dim = embed_dim # 1024
|
||||
self.dino_out_dim = cfg.dino.head_n_prototypes # 65536
|
||||
|
||||
self.do_dino = cfg.dino.loss_weight > 0
|
||||
self.do_koleo = cfg.dino.koleo_loss_weight > 0
|
||||
self.do_ibot = cfg.ibot.loss_weight > 0
|
||||
self.ibot_separate_head = cfg.ibot.separate_head
|
||||
self.do_dino = cfg.dino.loss_weight > 0 # True
|
||||
self.do_koleo = cfg.dino.koleo_loss_weight > 0 # True
|
||||
self.do_ibot = cfg.ibot.loss_weight > 0 # True
|
||||
self.ibot_separate_head = cfg.ibot.separate_head # False
|
||||
|
||||
logger.info("OPTIONS -- DINO")
|
||||
if self.do_dino:
|
||||
|
@ -81,7 +81,7 @@ class SSLMetaArch(nn.Module):
|
|||
if self.do_dino or self.do_ibot:
|
||||
student_model_dict["dino_head"] = dino_head()
|
||||
teacher_model_dict["dino_head"] = dino_head()
|
||||
|
||||
|
||||
logger.info("OPTIONS -- IBOT")
|
||||
logger.info(f"OPTIONS -- IBOT -- loss_weight: {cfg.ibot.loss_weight}")
|
||||
logger.info(f"OPTIONS -- IBOT masking -- ibot_mask_ratio_tuple: {cfg.ibot.mask_ratio_min_max}")
|
||||
|
@ -134,18 +134,18 @@ class SSLMetaArch(nn.Module):
|
|||
assert n_global_crops == 2
|
||||
n_local_crops = self.cfg.crops.local_crops_number
|
||||
|
||||
global_crops = images["collated_global_crops"].cuda(non_blocking=True)
|
||||
local_crops = images["collated_local_crops"].cuda(non_blocking=True)
|
||||
global_crops = images["collated_global_crops"].cuda(non_blocking=True) # 128x3x224x224
|
||||
local_crops = images["collated_local_crops"].cuda(non_blocking=True) # 512x3x96x96
|
||||
|
||||
masks = images["collated_masks"].cuda(non_blocking=True)
|
||||
mask_indices_list = images["mask_indices_list"].cuda(non_blocking=True)
|
||||
n_masked_patches_tensor = images["n_masked_patches"].cuda(non_blocking=True)
|
||||
n_masked_patches = mask_indices_list.shape[0]
|
||||
upperbound = images["upperbound"]
|
||||
masks_weight = images["masks_weight"].cuda(non_blocking=True)
|
||||
masks = images["collated_masks"].cuda(non_blocking=True) # 128x196
|
||||
mask_indices_list = images["mask_indices_list"].cuda(non_blocking=True) # 3730x
|
||||
n_masked_patches_tensor = images["n_masked_patches"].cuda(non_blocking=True) # 1x
|
||||
n_masked_patches = mask_indices_list.shape[0] # 3730
|
||||
upperbound = images["upperbound"] # 3771
|
||||
masks_weight = images["masks_weight"].cuda(non_blocking=True) # 3730x
|
||||
|
||||
n_local_crops_loss_terms = max(n_local_crops * n_global_crops, 1)
|
||||
n_global_crops_loss_terms = (n_global_crops - 1) * n_global_crops
|
||||
n_local_crops_loss_terms = max(n_local_crops * n_global_crops, 1) # 16
|
||||
n_global_crops_loss_terms = (n_global_crops - 1) * n_global_crops # 2
|
||||
|
||||
do_dino = self.do_dino
|
||||
do_ibot = self.do_ibot
|
||||
|
@ -226,7 +226,7 @@ class SSLMetaArch(nn.Module):
|
|||
|
||||
return teacher_dino_softmaxed_centered_list, masked_teacher_ibot_softmaxed_centered
|
||||
|
||||
teacher_dino_softmaxed_centered_list, masked_teacher_ibot_softmaxed_centered = get_teacher_output()
|
||||
teacher_dino_softmaxed_centered_list, masked_teacher_ibot_softmaxed_centered = get_teacher_output() # [64x65536, 64x65536], 3730x65536
|
||||
reshard_fsdp_model(self.teacher)
|
||||
|
||||
loss_dict = {}
|
||||
|
|
|
@ -22,6 +22,10 @@ from dinov2.utils.utils import CosineScheduler
|
|||
|
||||
from dinov2.train.ssl_meta_arch import SSLMetaArch
|
||||
|
||||
import sys
|
||||
sys.path.append('/home/li.yu/code/JupiterCVML/europa/base/src/europa')
|
||||
from dl.network.nextvit_brt import _get_nextvit
|
||||
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = True # PyTorch 1.12 sets this to False by default
|
||||
logger = logging.getLogger("dinov2")
|
||||
|
@ -54,6 +58,9 @@ For python-based LazyConfig, use "path.key=value".
|
|||
type=str,
|
||||
help="Output directory to save logs and checkpoints",
|
||||
)
|
||||
parser.add_argument('--world_size', default=1, type=int,
|
||||
help='number of distributed processes')
|
||||
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
|
||||
|
||||
return parser
|
||||
|
||||
|
@ -133,7 +140,7 @@ def do_test(cfg, model, iteration):
|
|||
|
||||
def do_train(cfg, model, resume=False):
|
||||
model.train()
|
||||
inputs_dtype = torch.half
|
||||
inputs_dtype = torch.half # torch.float16
|
||||
fp16_scaler = model.fp16_scaler # for mixed precision training
|
||||
|
||||
# setup optimizer
|
||||
|
@ -152,8 +159,8 @@ def do_train(cfg, model, resume=False):
|
|||
|
||||
start_iter = checkpointer.resume_or_load(cfg.MODEL.WEIGHTS, resume=resume).get("iteration", -1) + 1
|
||||
|
||||
OFFICIAL_EPOCH_LENGTH = cfg.train.OFFICIAL_EPOCH_LENGTH
|
||||
max_iter = cfg.optim.epochs * OFFICIAL_EPOCH_LENGTH
|
||||
OFFICIAL_EPOCH_LENGTH = cfg.train.OFFICIAL_EPOCH_LENGTH # 1250
|
||||
max_iter = cfg.optim.epochs * OFFICIAL_EPOCH_LENGTH # 125000
|
||||
|
||||
periodic_checkpointer = PeriodicCheckpointer(
|
||||
checkpointer,
|
||||
|
@ -164,9 +171,9 @@ def do_train(cfg, model, resume=False):
|
|||
|
||||
# setup data preprocessing
|
||||
|
||||
img_size = cfg.crops.global_crops_size
|
||||
patch_size = cfg.student.patch_size
|
||||
n_tokens = (img_size // patch_size) ** 2
|
||||
img_size = cfg.crops.global_crops_size # 224
|
||||
patch_size = cfg.student.patch_size # 16
|
||||
n_tokens = (img_size // patch_size) ** 2 # 196
|
||||
mask_generator = MaskingGenerator(
|
||||
input_size=(img_size // patch_size, img_size // patch_size),
|
||||
max_num_patches=0.5 * img_size // patch_size * img_size // patch_size,
|
||||
|
@ -300,6 +307,18 @@ def main(args):
|
|||
model = SSLMetaArch(cfg).to(torch.device("cuda"))
|
||||
model.prepare_for_distributed_training()
|
||||
|
||||
# model = _get_nextvit(
|
||||
# model_size="small",
|
||||
# frozen_stages=-1,
|
||||
# norm_eval=False,
|
||||
# with_extra_norm=True,
|
||||
# norm_cfg=dict(type="SyncBN", requires_grad=True),
|
||||
# in_channels=3,
|
||||
# )
|
||||
# print('tunable parameters', sum(p.numel() for p in model.parameters() if p.requires_grad))
|
||||
# if args.distributed:
|
||||
# model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
|
||||
|
||||
logger.info("Model:\n{}".format(model))
|
||||
if args.eval_only:
|
||||
iteration = (
|
||||
|
|
|
@ -1,319 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from functools import partial
|
||||
|
||||
from fvcore.common.checkpoint import PeriodicCheckpointer
|
||||
import torch
|
||||
|
||||
from dinov2.data import SamplerType, make_data_loader, make_dataset
|
||||
from dinov2.data import collate_data_and_cast, DataAugmentationDINO, MaskingGenerator
|
||||
import dinov2.distributed as distributed
|
||||
from dinov2.fsdp import FSDPCheckpointer
|
||||
from dinov2.logging import MetricLogger
|
||||
from dinov2.utils.config import setup
|
||||
from dinov2.utils.utils import CosineScheduler
|
||||
|
||||
from dinov2.train.distill_meta_arch import DistillMetaArch
|
||||
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = True # PyTorch 1.12 sets this to False by default
|
||||
logger = logging.getLogger("dinov2")
|
||||
|
||||
|
||||
def get_args_parser(add_help: bool = True):
|
||||
parser = argparse.ArgumentParser("DINOv2 training", add_help=add_help)
|
||||
parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file")
|
||||
parser.add_argument(
|
||||
"--no-resume",
|
||||
action="store_true",
|
||||
help="Whether to not attempt to resume from the checkpoint directory. ",
|
||||
)
|
||||
parser.add_argument("--eval-only", action="store_true", help="perform evaluation only")
|
||||
parser.add_argument("--eval", type=str, default="", help="Eval type to perform")
|
||||
parser.add_argument(
|
||||
"opts",
|
||||
help="""
|
||||
Modify config options at the end of the command. For Yacs configs, use
|
||||
space-separated "PATH.KEY VALUE" pairs.
|
||||
For python-based LazyConfig, use "path.key=value".
|
||||
""".strip(),
|
||||
default=None,
|
||||
nargs=argparse.REMAINDER,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
"--output_dir",
|
||||
default="",
|
||||
type=str,
|
||||
help="Output directory to save logs and checkpoints",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def build_optimizer(cfg, params_groups):
|
||||
return torch.optim.AdamW(params_groups, betas=(cfg.optim.adamw_beta1, cfg.optim.adamw_beta2))
|
||||
|
||||
|
||||
def build_schedulers(cfg):
|
||||
OFFICIAL_EPOCH_LENGTH = cfg.train.OFFICIAL_EPOCH_LENGTH
|
||||
lr = dict(
|
||||
base_value=cfg.optim["lr"],
|
||||
final_value=cfg.optim["min_lr"],
|
||||
total_iters=cfg.optim["epochs"] * OFFICIAL_EPOCH_LENGTH,
|
||||
warmup_iters=cfg.optim["warmup_epochs"] * OFFICIAL_EPOCH_LENGTH,
|
||||
start_warmup_value=0,
|
||||
)
|
||||
wd = dict(
|
||||
base_value=cfg.optim["weight_decay"],
|
||||
final_value=cfg.optim["weight_decay_end"],
|
||||
total_iters=cfg.optim["epochs"] * OFFICIAL_EPOCH_LENGTH,
|
||||
)
|
||||
momentum = dict(
|
||||
base_value=cfg.teacher["momentum_teacher"],
|
||||
final_value=cfg.teacher["final_momentum_teacher"],
|
||||
total_iters=cfg.optim["epochs"] * OFFICIAL_EPOCH_LENGTH,
|
||||
)
|
||||
teacher_temp = dict(
|
||||
base_value=cfg.teacher["teacher_temp"],
|
||||
final_value=cfg.teacher["teacher_temp"],
|
||||
total_iters=cfg.teacher["warmup_teacher_temp_epochs"] * OFFICIAL_EPOCH_LENGTH,
|
||||
warmup_iters=cfg.teacher["warmup_teacher_temp_epochs"] * OFFICIAL_EPOCH_LENGTH,
|
||||
start_warmup_value=cfg.teacher["warmup_teacher_temp"],
|
||||
)
|
||||
|
||||
lr_schedule = CosineScheduler(**lr)
|
||||
wd_schedule = CosineScheduler(**wd)
|
||||
momentum_schedule = CosineScheduler(**momentum)
|
||||
teacher_temp_schedule = CosineScheduler(**teacher_temp)
|
||||
last_layer_lr_schedule = CosineScheduler(**lr)
|
||||
|
||||
last_layer_lr_schedule.schedule[
|
||||
: cfg.optim["freeze_last_layer_epochs"] * OFFICIAL_EPOCH_LENGTH
|
||||
] = 0 # mimicking the original schedules
|
||||
|
||||
logger.info("Schedulers ready.")
|
||||
|
||||
return (
|
||||
lr_schedule,
|
||||
wd_schedule,
|
||||
momentum_schedule,
|
||||
teacher_temp_schedule,
|
||||
last_layer_lr_schedule,
|
||||
)
|
||||
|
||||
|
||||
def apply_optim_scheduler(optimizer, lr, wd, last_layer_lr):
|
||||
for param_group in optimizer.param_groups:
|
||||
is_last_layer = param_group["is_last_layer"]
|
||||
lr_multiplier = param_group["lr_multiplier"]
|
||||
wd_multiplier = param_group["wd_multiplier"]
|
||||
param_group["weight_decay"] = wd * wd_multiplier
|
||||
param_group["lr"] = (last_layer_lr if is_last_layer else lr) * lr_multiplier
|
||||
|
||||
|
||||
def do_test(cfg, model, iteration):
|
||||
new_state_dict = model.teacher.state_dict()
|
||||
|
||||
if distributed.is_main_process():
|
||||
iterstring = str(iteration)
|
||||
eval_dir = os.path.join(cfg.train.output_dir, "eval", iterstring)
|
||||
os.makedirs(eval_dir, exist_ok=True)
|
||||
# save teacher checkpoint
|
||||
teacher_ckp_path = os.path.join(eval_dir, "teacher_checkpoint.pth")
|
||||
torch.save({"teacher": new_state_dict}, teacher_ckp_path)
|
||||
|
||||
|
||||
def do_train(cfg, model, resume=False):
|
||||
model.train()
|
||||
inputs_dtype = torch.half
|
||||
fp16_scaler = model.fp16_scaler # for mixed precision training
|
||||
|
||||
# setup optimizer
|
||||
|
||||
optimizer = build_optimizer(cfg, model.get_params_groups())
|
||||
(
|
||||
lr_schedule,
|
||||
wd_schedule,
|
||||
momentum_schedule,
|
||||
teacher_temp_schedule,
|
||||
last_layer_lr_schedule,
|
||||
) = build_schedulers(cfg)
|
||||
|
||||
# checkpointer
|
||||
checkpointer = FSDPCheckpointer(model, cfg.train.output_dir, optimizer=optimizer, save_to_disk=True)
|
||||
|
||||
start_iter = checkpointer.resume_or_load(cfg.MODEL.WEIGHTS, resume=resume).get("iteration", -1) + 1
|
||||
|
||||
OFFICIAL_EPOCH_LENGTH = cfg.train.OFFICIAL_EPOCH_LENGTH
|
||||
max_iter = cfg.optim.epochs * OFFICIAL_EPOCH_LENGTH
|
||||
|
||||
periodic_checkpointer = PeriodicCheckpointer(
|
||||
checkpointer,
|
||||
period=3 * OFFICIAL_EPOCH_LENGTH,
|
||||
max_iter=max_iter,
|
||||
max_to_keep=3,
|
||||
)
|
||||
|
||||
# setup data preprocessing
|
||||
|
||||
img_size = cfg.crops.global_crops_size
|
||||
patch_size = cfg.student.patch_size
|
||||
n_tokens = (img_size // patch_size) ** 2
|
||||
mask_generator = MaskingGenerator(
|
||||
input_size=(img_size // patch_size, img_size // patch_size),
|
||||
max_num_patches=0.5 * img_size // patch_size * img_size // patch_size,
|
||||
)
|
||||
|
||||
data_transform = DataAugmentationDINO(
|
||||
cfg.crops.global_crops_scale,
|
||||
cfg.crops.local_crops_scale,
|
||||
cfg.crops.local_crops_number,
|
||||
global_crops_size=cfg.crops.global_crops_size,
|
||||
local_crops_size=cfg.crops.local_crops_size,
|
||||
)
|
||||
|
||||
collate_fn = partial(
|
||||
collate_data_and_cast,
|
||||
mask_ratio_tuple=cfg.ibot.mask_ratio_min_max,
|
||||
mask_probability=cfg.ibot.mask_sample_probability,
|
||||
n_tokens=n_tokens,
|
||||
mask_generator=mask_generator,
|
||||
dtype=inputs_dtype,
|
||||
)
|
||||
|
||||
# setup data loader
|
||||
|
||||
dataset = make_dataset(
|
||||
dataset_str=cfg.train.dataset_path,
|
||||
transform=data_transform,
|
||||
target_transform=lambda _: (),
|
||||
)
|
||||
# sampler_type = SamplerType.INFINITE
|
||||
sampler_type = SamplerType.SHARDED_INFINITE
|
||||
data_loader = make_data_loader(
|
||||
dataset=dataset,
|
||||
batch_size=cfg.train.batch_size_per_gpu,
|
||||
num_workers=cfg.train.num_workers,
|
||||
shuffle=True,
|
||||
seed=start_iter, # TODO: Fix this -- cfg.train.seed
|
||||
sampler_type=sampler_type,
|
||||
sampler_advance=0, # TODO(qas): fix this -- start_iter * cfg.train.batch_size_per_gpu,
|
||||
drop_last=True,
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
|
||||
# training loop
|
||||
|
||||
iteration = start_iter
|
||||
|
||||
logger.info("Starting training from iteration {}".format(start_iter))
|
||||
metrics_file = os.path.join(cfg.train.output_dir, "training_metrics.json")
|
||||
metric_logger = MetricLogger(delimiter=" ", output_file=metrics_file)
|
||||
header = "Training"
|
||||
|
||||
for data in metric_logger.log_every(
|
||||
data_loader,
|
||||
10,
|
||||
header,
|
||||
max_iter,
|
||||
start_iter,
|
||||
):
|
||||
current_batch_size = data["collated_global_crops"].shape[0] / 2
|
||||
if iteration > max_iter:
|
||||
return
|
||||
|
||||
# apply schedules
|
||||
|
||||
lr = lr_schedule[iteration]
|
||||
wd = wd_schedule[iteration]
|
||||
mom = momentum_schedule[iteration]
|
||||
teacher_temp = teacher_temp_schedule[iteration]
|
||||
last_layer_lr = last_layer_lr_schedule[iteration]
|
||||
apply_optim_scheduler(optimizer, lr, wd, last_layer_lr)
|
||||
|
||||
# compute losses
|
||||
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
loss_dict = model.forward_backward(data, teacher_temp=teacher_temp)
|
||||
|
||||
# clip gradients
|
||||
|
||||
if fp16_scaler is not None:
|
||||
if cfg.optim.clip_grad:
|
||||
fp16_scaler.unscale_(optimizer)
|
||||
for v in model.student.values():
|
||||
v.clip_grad_norm_(cfg.optim.clip_grad)
|
||||
fp16_scaler.step(optimizer)
|
||||
fp16_scaler.update()
|
||||
else:
|
||||
if cfg.optim.clip_grad:
|
||||
for v in model.student.values():
|
||||
v.clip_grad_norm_(cfg.optim.clip_grad)
|
||||
optimizer.step()
|
||||
|
||||
# perform teacher EMA update
|
||||
|
||||
model.update_student_shadow(mom)
|
||||
|
||||
# logging
|
||||
|
||||
if distributed.get_global_size() > 1:
|
||||
for v in loss_dict.values():
|
||||
torch.distributed.all_reduce(v)
|
||||
loss_dict_reduced = {k: v.item() / distributed.get_global_size() for k, v in loss_dict.items()}
|
||||
|
||||
if math.isnan(sum(loss_dict_reduced.values())):
|
||||
logger.info("NaN detected")
|
||||
raise AssertionError
|
||||
losses_reduced = sum(loss for loss in loss_dict_reduced.values())
|
||||
|
||||
metric_logger.update(lr=lr)
|
||||
metric_logger.update(wd=wd)
|
||||
metric_logger.update(mom=mom)
|
||||
metric_logger.update(last_layer_lr=last_layer_lr)
|
||||
metric_logger.update(current_batch_size=current_batch_size)
|
||||
metric_logger.update(total_loss=losses_reduced, **loss_dict_reduced)
|
||||
|
||||
# checkpointing and testing
|
||||
|
||||
if cfg.evaluation.eval_period_iterations > 0 and (iteration + 1) % cfg.evaluation.eval_period_iterations == 0:
|
||||
do_test(cfg, model, f"training_{iteration}")
|
||||
torch.cuda.synchronize()
|
||||
periodic_checkpointer.step(iteration)
|
||||
|
||||
iteration = iteration + 1
|
||||
metric_logger.synchronize_between_processes()
|
||||
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
||||
|
||||
|
||||
def main(args):
|
||||
cfg = setup(args)
|
||||
|
||||
model = DistillMetaArch(cfg).to(torch.device("cuda"))
|
||||
model.prepare_for_distributed_training()
|
||||
|
||||
logger.info("Model:\n{}".format(model))
|
||||
if args.eval_only:
|
||||
iteration = (
|
||||
FSDPCheckpointer(model, save_dir=cfg.train.output_dir)
|
||||
.resume_or_load(cfg.MODEL.WEIGHTS, resume=not args.no_resume)
|
||||
.get("iteration", -1)
|
||||
+ 1
|
||||
)
|
||||
return do_test(cfg, model, f"manual_{iteration}")
|
||||
|
||||
do_train(cfg, model, resume=not args.no_resume)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args_parser(add_help=True).parse_args()
|
||||
main(args)
|
|
@ -76,20 +76,20 @@ def get_slurm_executor_parameters(
|
|||
) -> Dict[str, Any]:
|
||||
# create default parameters
|
||||
params = {
|
||||
# "mem_gb": 0, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html
|
||||
"mem_gb": 0, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html
|
||||
"gpus_per_node": num_gpus_per_node,
|
||||
"tasks_per_node": num_gpus_per_node, # one task per GPU
|
||||
"cpus_per_gpu": 7,
|
||||
"cpus_per_task": 10,
|
||||
"nodes": nodes,
|
||||
"slurm_partition": get_slurm_partition(cluster_type),
|
||||
}
|
||||
# apply cluster-specific adjustments
|
||||
cluster_type = get_cluster_type(cluster_type)
|
||||
if cluster_type == ClusterType.AWS:
|
||||
params["cpus_per_gpu"] = 12
|
||||
params["cpus_per_task"] = 12
|
||||
del params["mem_gb"]
|
||||
elif cluster_type == ClusterType.RSC:
|
||||
params["cpus_per_gpu"] = 12
|
||||
params["cpus_per_task"] = 12
|
||||
# set additional parameters / apply overrides
|
||||
params.update(kwargs)
|
||||
return params
|
||||
|
|
|
@ -7,6 +7,7 @@ import math
|
|||
import logging
|
||||
import os
|
||||
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
import dinov2.distributed as distributed
|
||||
|
@ -46,6 +47,44 @@ def get_cfg_from_args(args):
|
|||
return cfg
|
||||
|
||||
|
||||
def setup_for_distributed(is_master):
|
||||
"""
|
||||
This function disables printing when not in master process
|
||||
"""
|
||||
import builtins as __builtin__
|
||||
builtin_print = __builtin__.print
|
||||
|
||||
def print(*args, **kwargs):
|
||||
force = kwargs.pop('force', False)
|
||||
if is_master or force:
|
||||
builtin_print(*args, **kwargs)
|
||||
|
||||
__builtin__.print = print
|
||||
|
||||
def init_distributed_mode(args):
|
||||
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
||||
args.rank = int(os.environ["RANK"])
|
||||
args.world_size = int(os.environ['WORLD_SIZE'])
|
||||
args.gpu = int(os.environ['LOCAL_RANK'])
|
||||
elif 'SLURM_PROCID' in os.environ:
|
||||
args.rank = int(os.environ['SLURM_PROCID'])
|
||||
args.gpu = args.rank % torch.cuda.device_count()
|
||||
else:
|
||||
print('Not using distributed mode')
|
||||
args.distributed = False
|
||||
return
|
||||
|
||||
args.distributed = True
|
||||
|
||||
torch.cuda.set_device(args.gpu)
|
||||
args.dist_backend = 'nccl'
|
||||
print('| distributed init (rank {}): {}'.format(
|
||||
args.rank, args.dist_url), flush=True)
|
||||
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
||||
world_size=args.world_size, rank=args.rank)
|
||||
torch.distributed.barrier()
|
||||
setup_for_distributed(args.rank == 0)
|
||||
|
||||
def default_setup(args):
|
||||
distributed.enable(overwrite=True)
|
||||
seed = getattr(args, "seed", 0)
|
||||
|
@ -66,7 +105,8 @@ def setup(args):
|
|||
"""
|
||||
cfg = get_cfg_from_args(args)
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
default_setup(args)
|
||||
# default_setup(args)
|
||||
init_distributed_mode(args)
|
||||
apply_scaling_rules_to_cfg(cfg)
|
||||
write_config(cfg, args.output_dir)
|
||||
return cfg
|
||||
|
|
Loading…
Reference in New Issue