mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
fix norm at last feat_idx
This commit is contained in:
parent
e16d385592
commit
9aedecbb5f
@ -479,7 +479,6 @@ class MambaOut(nn.Module):
|
||||
self.reset_classifier(0, '')
|
||||
return take_indices
|
||||
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.stem(x)
|
||||
x = self.stages(x)
|
||||
|
@ -1302,7 +1302,8 @@ class MaxxVit(nn.Module):
|
||||
if intermediates_only:
|
||||
return intermediates
|
||||
|
||||
x = self.norm(x)
|
||||
if feat_idx == last_idx:
|
||||
x = self.norm(x)
|
||||
|
||||
return x, intermediates
|
||||
|
||||
|
@ -449,6 +449,7 @@ class Nest(nn.Module):
|
||||
|
||||
# forward pass
|
||||
x = self.patch_embed(x)
|
||||
last_idx = self.num_blocks - 1
|
||||
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||
stages = self.levels
|
||||
else:
|
||||
@ -457,13 +458,18 @@ class Nest(nn.Module):
|
||||
for feat_idx, stage in enumerate(stages):
|
||||
x = stage(x)
|
||||
if feat_idx in take_indices:
|
||||
intermediates.append(x)
|
||||
if norm and feat_idx == last_idx:
|
||||
x_inter = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||
intermediates.append(x_inter)
|
||||
else:
|
||||
intermediates.append(x)
|
||||
|
||||
if intermediates_only:
|
||||
return intermediates
|
||||
|
||||
# Layer norm done over channel dim only (to NHWC and back)
|
||||
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||
if feat_idx == last_idx:
|
||||
# Layer norm done over channel dim only (to NHWC and back)
|
||||
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||
|
||||
return x, intermediates
|
||||
|
||||
|
@ -588,6 +588,7 @@ class NextViT(nn.Module):
|
||||
|
||||
# forward pass
|
||||
x = self.stem(x)
|
||||
last_idx = len(self.stages) - 1
|
||||
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||
stages = self.stages
|
||||
else:
|
||||
@ -596,12 +597,17 @@ class NextViT(nn.Module):
|
||||
for feat_idx, stage in enumerate(stages):
|
||||
x = stage(x)
|
||||
if feat_idx in take_indices:
|
||||
intermediates.append(x)
|
||||
if feat_idx == last_idx:
|
||||
x_inter = self.norm(x) if norm else x
|
||||
intermediates.append(x_inter)
|
||||
else:
|
||||
intermediates.append(x)
|
||||
|
||||
if intermediates_only:
|
||||
return intermediates
|
||||
|
||||
x = self.norm(x)
|
||||
if feat_idx == last_idx:
|
||||
x = self.norm(x)
|
||||
|
||||
return x, intermediates
|
||||
|
||||
|
@ -309,7 +309,7 @@ class RDNet(nn.Module):
|
||||
x = self.stem(x)
|
||||
if feat_idx in take_indices:
|
||||
intermediates.append(x)
|
||||
|
||||
last_idx = len(self.dense_stages)
|
||||
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||
dense_stages = self.dense_stages
|
||||
else:
|
||||
@ -324,7 +324,8 @@ class RDNet(nn.Module):
|
||||
if intermediates_only:
|
||||
return intermediates
|
||||
|
||||
x = self.norm_pre(x)
|
||||
if feat_idx == last_idx:
|
||||
x = self.norm_pre(x)
|
||||
|
||||
return x, intermediates
|
||||
|
||||
|
@ -574,7 +574,7 @@ class ResNetV2(nn.Module):
|
||||
x = self.stem(x)
|
||||
if feat_idx in take_indices:
|
||||
intermediates.append(x)
|
||||
|
||||
last_idx = len(self.stages)
|
||||
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||
stages = self.stages
|
||||
else:
|
||||
@ -583,12 +583,17 @@ class ResNetV2(nn.Module):
|
||||
for feat_idx, stage in enumerate(stages, start=1):
|
||||
x = stage(x)
|
||||
if feat_idx in take_indices:
|
||||
intermediates.append(x)
|
||||
if feat_idx == last_idx:
|
||||
x_inter = self.norm(x) if norm else x
|
||||
intermediates.append(x_inter)
|
||||
else:
|
||||
intermediates.append(x)
|
||||
|
||||
if intermediates_only:
|
||||
return intermediates
|
||||
|
||||
x = self.norm(x)
|
||||
if feat_idx == last_idx:
|
||||
x = self.norm(x)
|
||||
|
||||
return x, intermediates
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user