[bug] Fix interpolation of positional embeddings (#378)
Use size instead of scale factor to specify the output size of nn.interpolate(): this avoids any rounding issue leading to mismatching output size and consistently generate the same output size as with the previous kludge (from facebookresearch/dino#8).main
parent
2302b6bf46
commit
e1277af2ba
|
@ -188,21 +188,25 @@ class DinoVisionTransformer(nn.Module):
|
|||
dim = x.shape[-1]
|
||||
w0 = w // self.patch_size
|
||||
h0 = h // self.patch_size
|
||||
# we add a small number to avoid floating point error in the interpolation
|
||||
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
||||
w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset
|
||||
|
||||
sqrt_N = math.sqrt(N)
|
||||
sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
|
||||
M = int(math.sqrt(N)) # Recover the number of patches in each dimension
|
||||
assert N == M * M
|
||||
kwargs = {}
|
||||
if self.interpolate_offset:
|
||||
# Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
|
||||
# Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
|
||||
sx = float(w0 + self.interpolate_offset) / M
|
||||
sy = float(h0 + self.interpolate_offset) / M
|
||||
kwargs["scale_factor"] = (sx, sy)
|
||||
else:
|
||||
# Simply specify an output size instead of a scale factor
|
||||
kwargs["size"] = (w0, h0)
|
||||
patch_pos_embed = nn.functional.interpolate(
|
||||
patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
|
||||
scale_factor=(sx, sy),
|
||||
patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
|
||||
mode="bicubic",
|
||||
antialias=self.interpolate_antialias,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
assert int(w0) == patch_pos_embed.shape[-2]
|
||||
assert int(h0) == patch_pos_embed.shape[-1]
|
||||
assert (w0, h0) == patch_pos_embed.shape[-2:]
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
|
||||
|
||||
|
@ -306,7 +310,7 @@ class DinoVisionTransformer(nn.Module):
|
|||
if norm:
|
||||
outputs = [self.norm(out) for out in outputs]
|
||||
class_tokens = [out[:, 0] for out in outputs]
|
||||
outputs = [out[:, 1 + self.num_register_tokens:] for out in outputs]
|
||||
outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
|
||||
if reshape:
|
||||
B, _, w, h = x.shape
|
||||
outputs = [
|
||||
|
|
Loading…
Reference in New Issue