diff --git a/timm/models/fastvit.py b/timm/models/fastvit.py index d3d9bfdf..b156ade0 100644 --- a/timm/models/fastvit.py +++ b/timm/models/fastvit.py @@ -514,7 +514,7 @@ class Attention(nn.Module): if self.fused_attn: x = torch.nn.functional.scaled_dot_product_attention( q, k, v, - dropout_p=self.attn_drop.p, + dropout_p=self.attn_drop.p if self.training else 0.0, ) else: q = q * self.scale