add configuration for merge block indexes

pull/511/head
Veronikkkka 2025-03-10 17:13:41 +00:00
parent 050e5677c8
commit e68f8e290a
5 changed files with 45 additions and 13 deletions

View File

@ -56,6 +56,7 @@ train:
ffn_layer: swiglufused
block_chunks: 0 # for distributed training
num_register_tokens: 0 # 0 for no register tokens
teacher:
momentum_teacher: 0.994
@ -77,6 +78,7 @@ evaluation:
student:
arch: vit_base
patch_size: 14
merge_block_indexes: "1,3,7,11" # num, num, num,
crops:
global_crops_scale:
- 0.32 #0.32 default

View File

@ -13,6 +13,7 @@ logger = logging.getLogger("dinov2")
def build_model(args, only_teacher=False, img_size=224):
args.arch = args.arch.removesuffix("_memeff")
print("ARGS", args.merge_block_indexes)
if "vit" in args.arch:
vit_kwargs = dict(
img_size=img_size,
@ -26,6 +27,7 @@ def build_model(args, only_teacher=False, img_size=224):
num_register_tokens=args.num_register_tokens,
interpolate_offset=args.interpolate_offset,
interpolate_antialias=args.interpolate_antialias,
merge_blocks_indexes=args.merge_block_indexes,
)
teacher = vits.__dict__[args.arch](**vit_kwargs)
if only_teacher:

View File

@ -25,12 +25,12 @@ class Merge_block(BaseModule):
self.ada_c = ada_c
# 784 - embedded dim + adapter_c
self.embeded_dim = 768
self.fc_1 = nn.Linear(self.embeded_dim*2, mid_c)
self.fc_2 = nn.Linear(mid_c, self.embeded_dim)
self.fc_1 = nn.Linear(self.embeded_dim*2, mid_c).to(torch.float16)
self.fc_2 = nn.Linear(mid_c, self.embeded_dim).to(torch.float16)
self.return_ada = return_ada
if self.return_ada:
self.conv_3 = nn.Conv1d(mid_c, self.embeded_dim, kernel_size=1) # 1D Conv instead of 3x3
self.conv_3 = nn.Conv1d(mid_c, self.embeded_dim, kernel_size=1).to(torch.float16) # 1D Conv instead of 3x3
else:
self.conv_3 = None

View File

@ -77,6 +77,7 @@ class DinoVisionTransformer(nn.Module):
fea_c_s = [384, 768, 1920],
ada_c_s = [16, 32, 64],
mid_c_s = [384, 576, 768],
merge_blocks_indexes=[],
):
"""
Args:
@ -203,13 +204,25 @@ class DinoVisionTransformer(nn.Module):
print("Loading input-level adapter:", input_level_adapter_path)
adapter_state = torch.load(input_level_adapter_path, map_location="cpu")
self.pre_encoder.load_state_dict(adapter_state)
self.merge_1 = Merge_block(fea_c=fea_c_s[0], ada_c=ada_c_s[0], mid_c=mid_c_s[0], return_ada=True)
self.merge_2 = Merge_block(fea_c=fea_c_s[1], ada_c=ada_c_s[1], mid_c=mid_c_s[1], return_ada=True)
self.merge_3 = Merge_block(fea_c=fea_c_s[2], ada_c=ada_c_s[2], mid_c=mid_c_s[2], return_ada=False)
self.merge_blocks = [self.merge_1, self.merge_2, self.merge_3]
# print(self.merge_blocks)
self.merge_blocks = []
self.merge_blocks_indexes = merge_blocks_indexes
# Loop through the merge_blocks_indexes and create Merge_block instances
for i, idx in enumerate(self.merge_blocks_indexes):
return_ada = False if i == len(self.merge_blocks_indexes) - 1 else True # Only the last block gets return_ada=False
if i != 0 or i != len(self.merge_blocks_indexes) - 1:
k = 1
else:
k = i
merge_block = Merge_block(
fea_c=fea_c_s[k],
ada_c=ada_c_s[k],
mid_c=mid_c_s[k],
return_ada=return_ada
).to("cuda")
self.merge_blocks.append(merge_block)
# self.merge_blocks.to("cuda")
print(self.merge_blocks)
self.init_weights()
@ -270,7 +283,7 @@ class DinoVisionTransformer(nn.Module):
def prepare_tokens_with_masks(self, x, masks=None):
B, nc, w, h = x.shape
print("BLOCKS NUM: " , len(self.blocks), len(self.merge_blocks))
x_raw = self.pre_encoder(x)
if self.w_lut: # I1, I2, I3, I4
ada = self.model_adapter([x_raw[0], x_raw[1], x_raw[2], x_raw[3]])

View File

@ -29,7 +29,7 @@ logger = logging.getLogger("dinov2")
def get_args_parser(add_help: bool = True):
parser = argparse.ArgumentParser("DINOv2 training", add_help=add_help)
parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file")
parser.add_argument("--config-file", default="dinov2/dinov2/configs/train/custom.yaml", metavar="FILE", help="path to config file")
parser.add_argument(
"--no-resume",
action="store_true",
@ -293,14 +293,29 @@ def do_train(cfg, model, resume=False):
metric_logger.synchronize_between_processes()
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
import re
from typing import List, Union
def parse_merge_block_indexes(config_value: str) -> List[int]:
"""
Parses a string containing merge block indexes and returns a list of integers.
Supports formats like "1,3,7,11" or "0..11".
"""
if '..' in config_value:
start, end = map(int, config_value.split('..'))
return list(range(start, end + 1))
return list(map(int, re.split(r'\s*,\s*', config_value)))
def main(args):
cfg = setup(args)
cfg.student.merge_block_indexes = parse_merge_block_indexes(cfg.student.merge_block_indexes)
print("INDEXES", cfg.student.merge_block_indexes)
model = SSLMetaArch(cfg).to(torch.device("cuda"))
model.prepare_for_distributed_training()
logger.info("Model:\n{}".format(model))
# logger.info("Model:\n{}".format(model))
if args.eval_only:
iteration = (
FSDPCheckpointer(model, save_dir=cfg.train.output_dir)