mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge pull request #2111 from jamesljlster/enhance_vit_get_intermediate_layers
Vision Transformer (ViT) get_intermediate_layers: enhanced to support dynamic image size and saved computational costs from unused blocks
This commit is contained in:
commit
6ccb7d6a7c
@ -635,13 +635,14 @@ class VisionTransformer(nn.Module):
|
|||||||
) -> List[torch.Tensor]:
|
) -> List[torch.Tensor]:
|
||||||
outputs, num_blocks = [], len(self.blocks)
|
outputs, num_blocks = [], len(self.blocks)
|
||||||
take_indices = set(range(num_blocks - n, num_blocks) if isinstance(n, int) else n)
|
take_indices = set(range(num_blocks - n, num_blocks) if isinstance(n, int) else n)
|
||||||
|
last_index_to_take = max(take_indices)
|
||||||
|
|
||||||
# forward pass
|
# forward pass
|
||||||
x = self.patch_embed(x)
|
x = self.patch_embed(x)
|
||||||
x = self._pos_embed(x)
|
x = self._pos_embed(x)
|
||||||
x = self.patch_drop(x)
|
x = self.patch_drop(x)
|
||||||
x = self.norm_pre(x)
|
x = self.norm_pre(x)
|
||||||
for i, blk in enumerate(self.blocks):
|
for i, blk in enumerate(self.blocks[: last_index_to_take + 1]):
|
||||||
x = blk(x)
|
x = blk(x)
|
||||||
if i in take_indices:
|
if i in take_indices:
|
||||||
outputs.append(x)
|
outputs.append(x)
|
||||||
@ -667,9 +668,12 @@ class VisionTransformer(nn.Module):
|
|||||||
outputs = [out[:, self.num_prefix_tokens:] for out in outputs]
|
outputs = [out[:, self.num_prefix_tokens:] for out in outputs]
|
||||||
|
|
||||||
if reshape:
|
if reshape:
|
||||||
grid_size = self.patch_embed.grid_size
|
patch_size = self.patch_embed.patch_size
|
||||||
|
batch, _, height, width = x.size()
|
||||||
outputs = [
|
outputs = [
|
||||||
out.reshape(x.shape[0], grid_size[0], grid_size[1], -1).permute(0, 3, 1, 2).contiguous()
|
out.reshape(batch, int(math.ceil(height / patch_size[0])), int(math.ceil(width / patch_size[1])), -1)
|
||||||
|
.permute(0, 3, 1, 2)
|
||||||
|
.contiguous()
|
||||||
for out in outputs
|
for out in outputs
|
||||||
]
|
]
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user