diff --git a/timm/models/efficientvit_msra.py b/timm/models/efficientvit_msra.py index 00bfc055..bef462e0 100644 --- a/timm/models/efficientvit_msra.py +++ b/timm/models/efficientvit_msra.py @@ -384,6 +384,7 @@ class EfficientViTMSRA(nn.Module): return x def forward_head(self, x, pre_logits: bool = False): + x = self.global_pool(x) return x if pre_logits else self.head(x) def forward(self, x):