Merge d283fc2b96
into e1277af2ba
commit
52c0aec7fb
|
@ -176,7 +176,7 @@ class DinoVisionTransformer(nn.Module):
|
|||
nn.init.normal_(self.register_tokens, std=1e-6)
|
||||
named_apply(init_weights_vit_timm, self)
|
||||
|
||||
def interpolate_pos_encoding(self, x, w, h):
|
||||
def interpolate_pos_encoding(self, x, h, w):
|
||||
previous_dtype = x.dtype
|
||||
npatch = x.shape[1] - 1
|
||||
N = self.pos_embed.shape[1] - 1
|
||||
|
@ -196,28 +196,28 @@ class DinoVisionTransformer(nn.Module):
|
|||
# Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
|
||||
sx = float(w0 + self.interpolate_offset) / M
|
||||
sy = float(h0 + self.interpolate_offset) / M
|
||||
kwargs["scale_factor"] = (sx, sy)
|
||||
kwargs["scale_factor"] = (sy, sx)
|
||||
else:
|
||||
# Simply specify an output size instead of a scale factor
|
||||
kwargs["size"] = (w0, h0)
|
||||
kwargs["size"] = (h0, w0)
|
||||
patch_pos_embed = nn.functional.interpolate(
|
||||
patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
|
||||
mode="bicubic",
|
||||
antialias=self.interpolate_antialias,
|
||||
**kwargs,
|
||||
)
|
||||
assert (w0, h0) == patch_pos_embed.shape[-2:]
|
||||
assert (h0, w0) == patch_pos_embed.shape[-2:]
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
|
||||
|
||||
def prepare_tokens_with_masks(self, x, masks=None):
|
||||
B, nc, w, h = x.shape
|
||||
B, nc, h, w = x.shape
|
||||
x = self.patch_embed(x)
|
||||
if masks is not None:
|
||||
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
|
||||
|
||||
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
||||
x = x + self.interpolate_pos_encoding(x, w, h)
|
||||
x = x + self.interpolate_pos_encoding(x, h, w)
|
||||
|
||||
if self.register_tokens is not None:
|
||||
x = torch.cat(
|
||||
|
|
Loading…
Reference in New Issue