From 2fb930b21d785ab3a9f451312405c49b98cfdc34 Mon Sep 17 00:00:00 2001 From: fanqiNO1 <1848839264@qq.com> Date: Mon, 10 Jul 2023 18:36:33 +0800 Subject: [PATCH] [Fix] Fix freeze stages --- mmpretrain/models/backbones/beit.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/mmpretrain/models/backbones/beit.py b/mmpretrain/models/backbones/beit.py index cc29b586..372be363 100644 --- a/mmpretrain/models/backbones/beit.py +++ b/mmpretrain/models/backbones/beit.py @@ -546,7 +546,8 @@ class BEiTViT(BaseBackbone): for param in self.patch_embed.parameters(): param.requires_grad = False # freeze cls_token - self.cls_token.requires_grad = False + if self.with_cls_token: + self.cls_token.requires_grad = False # freeze layers for i in range(1, self.frozen_stages + 1): m = self.layers[i - 1] @@ -558,6 +559,9 @@ class BEiTViT(BaseBackbone): self.ln1.eval() for param in self.ln1.parameters(): param.requires_grad = False + self.ln2.eval() + for param in self.ln2.parameters(): + param.requires_grad = False def forward(self, x): B = x.shape[0]