Modified ViT get_intermediate_layers() to support dynamic image size

This commit is contained in:
Cheng-Ling Lai 2024-03-16 23:05:26 +08:00
parent 6e6f3686a7
commit 4731e4efc4
No known key found for this signature in database
GPG Key ID: AD39A7F7B0EB1AC5

View File

@ -667,9 +667,12 @@ class VisionTransformer(nn.Module):
outputs = [out[:, self.num_prefix_tokens:] for out in outputs]
if reshape:
grid_size = self.patch_embed.grid_size
patch_size = self.patch_embed.patch_size
batch, _, height, width = x.size()
outputs = [
out.reshape(x.shape[0], grid_size[0], grid_size[1], -1).permute(0, 3, 1, 2).contiguous()
out.reshape(batch, int(math.ceil(height / patch_size[0])), int(math.ceil(width / patch_size[1])), -1)
.permute(0, 3, 1, 2)
.contiguous()
for out in outputs
]