[Fix] Fix freeze stages
parent
77d897aa01
commit
2fb930b21d
|
@ -546,7 +546,8 @@ class BEiTViT(BaseBackbone):
|
|||
for param in self.patch_embed.parameters():
|
||||
param.requires_grad = False
|
||||
# freeze cls_token
|
||||
self.cls_token.requires_grad = False
|
||||
if self.with_cls_token:
|
||||
self.cls_token.requires_grad = False
|
||||
# freeze layers
|
||||
for i in range(1, self.frozen_stages + 1):
|
||||
m = self.layers[i - 1]
|
||||
|
@ -558,6 +559,9 @@ class BEiTViT(BaseBackbone):
|
|||
self.ln1.eval()
|
||||
for param in self.ln1.parameters():
|
||||
param.requires_grad = False
|
||||
self.ln2.eval()
|
||||
for param in self.ln2.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, x):
|
||||
B = x.shape[0]
|
||||
|
|
Loading…
Reference in New Issue