Fix reparameterize for NextViT. Fix #2187

This commit is contained in:
Ross Wightman 2024-05-27 14:48:58 -07:00
parent e748805be3
commit 3c0283f9ef

View File

@ -197,7 +197,7 @@ class NextConvBlock(nn.Module):
def reparameterize(self):
if not self.is_fused:
merge_pre_bn(self.mlp.fc1, self.norm)
self.norm = None
self.norm = nn.Identity()
self.is_fused = True
def forward(self, x):