diff --git a/timm/layers/pos_embed_rel.py b/timm/layers/pos_embed_rel.py index 4620e81d..42a3b280 100644 --- a/timm/layers/pos_embed_rel.py +++ b/timm/layers/pos_embed_rel.py @@ -311,8 +311,8 @@ def gen_relative_log_coords( ): assert mode in ('swin', 'cr') # as per official swin-v2 impl, supporting timm specific 'cr' log coords as well - relative_coords_h = torch.arange(-(win_size[0] - 1), win_size[0], dtype=torch.float32) - relative_coords_w = torch.arange(-(win_size[1] - 1), win_size[1], dtype=torch.float32) + relative_coords_h = torch.arange(-(win_size[0] - 1), win_size[0]).to(torch.float32) + relative_coords_w = torch.arange(-(win_size[1] - 1), win_size[1]).to(torch.float32) relative_coords_table = torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w])) relative_coords_table = relative_coords_table.permute(1, 2, 0).contiguous() # 2*Wh-1, 2*Ww-1, 2 if mode == 'swin': diff --git a/timm/layers/pos_embed_sincos.py b/timm/layers/pos_embed_sincos.py index e850c034..4675aba2 100644 --- a/timm/layers/pos_embed_sincos.py +++ b/timm/layers/pos_embed_sincos.py @@ -15,13 +15,12 @@ def pixel_freq_bands( num_bands: int, max_freq: float = 224., linear_bands: bool = True, - dtype: torch.dtype = torch.float32, device: Optional[torch.device] = None, ): if linear_bands: - bands = torch.linspace(1.0, max_freq / 2, num_bands, dtype=dtype, device=device) + bands = torch.linspace(1.0, max_freq / 2, num_bands, dtype=torch.float32, device=device) else: - bands = 2 ** torch.linspace(0, math.log(max_freq, 2) - 1, num_bands, dtype=dtype, device=device) + bands = 2 ** torch.linspace(0, math.log(max_freq, 2) - 1, num_bands, dtype=torch.float32, device=device) return bands * torch.pi @@ -29,10 +28,10 @@ def freq_bands( num_bands: int, temperature: float = 10000., step: int = 2, - dtype: torch.dtype = torch.float32, device: Optional[torch.device] = None, ) -> torch.Tensor: - bands = 1. / (temperature ** (torch.arange(0, num_bands, step, dtype=dtype, device=device) / num_bands)) + exp = torch.arange(0, num_bands, step, dtype=torch.int64, device=device).to(torch.float32) / num_bands + bands = 1. / (temperature ** exp) return bands @@ -61,18 +60,20 @@ def build_sincos2d_pos_embed( """ assert dim % 4 == 0, 'Embed dimension must be divisible by 4 for sin-cos 2D position embedding' pos_dim = dim // 4 - bands = freq_bands(pos_dim, temperature=temperature, step=1, dtype=dtype, device=device) + bands = freq_bands(pos_dim, temperature=temperature, step=1, device=device) if reverse_coord: feat_shape = feat_shape[::-1] # stack W, H instead of H, W grid = torch.stack(torch.meshgrid( - [torch.arange(s, device=device, dtype=dtype) for s in feat_shape])).flatten(1).transpose(0, 1) + [torch.arange(s, device=device, dtype=torch.int64).to(torch.float32) + for s in feat_shape]) + ).flatten(1).transpose(0, 1) pos2 = grid.unsqueeze(-1) * bands.unsqueeze(0) # FIXME add support for unflattened spatial dim? stack_dim = 2 if interleave_sin_cos else 1 # stack sin, cos, sin, cos instead of sin sin cos cos pos_emb = torch.stack([torch.sin(pos2), torch.cos(pos2)], dim=stack_dim).flatten(1) - return pos_emb + return pos_emb.to(dtype=dtype) def build_fourier_pos_embed( @@ -112,7 +113,6 @@ def build_fourier_pos_embed( num_bands, float(max_res), linear_bands=linear_bands, - dtype=dtype, device=device, ) else: @@ -120,7 +120,6 @@ def build_fourier_pos_embed( num_bands, temperature=temperature, step=1, - dtype=dtype, device=device, ) else: @@ -130,9 +129,9 @@ def build_fourier_pos_embed( dtype = bands.dtype if in_pixels: - t = [torch.linspace(-1., 1., steps=s, device=device, dtype=dtype) for s in feat_shape] + t = [torch.linspace(-1., 1., steps=s, device=device, dtype=torch.float32) for s in feat_shape] else: - t = [torch.arange(s, device=device, dtype=dtype) for s in feat_shape] + t = [torch.arange(s, device=device, dtype=torch.int64).to(torch.float32) for s in feat_shape] if ref_feat_shape is not None: # eva's scheme for resizing rope embeddings (ref shape = pretrain) @@ -142,7 +141,7 @@ def build_fourier_pos_embed( grid = grid.unsqueeze(-1) pos = grid * bands - pos_sin, pos_cos = pos.sin(), pos.cos() + pos_sin, pos_cos = pos.sin().to(dtype=dtype), pos.cos().to(dtype) out = [grid, pos_sin, pos_cos] if include_grid else [pos_sin, pos_cos] return out diff --git a/timm/models/edgenext.py b/timm/models/edgenext.py index 8a9704ea..661669d5 100644 --- a/timm/models/edgenext.py +++ b/timm/models/edgenext.py @@ -41,13 +41,13 @@ class PositionalEncodingFourier(nn.Module): device = self.token_projection.weight.device dtype = self.token_projection.weight.dtype inv_mask = ~torch.zeros(shape).to(device=device, dtype=torch.bool) - y_embed = inv_mask.cumsum(1, dtype=dtype) - x_embed = inv_mask.cumsum(2, dtype=dtype) + y_embed = inv_mask.cumsum(1, dtype=torch.float32) + x_embed = inv_mask.cumsum(2, dtype=torch.float32) eps = 1e-6 y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale - dim_t = torch.arange(self.hidden_dim, dtype=dtype, device=device) + dim_t = torch.arange(self.hidden_dim, dtype=torch.int64, device=device).to(torch.float32) dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode='floor') / self.hidden_dim) pos_x = x_embed[:, :, :, None] / dim_t @@ -59,7 +59,7 @@ class PositionalEncodingFourier(nn.Module): (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) - pos = self.token_projection(pos) + pos = self.token_projection(pos.to(dtype)) return pos diff --git a/timm/models/swin_transformer_v2.py b/timm/models/swin_transformer_v2.py index eb1feeb5..0815856b 100644 --- a/timm/models/swin_transformer_v2.py +++ b/timm/models/swin_transformer_v2.py @@ -105,8 +105,8 @@ class WindowAttention(nn.Module): ) # get relative_coords_table - relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32) - relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32) + relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0]).to(torch.float32) + relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1]).to(torch.float32) relative_coords_table = torch.stack(torch.meshgrid([ relative_coords_h, relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2