Merge pull request #2111 from jamesljlster/enhance_vit_get_intermediate_layers
Vision Transformer (ViT) get_intermediate_layers: enhanced to support dynamic image size and saved computational costs from unused blockspull/2117/head
commit
6ccb7d6a7c
|
@ -635,13 +635,14 @@ class VisionTransformer(nn.Module):
|
|||
) -> List[torch.Tensor]:
|
||||
outputs, num_blocks = [], len(self.blocks)
|
||||
take_indices = set(range(num_blocks - n, num_blocks) if isinstance(n, int) else n)
|
||||
last_index_to_take = max(take_indices)
|
||||
|
||||
# forward pass
|
||||
x = self.patch_embed(x)
|
||||
x = self._pos_embed(x)
|
||||
x = self.patch_drop(x)
|
||||
x = self.norm_pre(x)
|
||||
for i, blk in enumerate(self.blocks):
|
||||
for i, blk in enumerate(self.blocks[: last_index_to_take + 1]):
|
||||
x = blk(x)
|
||||
if i in take_indices:
|
||||
outputs.append(x)
|
||||
|
@ -667,9 +668,12 @@ class VisionTransformer(nn.Module):
|
|||
outputs = [out[:, self.num_prefix_tokens:] for out in outputs]
|
||||
|
||||
if reshape:
|
||||
grid_size = self.patch_embed.grid_size
|
||||
patch_size = self.patch_embed.patch_size
|
||||
batch, _, height, width = x.size()
|
||||
outputs = [
|
||||
out.reshape(x.shape[0], grid_size[0], grid_size[1], -1).permute(0, 3, 1, 2).contiguous()
|
||||
out.reshape(batch, int(math.ceil(height / patch_size[0])), int(math.ceil(width / patch_size[1])), -1)
|
||||
.permute(0, 3, 1, 2)
|
||||
.contiguous()
|
||||
for out in outputs
|
||||
]
|
||||
|
||||
|
|
Loading…
Reference in New Issue