diff --git a/dinov2/models/vision_transformer.py b/dinov2/models/vision_transformer.py index c8c3ec2..4926108 100644 --- a/dinov2/models/vision_transformer.py +++ b/dinov2/models/vision_transformer.py @@ -306,7 +306,7 @@ class DinoVisionTransformer(nn.Module): if norm: outputs = [self.norm(out) for out in outputs] class_tokens = [out[:, 0] for out in outputs] - outputs = [out[:, 1:] for out in outputs] + outputs = [out[:, 1 + self.num_register_tokens:] for out in outputs] if reshape: B, _, w, h = x.shape outputs = [