[Fix] Fix freeze stages

pull/1705/head
fanqiNO1 2023-07-10 18:36:33 +08:00
parent 77d897aa01
commit 2fb930b21d
1 changed files with 5 additions and 1 deletions

View File

@ -546,7 +546,8 @@ class BEiTViT(BaseBackbone):
for param in self.patch_embed.parameters():
param.requires_grad = False
# freeze cls_token
self.cls_token.requires_grad = False
if self.with_cls_token:
self.cls_token.requires_grad = False
# freeze layers
for i in range(1, self.frozen_stages + 1):
m = self.layers[i - 1]
@ -558,6 +559,9 @@ class BEiTViT(BaseBackbone):
self.ln1.eval()
for param in self.ln1.parameters():
param.requires_grad = False
self.ln2.eval()
for param in self.ln2.parameters():
param.requires_grad = False
def forward(self, x):
B = x.shape[0]