Improve type handling for arange & rel pos embeds, keep calculations in float32 until application (may change to apply in float32 in future). Prevent arange type hijacking by DeepSpeed Zero

This commit is contained in:
Ross Wightman 2024-01-26 14:17:54 -08:00
parent 3234daf783
commit 284e4ea7a9
4 changed files with 20 additions and 21 deletions

View File

@ -311,8 +311,8 @@ def gen_relative_log_coords(
):
assert mode in ('swin', 'cr')
# as per official swin-v2 impl, supporting timm specific 'cr' log coords as well
relative_coords_h = torch.arange(-(win_size[0] - 1), win_size[0], dtype=torch.float32)
relative_coords_w = torch.arange(-(win_size[1] - 1), win_size[1], dtype=torch.float32)
relative_coords_h = torch.arange(-(win_size[0] - 1), win_size[0]).to(torch.float32)
relative_coords_w = torch.arange(-(win_size[1] - 1), win_size[1]).to(torch.float32)
relative_coords_table = torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w]))
relative_coords_table = relative_coords_table.permute(1, 2, 0).contiguous() # 2*Wh-1, 2*Ww-1, 2
if mode == 'swin':

View File

@ -15,13 +15,12 @@ def pixel_freq_bands(
num_bands: int,
max_freq: float = 224.,
linear_bands: bool = True,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
):
if linear_bands:
bands = torch.linspace(1.0, max_freq / 2, num_bands, dtype=dtype, device=device)
bands = torch.linspace(1.0, max_freq / 2, num_bands, dtype=torch.float32, device=device)
else:
bands = 2 ** torch.linspace(0, math.log(max_freq, 2) - 1, num_bands, dtype=dtype, device=device)
bands = 2 ** torch.linspace(0, math.log(max_freq, 2) - 1, num_bands, dtype=torch.float32, device=device)
return bands * torch.pi
@ -29,10 +28,10 @@ def freq_bands(
num_bands: int,
temperature: float = 10000.,
step: int = 2,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
) -> torch.Tensor:
bands = 1. / (temperature ** (torch.arange(0, num_bands, step, dtype=dtype, device=device) / num_bands))
exp = torch.arange(0, num_bands, step, dtype=torch.int64, device=device).to(torch.float32) / num_bands
bands = 1. / (temperature ** exp)
return bands
@ -61,18 +60,20 @@ def build_sincos2d_pos_embed(
"""
assert dim % 4 == 0, 'Embed dimension must be divisible by 4 for sin-cos 2D position embedding'
pos_dim = dim // 4
bands = freq_bands(pos_dim, temperature=temperature, step=1, dtype=dtype, device=device)
bands = freq_bands(pos_dim, temperature=temperature, step=1, device=device)
if reverse_coord:
feat_shape = feat_shape[::-1] # stack W, H instead of H, W
grid = torch.stack(torch.meshgrid(
[torch.arange(s, device=device, dtype=dtype) for s in feat_shape])).flatten(1).transpose(0, 1)
[torch.arange(s, device=device, dtype=torch.int64).to(torch.float32)
for s in feat_shape])
).flatten(1).transpose(0, 1)
pos2 = grid.unsqueeze(-1) * bands.unsqueeze(0)
# FIXME add support for unflattened spatial dim?
stack_dim = 2 if interleave_sin_cos else 1 # stack sin, cos, sin, cos instead of sin sin cos cos
pos_emb = torch.stack([torch.sin(pos2), torch.cos(pos2)], dim=stack_dim).flatten(1)
return pos_emb
return pos_emb.to(dtype=dtype)
def build_fourier_pos_embed(
@ -112,7 +113,6 @@ def build_fourier_pos_embed(
num_bands,
float(max_res),
linear_bands=linear_bands,
dtype=dtype,
device=device,
)
else:
@ -120,7 +120,6 @@ def build_fourier_pos_embed(
num_bands,
temperature=temperature,
step=1,
dtype=dtype,
device=device,
)
else:
@ -130,9 +129,9 @@ def build_fourier_pos_embed(
dtype = bands.dtype
if in_pixels:
t = [torch.linspace(-1., 1., steps=s, device=device, dtype=dtype) for s in feat_shape]
t = [torch.linspace(-1., 1., steps=s, device=device, dtype=torch.float32) for s in feat_shape]
else:
t = [torch.arange(s, device=device, dtype=dtype) for s in feat_shape]
t = [torch.arange(s, device=device, dtype=torch.int64).to(torch.float32) for s in feat_shape]
if ref_feat_shape is not None:
# eva's scheme for resizing rope embeddings (ref shape = pretrain)
@ -142,7 +141,7 @@ def build_fourier_pos_embed(
grid = grid.unsqueeze(-1)
pos = grid * bands
pos_sin, pos_cos = pos.sin(), pos.cos()
pos_sin, pos_cos = pos.sin().to(dtype=dtype), pos.cos().to(dtype)
out = [grid, pos_sin, pos_cos] if include_grid else [pos_sin, pos_cos]
return out

View File

@ -41,13 +41,13 @@ class PositionalEncodingFourier(nn.Module):
device = self.token_projection.weight.device
dtype = self.token_projection.weight.dtype
inv_mask = ~torch.zeros(shape).to(device=device, dtype=torch.bool)
y_embed = inv_mask.cumsum(1, dtype=dtype)
x_embed = inv_mask.cumsum(2, dtype=dtype)
y_embed = inv_mask.cumsum(1, dtype=torch.float32)
x_embed = inv_mask.cumsum(2, dtype=torch.float32)
eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.hidden_dim, dtype=dtype, device=device)
dim_t = torch.arange(self.hidden_dim, dtype=torch.int64, device=device).to(torch.float32)
dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode='floor') / self.hidden_dim)
pos_x = x_embed[:, :, :, None] / dim_t
@ -59,7 +59,7 @@ class PositionalEncodingFourier(nn.Module):
(pos_y[:, :, :, 0::2].sin(),
pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
pos = self.token_projection(pos)
pos = self.token_projection(pos.to(dtype))
return pos

View File

@ -105,8 +105,8 @@ class WindowAttention(nn.Module):
)
# get relative_coords_table
relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0]).to(torch.float32)
relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1]).to(torch.float32)
relative_coords_table = torch.stack(torch.meshgrid([
relative_coords_h,
relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2