Merge pull request #2111 from jamesljlster/enhance_vit_get_intermediate_layers

Vision Transformer (ViT) get_intermediate_layers: enhanced to support dynamic image size and saved computational costs from unused blocks
pull/2117/head
Ross Wightman 2024-03-18 13:41:18 -07:00 committed by GitHub
commit 6ccb7d6a7c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 7 additions and 3 deletions

View File

@ -635,13 +635,14 @@ class VisionTransformer(nn.Module):
) -> List[torch.Tensor]:
outputs, num_blocks = [], len(self.blocks)
take_indices = set(range(num_blocks - n, num_blocks) if isinstance(n, int) else n)
last_index_to_take = max(take_indices)
# forward pass
x = self.patch_embed(x)
x = self._pos_embed(x)
x = self.patch_drop(x)
x = self.norm_pre(x)
for i, blk in enumerate(self.blocks):
for i, blk in enumerate(self.blocks[: last_index_to_take + 1]):
x = blk(x)
if i in take_indices:
outputs.append(x)
@ -667,9 +668,12 @@ class VisionTransformer(nn.Module):
outputs = [out[:, self.num_prefix_tokens:] for out in outputs]
if reshape:
grid_size = self.patch_embed.grid_size
patch_size = self.patch_embed.patch_size
batch, _, height, width = x.size()
outputs = [
out.reshape(x.shape[0], grid_size[0], grid_size[1], -1).permute(0, 3, 1, 2).contiguous()
out.reshape(batch, int(math.ceil(height / patch_size[0])), int(math.ceil(width / patch_size[1])), -1)
.permute(0, 3, 1, 2)
.contiguous()
for out in outputs
]