mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Improve type handling for arange & rel pos embeds, keep calculations in float32 until application (may change to apply in float32 in future). Prevent arange type hijacking by DeepSpeed Zero
This commit is contained in:
parent
3234daf783
commit
284e4ea7a9
@ -311,8 +311,8 @@ def gen_relative_log_coords(
|
||||
):
|
||||
assert mode in ('swin', 'cr')
|
||||
# as per official swin-v2 impl, supporting timm specific 'cr' log coords as well
|
||||
relative_coords_h = torch.arange(-(win_size[0] - 1), win_size[0], dtype=torch.float32)
|
||||
relative_coords_w = torch.arange(-(win_size[1] - 1), win_size[1], dtype=torch.float32)
|
||||
relative_coords_h = torch.arange(-(win_size[0] - 1), win_size[0]).to(torch.float32)
|
||||
relative_coords_w = torch.arange(-(win_size[1] - 1), win_size[1]).to(torch.float32)
|
||||
relative_coords_table = torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w]))
|
||||
relative_coords_table = relative_coords_table.permute(1, 2, 0).contiguous() # 2*Wh-1, 2*Ww-1, 2
|
||||
if mode == 'swin':
|
||||
|
@ -15,13 +15,12 @@ def pixel_freq_bands(
|
||||
num_bands: int,
|
||||
max_freq: float = 224.,
|
||||
linear_bands: bool = True,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: Optional[torch.device] = None,
|
||||
):
|
||||
if linear_bands:
|
||||
bands = torch.linspace(1.0, max_freq / 2, num_bands, dtype=dtype, device=device)
|
||||
bands = torch.linspace(1.0, max_freq / 2, num_bands, dtype=torch.float32, device=device)
|
||||
else:
|
||||
bands = 2 ** torch.linspace(0, math.log(max_freq, 2) - 1, num_bands, dtype=dtype, device=device)
|
||||
bands = 2 ** torch.linspace(0, math.log(max_freq, 2) - 1, num_bands, dtype=torch.float32, device=device)
|
||||
return bands * torch.pi
|
||||
|
||||
|
||||
@ -29,10 +28,10 @@ def freq_bands(
|
||||
num_bands: int,
|
||||
temperature: float = 10000.,
|
||||
step: int = 2,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> torch.Tensor:
|
||||
bands = 1. / (temperature ** (torch.arange(0, num_bands, step, dtype=dtype, device=device) / num_bands))
|
||||
exp = torch.arange(0, num_bands, step, dtype=torch.int64, device=device).to(torch.float32) / num_bands
|
||||
bands = 1. / (temperature ** exp)
|
||||
return bands
|
||||
|
||||
|
||||
@ -61,18 +60,20 @@ def build_sincos2d_pos_embed(
|
||||
"""
|
||||
assert dim % 4 == 0, 'Embed dimension must be divisible by 4 for sin-cos 2D position embedding'
|
||||
pos_dim = dim // 4
|
||||
bands = freq_bands(pos_dim, temperature=temperature, step=1, dtype=dtype, device=device)
|
||||
bands = freq_bands(pos_dim, temperature=temperature, step=1, device=device)
|
||||
|
||||
if reverse_coord:
|
||||
feat_shape = feat_shape[::-1] # stack W, H instead of H, W
|
||||
grid = torch.stack(torch.meshgrid(
|
||||
[torch.arange(s, device=device, dtype=dtype) for s in feat_shape])).flatten(1).transpose(0, 1)
|
||||
[torch.arange(s, device=device, dtype=torch.int64).to(torch.float32)
|
||||
for s in feat_shape])
|
||||
).flatten(1).transpose(0, 1)
|
||||
pos2 = grid.unsqueeze(-1) * bands.unsqueeze(0)
|
||||
# FIXME add support for unflattened spatial dim?
|
||||
|
||||
stack_dim = 2 if interleave_sin_cos else 1 # stack sin, cos, sin, cos instead of sin sin cos cos
|
||||
pos_emb = torch.stack([torch.sin(pos2), torch.cos(pos2)], dim=stack_dim).flatten(1)
|
||||
return pos_emb
|
||||
return pos_emb.to(dtype=dtype)
|
||||
|
||||
|
||||
def build_fourier_pos_embed(
|
||||
@ -112,7 +113,6 @@ def build_fourier_pos_embed(
|
||||
num_bands,
|
||||
float(max_res),
|
||||
linear_bands=linear_bands,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
@ -120,7 +120,6 @@ def build_fourier_pos_embed(
|
||||
num_bands,
|
||||
temperature=temperature,
|
||||
step=1,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
@ -130,9 +129,9 @@ def build_fourier_pos_embed(
|
||||
dtype = bands.dtype
|
||||
|
||||
if in_pixels:
|
||||
t = [torch.linspace(-1., 1., steps=s, device=device, dtype=dtype) for s in feat_shape]
|
||||
t = [torch.linspace(-1., 1., steps=s, device=device, dtype=torch.float32) for s in feat_shape]
|
||||
else:
|
||||
t = [torch.arange(s, device=device, dtype=dtype) for s in feat_shape]
|
||||
t = [torch.arange(s, device=device, dtype=torch.int64).to(torch.float32) for s in feat_shape]
|
||||
|
||||
if ref_feat_shape is not None:
|
||||
# eva's scheme for resizing rope embeddings (ref shape = pretrain)
|
||||
@ -142,7 +141,7 @@ def build_fourier_pos_embed(
|
||||
grid = grid.unsqueeze(-1)
|
||||
pos = grid * bands
|
||||
|
||||
pos_sin, pos_cos = pos.sin(), pos.cos()
|
||||
pos_sin, pos_cos = pos.sin().to(dtype=dtype), pos.cos().to(dtype)
|
||||
out = [grid, pos_sin, pos_cos] if include_grid else [pos_sin, pos_cos]
|
||||
return out
|
||||
|
||||
|
@ -41,13 +41,13 @@ class PositionalEncodingFourier(nn.Module):
|
||||
device = self.token_projection.weight.device
|
||||
dtype = self.token_projection.weight.dtype
|
||||
inv_mask = ~torch.zeros(shape).to(device=device, dtype=torch.bool)
|
||||
y_embed = inv_mask.cumsum(1, dtype=dtype)
|
||||
x_embed = inv_mask.cumsum(2, dtype=dtype)
|
||||
y_embed = inv_mask.cumsum(1, dtype=torch.float32)
|
||||
x_embed = inv_mask.cumsum(2, dtype=torch.float32)
|
||||
eps = 1e-6
|
||||
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
||||
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
||||
|
||||
dim_t = torch.arange(self.hidden_dim, dtype=dtype, device=device)
|
||||
dim_t = torch.arange(self.hidden_dim, dtype=torch.int64, device=device).to(torch.float32)
|
||||
dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode='floor') / self.hidden_dim)
|
||||
|
||||
pos_x = x_embed[:, :, :, None] / dim_t
|
||||
@ -59,7 +59,7 @@ class PositionalEncodingFourier(nn.Module):
|
||||
(pos_y[:, :, :, 0::2].sin(),
|
||||
pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
||||
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
||||
pos = self.token_projection(pos)
|
||||
pos = self.token_projection(pos.to(dtype))
|
||||
|
||||
return pos
|
||||
|
||||
|
@ -105,8 +105,8 @@ class WindowAttention(nn.Module):
|
||||
)
|
||||
|
||||
# get relative_coords_table
|
||||
relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
|
||||
relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
|
||||
relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0]).to(torch.float32)
|
||||
relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1]).to(torch.float32)
|
||||
relative_coords_table = torch.stack(torch.meshgrid([
|
||||
relative_coords_h,
|
||||
relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
|
||||
|
Loading…
x
Reference in New Issue
Block a user