Saved computational costs of get_intermediate_layers() from unused blocks

This commit is contained in:
Cheng-Ling Lai 2024-03-16 23:18:36 +08:00
parent 4731e4efc4
commit db06b56d34
No known key found for this signature in database
GPG Key ID: AD39A7F7B0EB1AC5

View File

@ -635,13 +635,14 @@ class VisionTransformer(nn.Module):
) -> List[torch.Tensor]:
outputs, num_blocks = [], len(self.blocks)
take_indices = set(range(num_blocks - n, num_blocks) if isinstance(n, int) else n)
last_index_to_take = max(take_indices)
# forward pass
x = self.patch_embed(x)
x = self._pos_embed(x)
x = self.patch_drop(x)
x = self.norm_pre(x)
for i, blk in enumerate(self.blocks):
for i, blk in enumerate(self.blocks[: last_index_to_take + 1]):
x = blk(x)
if i in take_indices:
outputs.append(x)