diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 1ab2e736..e560ec9a 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -635,13 +635,14 @@ class VisionTransformer(nn.Module): ) -> List[torch.Tensor]: outputs, num_blocks = [], len(self.blocks) take_indices = set(range(num_blocks - n, num_blocks) if isinstance(n, int) else n) + last_index_to_take = max(take_indices) # forward pass x = self.patch_embed(x) x = self._pos_embed(x) x = self.patch_drop(x) x = self.norm_pre(x) - for i, blk in enumerate(self.blocks): + for i, blk in enumerate(self.blocks[: last_index_to_take + 1]): x = blk(x) if i in take_indices: outputs.append(x)