From ea728f67fa26779f65e1cb9738ece458dbc86a42 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 14 Apr 2025 11:01:56 -0700 Subject: [PATCH] Improve several typing issues for flex vit, can (almost) work with jit if we bash h,w key into an int or str --- timm/layers/patch_embed.py | 2 +- timm/models/vision_transformer_flex.py | 76 ++++++++++++++------------ 2 files changed, 43 insertions(+), 35 deletions(-) diff --git a/timm/layers/patch_embed.py b/timm/layers/patch_embed.py index dab7acc9..519bb30c 100644 --- a/timm/layers/patch_embed.py +++ b/timm/layers/patch_embed.py @@ -321,7 +321,7 @@ def resample_patch_embed( verbose: bool = False, ): """ Standalone function (computes matrix on each call). """ - assert len(patch_embed.shape) == 4, "Input tensor should be 4D (out_c, in_c, h, w)" + assert len(patch_embed.shape) == 4, "Input tensor should be 4D (out_ch, in_ch, h, w)" assert len(new_size) == 2, "New shape should only be hw (height, width)" old_size_tuple: Tuple[int, int] = tuple(patch_embed.shape[-2:]) diff --git a/timm/models/vision_transformer_flex.py b/timm/models/vision_transformer_flex.py index 63db0f70..7ddc3377 100644 --- a/timm/models/vision_transformer_flex.py +++ b/timm/models/vision_transformer_flex.py @@ -42,7 +42,7 @@ def batch_patchify( pad: bool = True, ) -> Tuple[torch.Tensor, Tuple[int, int]]: B, C, H, W = x.shape - ph, pw = to_2tuple(patch_size) + ph, pw = patch_size # Ensure the image is divisible by patch size if pad and (H % ph != 0 or W % pw != 0): @@ -202,13 +202,12 @@ class FlexEmbeds(nn.Module): else: return img_size[0] // self.patch_size[0], img_size[1] // self.patch_size[1] - def forward(self, x, patch_coord=None, patch_valid=None): + def forward(self, x: torch.Tensor, patch_coord: Optional[torch.Tensor] = None): """Forward pass for combined embedding Args: x: Input tensor [B, C, H, W] or pre-patchified [B, N, P*P*C] patch_coord: Optional patch coordinates [B, N, 2] for NaFlex - patch_valid: Optional patch type indicators (1=patch, 0=padding) for NaFlex Returns: Embedded tensor with position encoding and class/register tokens applied @@ -216,7 +215,7 @@ class FlexEmbeds(nn.Module): """ # Apply patch embedding naflex_grid_sizes: Optional[List[Tuple[int, int]]] = None - grid_size: Optional[Tuple[int, int]] = None + grid_size: Optional[List[int]] = None B = x.shape[0] if self.is_linear: @@ -227,7 +226,7 @@ class FlexEmbeds(nn.Module): # Calculate the appropriate grid size from coords max_y = patch_coord[:, :, 0].max(dim=1)[0] + 1 max_x = patch_coord[:, :, 1].max(dim=1)[0] + 1 - naflex_grid_sizes = [(h.item(), w.item()) for h, w in zip(max_y, max_x)] + naflex_grid_sizes = [(int(h.item()), int(w.item())) for h, w in zip(max_y, max_x)] else: _assert(x.ndim == 4, 'Expecting 2D image input with input ndim == 4') x, grid_size = batch_patchify(x, self.patch_size, pad=self.dynamic_img_pad) @@ -257,6 +256,7 @@ class FlexEmbeds(nn.Module): if naflex_grid_sizes is not None: self._apply_learned_naflex_pos_embed(x, naflex_grid_sizes=naflex_grid_sizes) else: + assert grid_size is not None self._apply_learned_pos_embed(x, grid_size=grid_size) elif self.pos_embed_type == 'rope': assert False, "ROPE not yet implemented" @@ -287,15 +287,19 @@ class FlexEmbeds(nn.Module): orig_h, orig_w = self.pos_embed.shape[1:3] # Determine unique grid sizes - size_to_indices = {} + size_to_indices: Dict[Tuple[int, int], List[int]] = {} for bi, (h, w) in enumerate(naflex_grid_sizes): - if not (h, w) in size_to_indices: - size_to_indices[(h, w)] = [bi] + #k = h << 16 | w # FIXME can get jit compat with this + k = (h, w) + if not k in size_to_indices: + size_to_indices[k] = [bi] else: - size_to_indices[(h, w)].append(bi) + size_to_indices[k].append(bi) # Handle each batch element separately with its own grid size - for (h, w), batch_indices in size_to_indices.items(): + for k, batch_indices in size_to_indices.items(): + h, w = k + #h, w = k >> 16, k & 0xFFFF # FIXME can get jit compat with this # Interpolate only once for this (h, w) if (h == orig_h) and (w == orig_w): pos_embed_flat = self.pos_embed.reshape(orig_h * orig_w, -1) @@ -315,7 +319,7 @@ class FlexEmbeds(nn.Module): def _apply_learned_pos_embed( self, x: torch.Tensor, - grid_size: Tuple[int, int], + grid_size: List[int], ): orig_h, orig_w = self.pos_embed.shape[1:3] if grid_size[0] != orig_h or grid_size[1] != orig_w: @@ -340,7 +344,7 @@ class FlexEmbeds(nn.Module): @register_notrace_function def create_attention_mask( - patch_valid: Optional[torch.Tensor], + patch_valid: torch.Tensor, num_prefix_tokens: int = 0, dtype: torch.dtype = torch.float32, ) -> torch.Tensor: @@ -357,7 +361,7 @@ def create_attention_mask( Attention mask of shape [B, seq_len, seq_len] where seq_len = N + num_prefix_tokens, or None if patch_type is None """ - patch_valid = patch_valid.bool() + patch_valid = patch_valid.to(torch.bool) B = patch_valid.shape[0] if num_prefix_tokens > 0: @@ -373,7 +377,7 @@ def create_attention_mask( @register_notrace_function def create_attention_mask2( - patch_valid: Optional[torch.Tensor], + patch_valid: torch.Tensor, num_prefix_tokens: int = 0, q_len: Optional[int] = None, dtype: torch.dtype = torch.float32, @@ -411,7 +415,7 @@ def create_attention_mask2( @register_notrace_function def create_pool_mask( - patch_valid: Optional[torch.Tensor], + patch_valid:torch.Tensor, dtype: torch.dtype = torch.float32, ) -> torch.Tensor: patch_valid = patch_valid.bool() @@ -773,8 +777,16 @@ class VisionTransformerFlex(nn.Module): patch_valid: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: + + if attn_mask is None and patch_valid is not None: + attn_mask = create_attention_mask( + patch_valid, + num_prefix_tokens=self.num_prefix_tokens, + dtype=x.dtype + ) + # Pass through embedding module with patch coordinate/type support - x = self.embeds(x, patch_coord=patch_coord, patch_valid=patch_valid) + x = self.embeds(x, patch_coord=patch_coord) # Apply transformer blocks with masked attention if mask provided if attn_mask is not None: @@ -827,7 +839,7 @@ class VisionTransformerFlex(nn.Module): # For max pooling with mask masked_x = x.clone() - masked_x[~patch_valid] = torch.finfo(masked_x.dtype).min + masked_x[~patch_valid] = -1e4 # torch.finfo(masked_x.dtype).min masked_max = masked_x.max(dim=1)[0] # Combine average and max @@ -864,27 +876,23 @@ class VisionTransformerFlex(nn.Module): Returns: Model output tensor """ - # Handle dictionary input from NaFlex collator - if isinstance(x, dict): - assert patch_coord is None - assert patch_valid is None - # Extract the required components from the dictionary + if isinstance(x, torch.Tensor): + patches = x + else: + # Handle dictionary input from NaFlex collator patch_coord = x['patch_coord'] patch_valid = x['patch_valid'] patches = x['patches'] - if False: - # DEBUG, reconstruct patches - for i in range(len(patches)): - patch = patches[i][patch_valid[i]] - h = (patch_coord[i, :, 0].max() + 1).item() - w = (patch_coord[i, :, 1].max() + 1).item() - patch = patch.reshape(h, w, 16, 16, 3).permute(4, 0, 2, 1, 3) - patch = patch.reshape(3, h*16, w*16) - from torchvision.utils import save_image - save_image(patch, f'patch_{i}.jpg', normalize=True) - else: - patches = x + # DEBUG, reconstruct patches + # for i in range(len(patches)): + # patch = patches[i][patch_valid[i]] + # h = (patch_coord[i, :, 0].max() + 1).item() + # w = (patch_coord[i, :, 1].max() + 1).item() + # patch = patch.reshape(h, w, 16, 16, 3).permute(4, 0, 2, 1, 3) + # patch = patch.reshape(3, h*16, w*16) + # from torchvision.utils import save_image + # save_image(patch, f'patch_{i}.jpg', normalize=True) # Create attention mask if patch_type is provided if patch_valid is not None: