pull/440/merge
Sanskar Agrawal 2024-07-12 12:51:21 +00:00 committed by GitHub
commit 52c0aec7fb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 6 additions and 6 deletions

View File

@ -176,7 +176,7 @@ class DinoVisionTransformer(nn.Module):
nn.init.normal_(self.register_tokens, std=1e-6)
named_apply(init_weights_vit_timm, self)
def interpolate_pos_encoding(self, x, w, h):
def interpolate_pos_encoding(self, x, h, w):
previous_dtype = x.dtype
npatch = x.shape[1] - 1
N = self.pos_embed.shape[1] - 1
@ -196,28 +196,28 @@ class DinoVisionTransformer(nn.Module):
# Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
sx = float(w0 + self.interpolate_offset) / M
sy = float(h0 + self.interpolate_offset) / M
kwargs["scale_factor"] = (sx, sy)
kwargs["scale_factor"] = (sy, sx)
else:
# Simply specify an output size instead of a scale factor
kwargs["size"] = (w0, h0)
kwargs["size"] = (h0, w0)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
mode="bicubic",
antialias=self.interpolate_antialias,
**kwargs,
)
assert (w0, h0) == patch_pos_embed.shape[-2:]
assert (h0, w0) == patch_pos_embed.shape[-2:]
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
def prepare_tokens_with_masks(self, x, masks=None):
B, nc, w, h = x.shape
B, nc, h, w = x.shape
x = self.patch_embed(x)
if masks is not None:
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
x = x + self.interpolate_pos_encoding(x, w, h)
x = x + self.interpolate_pos_encoding(x, h, w)
if self.register_tokens is not None:
x = torch.cat(