diff --git a/mmcls/models/utils/attention.py b/mmcls/models/utils/attention.py index 4e795ed0..1aae72ae 100644 --- a/mmcls/models/utils/attention.py +++ b/mmcls/models/utils/attention.py @@ -261,10 +261,7 @@ class WindowMSAV2(BaseModule): attn = ( F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)) logit_scale = torch.clamp( - self.logit_scale, - max=torch.log( - torch.tensor(1. / 0.01, - device=self.logit_scale.device))).exp() + self.logit_scale, max=np.log(1. / 0.01)).exp() attn = attn * logit_scale relative_position_bias_table = self.cpb_mlp(