add original model generation
parent
e9bbd22734
commit
602c340026
|
@ -39,5 +39,7 @@ def build_model(args, only_teacher=False, img_size=224):
|
|||
return student, teacher, embed_dim
|
||||
|
||||
|
||||
def build_model_from_cfg(cfg, only_teacher=False):
|
||||
return build_model(cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size)
|
||||
def build_model_from_cfg(cfg, only_teacher=False, img_size=None):
|
||||
if img_size is None:
|
||||
img_size = cfg.crops.global_crops_size
|
||||
return build_model(cfg.student, only_teacher=only_teacher, img_size=img_size)
|
||||
|
|
|
@ -63,35 +63,39 @@ class SSLMetaArch(nn.Module):
|
|||
student_model_dict = dict()
|
||||
teacher_model_dict = dict()
|
||||
|
||||
# TODO: Hard-coded embedding size
|
||||
if cfg.student.arch == "dinov2_vits14":
|
||||
embed_dim = 384
|
||||
elif cfg.student.arch == "dinov2_vitb14":
|
||||
embed_dim = 768
|
||||
student_backbone = get_downloaded_dino_vit_s(cfg.student.arch, embed_dim)
|
||||
teacher_backbone = get_downloaded_dino_vit_s(cfg.student.arch, embed_dim)
|
||||
# # TODO: Hard-coded embedding size
|
||||
# if cfg.student.arch == "dinov2_vits14":
|
||||
# embed_dim = 384
|
||||
# elif cfg.student.arch == "dinov2_vitb14":
|
||||
# embed_dim = 768
|
||||
# student_backbone = get_downloaded_dino_vit_s(cfg.student.arch, embed_dim)
|
||||
# teacher_backbone = get_downloaded_dino_vit_s(cfg.student.arch, embed_dim)
|
||||
|
||||
# TODO: Hard-coded image size
|
||||
student_backbone, teacher_backbone, embed_dim = build_model_from_cfg(cfg, img_size = 518)
|
||||
|
||||
if cfg.block_expansion.enabled:
|
||||
logger.info("OPTIONS -- block expansion ENABLED")
|
||||
student_backbone = expand_dinov2(student_backbone, cfg.block_expansion.expanded_blocks, cfg.block_expansion.path_dropout)
|
||||
teacher_backbone = expand_dinov2(teacher_backbone, cfg.block_expansion.expanded_blocks, cfg.block_expansion.path_dropout)
|
||||
|
||||
# student_backbone, teacher_backbone, embed_dim = build_model_from_cfg(cfg)
|
||||
# if cfg.block_expansion.enabled:
|
||||
# logger.info(f"OPTIONS -- block expansion: expanded blocks: {cfg.block_expansion.expanded_blocks}, fix weights of original blocks")
|
||||
# block_positions = get_expanded_block_positions(cfg.block_expansion.expanded_blocks)
|
||||
# for p in student_backbone.blocks.parameters():
|
||||
# p.requires_grad = False
|
||||
# for pos in block_positions:
|
||||
# for p in student_backbone.blocks[pos].parameters():
|
||||
# p.requires_grad = True
|
||||
|
||||
student_model_dict["backbone"] = student_backbone
|
||||
teacher_model_dict["backbone"] = teacher_backbone
|
||||
logger.info(f"OPTIONS -- architecture : embed_dim: {embed_dim}")
|
||||
|
||||
# Get checkpoint from hub
|
||||
if cfg.student.arch == "vit_small":
|
||||
model_name = "dinov2_vits14"
|
||||
elif cfg.student.arch == "vit_base":
|
||||
model_name = "dinov2_vitb14"
|
||||
model_weights = torch.hub.load('facebookresearch/dinov2', model_name)
|
||||
student_model_dict["backbone"].load_state_dict(model_weights.state_dict(), strict=True)
|
||||
|
||||
# Block Expansion
|
||||
if cfg.block_expansion.enabled:
|
||||
logger.info("OPTIONS -- block expansion ENABLED")
|
||||
student_model_dict["backbone"] = expand_dinov2(student_model_dict["backbone"], cfg.block_expansion.expanded_blocks, cfg.block_expansion.path_dropout)
|
||||
teacher_model_dict["backbone"] = expand_dinov2(teacher_model_dict["backbone"], cfg.block_expansion.expanded_blocks, cfg.block_expansion.path_dropout)
|
||||
|
||||
if cfg.student.pretrained_weights:
|
||||
|
||||
chkpt = torch.load(cfg.student.pretrained_weights)
|
||||
logger.info(f"OPTIONS -- pretrained weights: loading from {cfg.student.pretrained_weights}")
|
||||
student_backbone.load_state_dict(chkpt["model"], strict=False)
|
||||
|
@ -169,8 +173,6 @@ class SSLMetaArch(nn.Module):
|
|||
p.requires_grad = False
|
||||
|
||||
if cfg.block_expansion.enabled:
|
||||
|
||||
print(self.student.backbone)
|
||||
|
||||
logger.info(f"OPTIONS -- block expansion: expanded blocks: {cfg.block_expansion.expanded_blocks}, fix weights of original blocks")
|
||||
block_positions = get_expanded_block_positions(cfg.block_expansion.expanded_blocks)
|
||||
|
|
Loading…
Reference in New Issue