add configuration for merge block indexes
parent
050e5677c8
commit
e68f8e290a
|
@ -56,6 +56,7 @@ train:
|
|||
ffn_layer: swiglufused
|
||||
block_chunks: 0 # for distributed training
|
||||
num_register_tokens: 0 # 0 for no register tokens
|
||||
|
||||
|
||||
teacher:
|
||||
momentum_teacher: 0.994
|
||||
|
@ -77,6 +78,7 @@ evaluation:
|
|||
student:
|
||||
arch: vit_base
|
||||
patch_size: 14
|
||||
merge_block_indexes: "1,3,7,11" # num, num, num,
|
||||
crops:
|
||||
global_crops_scale:
|
||||
- 0.32 #0.32 default
|
||||
|
|
|
@ -13,6 +13,7 @@ logger = logging.getLogger("dinov2")
|
|||
|
||||
def build_model(args, only_teacher=False, img_size=224):
|
||||
args.arch = args.arch.removesuffix("_memeff")
|
||||
print("ARGS", args.merge_block_indexes)
|
||||
if "vit" in args.arch:
|
||||
vit_kwargs = dict(
|
||||
img_size=img_size,
|
||||
|
@ -26,6 +27,7 @@ def build_model(args, only_teacher=False, img_size=224):
|
|||
num_register_tokens=args.num_register_tokens,
|
||||
interpolate_offset=args.interpolate_offset,
|
||||
interpolate_antialias=args.interpolate_antialias,
|
||||
merge_blocks_indexes=args.merge_block_indexes,
|
||||
)
|
||||
teacher = vits.__dict__[args.arch](**vit_kwargs)
|
||||
if only_teacher:
|
||||
|
|
|
@ -25,12 +25,12 @@ class Merge_block(BaseModule):
|
|||
self.ada_c = ada_c
|
||||
# 784 - embedded dim + adapter_c
|
||||
self.embeded_dim = 768
|
||||
self.fc_1 = nn.Linear(self.embeded_dim*2, mid_c)
|
||||
self.fc_2 = nn.Linear(mid_c, self.embeded_dim)
|
||||
self.fc_1 = nn.Linear(self.embeded_dim*2, mid_c).to(torch.float16)
|
||||
self.fc_2 = nn.Linear(mid_c, self.embeded_dim).to(torch.float16)
|
||||
self.return_ada = return_ada
|
||||
|
||||
if self.return_ada:
|
||||
self.conv_3 = nn.Conv1d(mid_c, self.embeded_dim, kernel_size=1) # 1D Conv instead of 3x3
|
||||
self.conv_3 = nn.Conv1d(mid_c, self.embeded_dim, kernel_size=1).to(torch.float16) # 1D Conv instead of 3x3
|
||||
else:
|
||||
self.conv_3 = None
|
||||
|
||||
|
|
|
@ -77,6 +77,7 @@ class DinoVisionTransformer(nn.Module):
|
|||
fea_c_s = [384, 768, 1920],
|
||||
ada_c_s = [16, 32, 64],
|
||||
mid_c_s = [384, 576, 768],
|
||||
merge_blocks_indexes=[],
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
|
@ -203,13 +204,25 @@ class DinoVisionTransformer(nn.Module):
|
|||
print("Loading input-level adapter:", input_level_adapter_path)
|
||||
adapter_state = torch.load(input_level_adapter_path, map_location="cpu")
|
||||
self.pre_encoder.load_state_dict(adapter_state)
|
||||
|
||||
self.merge_1 = Merge_block(fea_c=fea_c_s[0], ada_c=ada_c_s[0], mid_c=mid_c_s[0], return_ada=True)
|
||||
self.merge_2 = Merge_block(fea_c=fea_c_s[1], ada_c=ada_c_s[1], mid_c=mid_c_s[1], return_ada=True)
|
||||
self.merge_3 = Merge_block(fea_c=fea_c_s[2], ada_c=ada_c_s[2], mid_c=mid_c_s[2], return_ada=False)
|
||||
|
||||
self.merge_blocks = [self.merge_1, self.merge_2, self.merge_3]
|
||||
# print(self.merge_blocks)
|
||||
|
||||
self.merge_blocks = []
|
||||
self.merge_blocks_indexes = merge_blocks_indexes
|
||||
# Loop through the merge_blocks_indexes and create Merge_block instances
|
||||
for i, idx in enumerate(self.merge_blocks_indexes):
|
||||
return_ada = False if i == len(self.merge_blocks_indexes) - 1 else True # Only the last block gets return_ada=False
|
||||
if i != 0 or i != len(self.merge_blocks_indexes) - 1:
|
||||
k = 1
|
||||
else:
|
||||
k = i
|
||||
merge_block = Merge_block(
|
||||
fea_c=fea_c_s[k],
|
||||
ada_c=ada_c_s[k],
|
||||
mid_c=mid_c_s[k],
|
||||
return_ada=return_ada
|
||||
).to("cuda")
|
||||
self.merge_blocks.append(merge_block)
|
||||
# self.merge_blocks.to("cuda")
|
||||
print(self.merge_blocks)
|
||||
|
||||
|
||||
self.init_weights()
|
||||
|
@ -270,7 +283,7 @@ class DinoVisionTransformer(nn.Module):
|
|||
|
||||
def prepare_tokens_with_masks(self, x, masks=None):
|
||||
B, nc, w, h = x.shape
|
||||
|
||||
print("BLOCKS NUM: " , len(self.blocks), len(self.merge_blocks))
|
||||
x_raw = self.pre_encoder(x)
|
||||
if self.w_lut: # I1, I2, I3, I4
|
||||
ada = self.model_adapter([x_raw[0], x_raw[1], x_raw[2], x_raw[3]])
|
||||
|
|
|
@ -29,7 +29,7 @@ logger = logging.getLogger("dinov2")
|
|||
|
||||
def get_args_parser(add_help: bool = True):
|
||||
parser = argparse.ArgumentParser("DINOv2 training", add_help=add_help)
|
||||
parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file")
|
||||
parser.add_argument("--config-file", default="dinov2/dinov2/configs/train/custom.yaml", metavar="FILE", help="path to config file")
|
||||
parser.add_argument(
|
||||
"--no-resume",
|
||||
action="store_true",
|
||||
|
@ -293,14 +293,29 @@ def do_train(cfg, model, resume=False):
|
|||
metric_logger.synchronize_between_processes()
|
||||
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
||||
|
||||
import re
|
||||
from typing import List, Union
|
||||
|
||||
def parse_merge_block_indexes(config_value: str) -> List[int]:
|
||||
"""
|
||||
Parses a string containing merge block indexes and returns a list of integers.
|
||||
Supports formats like "1,3,7,11" or "0..11".
|
||||
"""
|
||||
if '..' in config_value:
|
||||
start, end = map(int, config_value.split('..'))
|
||||
return list(range(start, end + 1))
|
||||
return list(map(int, re.split(r'\s*,\s*', config_value)))
|
||||
|
||||
|
||||
def main(args):
|
||||
cfg = setup(args)
|
||||
|
||||
cfg.student.merge_block_indexes = parse_merge_block_indexes(cfg.student.merge_block_indexes)
|
||||
print("INDEXES", cfg.student.merge_block_indexes)
|
||||
model = SSLMetaArch(cfg).to(torch.device("cuda"))
|
||||
model.prepare_for_distributed_training()
|
||||
|
||||
logger.info("Model:\n{}".format(model))
|
||||
# logger.info("Model:\n{}".format(model))
|
||||
if args.eval_only:
|
||||
iteration = (
|
||||
FSDPCheckpointer(model, save_dir=cfg.train.output_dir)
|
||||
|
|
Loading…
Reference in New Issue