From db395d35b17540ec7613c98ca748ce3cbf0c83db Mon Sep 17 00:00:00 2001 From: Fabien Merceron PRL <117649630+fabien-merceron@users.noreply.github.com> Date: Fri, 14 Jul 2023 09:43:19 +0200 Subject: [PATCH] fix_freeze_without_cls_token_vit (#1693) --- mmpretrain/models/backbones/vision_transformer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mmpretrain/models/backbones/vision_transformer.py b/mmpretrain/models/backbones/vision_transformer.py index d77ac863..0e9efa34 100644 --- a/mmpretrain/models/backbones/vision_transformer.py +++ b/mmpretrain/models/backbones/vision_transformer.py @@ -436,7 +436,8 @@ class VisionTransformer(BaseBackbone): for param in self.pre_norm.parameters(): param.requires_grad = False # freeze cls_token - self.cls_token.requires_grad = False + if self.cls_token: + self.cls_token.requires_grad = False # freeze layers for i in range(1, self.frozen_stages + 1): m = self.layers[i - 1]