Update vision_transformer.py

Account for register tokens in get_intermediate_layers
qasfb-patch-3
qasfb 2023-12-01 18:00:30 +01:00 committed by GitHub
parent da4b3825f0
commit b8e789ce84
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 1 deletions

View File

@ -306,7 +306,7 @@ class DinoVisionTransformer(nn.Module):
if norm:
outputs = [self.norm(out) for out in outputs]
class_tokens = [out[:, 0] for out in outputs]
outputs = [out[:, 1:] for out in outputs]
outputs = [out[:, 1 + self.num_register_tokens:] for out in outputs]
if reshape:
B, _, w, h = x.shape
outputs = [