add original model generation

pull/393/head
cm090999 2024-03-14 10:19:10 +01:00
parent e9bbd22734
commit 602c340026
2 changed files with 28 additions and 24 deletions

View File

@ -39,5 +39,7 @@ def build_model(args, only_teacher=False, img_size=224):
return student, teacher, embed_dim
def build_model_from_cfg(cfg, only_teacher=False):
return build_model(cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size)
def build_model_from_cfg(cfg, only_teacher=False, img_size=None):
if img_size is None:
img_size = cfg.crops.global_crops_size
return build_model(cfg.student, only_teacher=only_teacher, img_size=img_size)

View File

@ -63,35 +63,39 @@ class SSLMetaArch(nn.Module):
student_model_dict = dict()
teacher_model_dict = dict()
# TODO: Hard-coded embedding size
if cfg.student.arch == "dinov2_vits14":
embed_dim = 384
elif cfg.student.arch == "dinov2_vitb14":
embed_dim = 768
student_backbone = get_downloaded_dino_vit_s(cfg.student.arch, embed_dim)
teacher_backbone = get_downloaded_dino_vit_s(cfg.student.arch, embed_dim)
# # TODO: Hard-coded embedding size
# if cfg.student.arch == "dinov2_vits14":
# embed_dim = 384
# elif cfg.student.arch == "dinov2_vitb14":
# embed_dim = 768
# student_backbone = get_downloaded_dino_vit_s(cfg.student.arch, embed_dim)
# teacher_backbone = get_downloaded_dino_vit_s(cfg.student.arch, embed_dim)
# TODO: Hard-coded image size
student_backbone, teacher_backbone, embed_dim = build_model_from_cfg(cfg, img_size = 518)
if cfg.block_expansion.enabled:
logger.info("OPTIONS -- block expansion ENABLED")
student_backbone = expand_dinov2(student_backbone, cfg.block_expansion.expanded_blocks, cfg.block_expansion.path_dropout)
teacher_backbone = expand_dinov2(teacher_backbone, cfg.block_expansion.expanded_blocks, cfg.block_expansion.path_dropout)
# student_backbone, teacher_backbone, embed_dim = build_model_from_cfg(cfg)
# if cfg.block_expansion.enabled:
# logger.info(f"OPTIONS -- block expansion: expanded blocks: {cfg.block_expansion.expanded_blocks}, fix weights of original blocks")
# block_positions = get_expanded_block_positions(cfg.block_expansion.expanded_blocks)
# for p in student_backbone.blocks.parameters():
# p.requires_grad = False
# for pos in block_positions:
# for p in student_backbone.blocks[pos].parameters():
# p.requires_grad = True
student_model_dict["backbone"] = student_backbone
teacher_model_dict["backbone"] = teacher_backbone
logger.info(f"OPTIONS -- architecture : embed_dim: {embed_dim}")
# Get checkpoint from hub
if cfg.student.arch == "vit_small":
model_name = "dinov2_vits14"
elif cfg.student.arch == "vit_base":
model_name = "dinov2_vitb14"
model_weights = torch.hub.load('facebookresearch/dinov2', model_name)
student_model_dict["backbone"].load_state_dict(model_weights.state_dict(), strict=True)
# Block Expansion
if cfg.block_expansion.enabled:
logger.info("OPTIONS -- block expansion ENABLED")
student_model_dict["backbone"] = expand_dinov2(student_model_dict["backbone"], cfg.block_expansion.expanded_blocks, cfg.block_expansion.path_dropout)
teacher_model_dict["backbone"] = expand_dinov2(teacher_model_dict["backbone"], cfg.block_expansion.expanded_blocks, cfg.block_expansion.path_dropout)
if cfg.student.pretrained_weights:
chkpt = torch.load(cfg.student.pretrained_weights)
logger.info(f"OPTIONS -- pretrained weights: loading from {cfg.student.pretrained_weights}")
student_backbone.load_state_dict(chkpt["model"], strict=False)
@ -169,8 +173,6 @@ class SSLMetaArch(nn.Module):
p.requires_grad = False
if cfg.block_expansion.enabled:
print(self.student.backbone)
logger.info(f"OPTIONS -- block expansion: expanded blocks: {cfg.block_expansion.expanded_blocks}, fix weights of original blocks")
block_positions = get_expanded_block_positions(cfg.block_expansion.expanded_blocks)