fix final norm only apply at last indice

This commit is contained in:
Ryan 2025-05-06 00:56:36 +08:00 committed by Ross Wightman
parent e0ae4db8fc
commit 72b2a09106

View File

@ -318,8 +318,11 @@ class RDNet(nn.Module):
feat_idx += 1
x = stage(x)
if feat_idx in take_indices:
# NOTE not bothering to apply norm_pre when norm=True as almost no models have it enabled
intermediates.append(x)
if norm and feat_idx == last_idx:
x_inter = self.norm_pre(x) # applying final norm to last intermediate
else:
x_inter = x
intermediates.append(x_inter)
if intermediates_only:
return intermediates