fix norm at last feat_idx

This commit is contained in:
Ryan 2025-05-05 23:15:39 +08:00 committed by Ross Wightman
parent e16d385592
commit 9aedecbb5f
6 changed files with 30 additions and 12 deletions

View File

@ -479,7 +479,6 @@ class MambaOut(nn.Module):
self.reset_classifier(0, '')
return take_indices
def forward_features(self, x):
x = self.stem(x)
x = self.stages(x)

View File

@ -1302,7 +1302,8 @@ class MaxxVit(nn.Module):
if intermediates_only:
return intermediates
x = self.norm(x)
if feat_idx == last_idx:
x = self.norm(x)
return x, intermediates

View File

@ -449,6 +449,7 @@ class Nest(nn.Module):
# forward pass
x = self.patch_embed(x)
last_idx = self.num_blocks - 1
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.levels
else:
@ -457,13 +458,18 @@ class Nest(nn.Module):
for feat_idx, stage in enumerate(stages):
x = stage(x)
if feat_idx in take_indices:
intermediates.append(x)
if norm and feat_idx == last_idx:
x_inter = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
intermediates.append(x_inter)
else:
intermediates.append(x)
if intermediates_only:
return intermediates
# Layer norm done over channel dim only (to NHWC and back)
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
if feat_idx == last_idx:
# Layer norm done over channel dim only (to NHWC and back)
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
return x, intermediates

View File

@ -588,6 +588,7 @@ class NextViT(nn.Module):
# forward pass
x = self.stem(x)
last_idx = len(self.stages) - 1
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.stages
else:
@ -596,12 +597,17 @@ class NextViT(nn.Module):
for feat_idx, stage in enumerate(stages):
x = stage(x)
if feat_idx in take_indices:
intermediates.append(x)
if feat_idx == last_idx:
x_inter = self.norm(x) if norm else x
intermediates.append(x_inter)
else:
intermediates.append(x)
if intermediates_only:
return intermediates
x = self.norm(x)
if feat_idx == last_idx:
x = self.norm(x)
return x, intermediates

View File

@ -309,7 +309,7 @@ class RDNet(nn.Module):
x = self.stem(x)
if feat_idx in take_indices:
intermediates.append(x)
last_idx = len(self.dense_stages)
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
dense_stages = self.dense_stages
else:
@ -324,7 +324,8 @@ class RDNet(nn.Module):
if intermediates_only:
return intermediates
x = self.norm_pre(x)
if feat_idx == last_idx:
x = self.norm_pre(x)
return x, intermediates

View File

@ -574,7 +574,7 @@ class ResNetV2(nn.Module):
x = self.stem(x)
if feat_idx in take_indices:
intermediates.append(x)
last_idx = len(self.stages)
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.stages
else:
@ -583,12 +583,17 @@ class ResNetV2(nn.Module):
for feat_idx, stage in enumerate(stages, start=1):
x = stage(x)
if feat_idx in take_indices:
intermediates.append(x)
if feat_idx == last_idx:
x_inter = self.norm(x) if norm else x
intermediates.append(x_inter)
else:
intermediates.append(x)
if intermediates_only:
return intermediates
x = self.norm(x)
if feat_idx == last_idx:
x = self.norm(x)
return x, intermediates