Improve several typing issues for flex vit, can (almost) work with jit if we bash h,w key into an int or str

This commit is contained in:
Ross Wightman 2025-04-14 11:01:56 -07:00
parent 97341fec51
commit ea728f67fa
2 changed files with 43 additions and 35 deletions

View File

@ -321,7 +321,7 @@ def resample_patch_embed(
verbose: bool = False,
):
""" Standalone function (computes matrix on each call). """
assert len(patch_embed.shape) == 4, "Input tensor should be 4D (out_c, in_c, h, w)"
assert len(patch_embed.shape) == 4, "Input tensor should be 4D (out_ch, in_ch, h, w)"
assert len(new_size) == 2, "New shape should only be hw (height, width)"
old_size_tuple: Tuple[int, int] = tuple(patch_embed.shape[-2:])

View File

@ -42,7 +42,7 @@ def batch_patchify(
pad: bool = True,
) -> Tuple[torch.Tensor, Tuple[int, int]]:
B, C, H, W = x.shape
ph, pw = to_2tuple(patch_size)
ph, pw = patch_size
# Ensure the image is divisible by patch size
if pad and (H % ph != 0 or W % pw != 0):
@ -202,13 +202,12 @@ class FlexEmbeds(nn.Module):
else:
return img_size[0] // self.patch_size[0], img_size[1] // self.patch_size[1]
def forward(self, x, patch_coord=None, patch_valid=None):
def forward(self, x: torch.Tensor, patch_coord: Optional[torch.Tensor] = None):
"""Forward pass for combined embedding
Args:
x: Input tensor [B, C, H, W] or pre-patchified [B, N, P*P*C]
patch_coord: Optional patch coordinates [B, N, 2] for NaFlex
patch_valid: Optional patch type indicators (1=patch, 0=padding) for NaFlex
Returns:
Embedded tensor with position encoding and class/register tokens applied
@ -216,7 +215,7 @@ class FlexEmbeds(nn.Module):
"""
# Apply patch embedding
naflex_grid_sizes: Optional[List[Tuple[int, int]]] = None
grid_size: Optional[Tuple[int, int]] = None
grid_size: Optional[List[int]] = None
B = x.shape[0]
if self.is_linear:
@ -227,7 +226,7 @@ class FlexEmbeds(nn.Module):
# Calculate the appropriate grid size from coords
max_y = patch_coord[:, :, 0].max(dim=1)[0] + 1
max_x = patch_coord[:, :, 1].max(dim=1)[0] + 1
naflex_grid_sizes = [(h.item(), w.item()) for h, w in zip(max_y, max_x)]
naflex_grid_sizes = [(int(h.item()), int(w.item())) for h, w in zip(max_y, max_x)]
else:
_assert(x.ndim == 4, 'Expecting 2D image input with input ndim == 4')
x, grid_size = batch_patchify(x, self.patch_size, pad=self.dynamic_img_pad)
@ -257,6 +256,7 @@ class FlexEmbeds(nn.Module):
if naflex_grid_sizes is not None:
self._apply_learned_naflex_pos_embed(x, naflex_grid_sizes=naflex_grid_sizes)
else:
assert grid_size is not None
self._apply_learned_pos_embed(x, grid_size=grid_size)
elif self.pos_embed_type == 'rope':
assert False, "ROPE not yet implemented"
@ -287,15 +287,19 @@ class FlexEmbeds(nn.Module):
orig_h, orig_w = self.pos_embed.shape[1:3]
# Determine unique grid sizes
size_to_indices = {}
size_to_indices: Dict[Tuple[int, int], List[int]] = {}
for bi, (h, w) in enumerate(naflex_grid_sizes):
if not (h, w) in size_to_indices:
size_to_indices[(h, w)] = [bi]
#k = h << 16 | w # FIXME can get jit compat with this
k = (h, w)
if not k in size_to_indices:
size_to_indices[k] = [bi]
else:
size_to_indices[(h, w)].append(bi)
size_to_indices[k].append(bi)
# Handle each batch element separately with its own grid size
for (h, w), batch_indices in size_to_indices.items():
for k, batch_indices in size_to_indices.items():
h, w = k
#h, w = k >> 16, k & 0xFFFF # FIXME can get jit compat with this
# Interpolate only once for this (h, w)
if (h == orig_h) and (w == orig_w):
pos_embed_flat = self.pos_embed.reshape(orig_h * orig_w, -1)
@ -315,7 +319,7 @@ class FlexEmbeds(nn.Module):
def _apply_learned_pos_embed(
self,
x: torch.Tensor,
grid_size: Tuple[int, int],
grid_size: List[int],
):
orig_h, orig_w = self.pos_embed.shape[1:3]
if grid_size[0] != orig_h or grid_size[1] != orig_w:
@ -340,7 +344,7 @@ class FlexEmbeds(nn.Module):
@register_notrace_function
def create_attention_mask(
patch_valid: Optional[torch.Tensor],
patch_valid: torch.Tensor,
num_prefix_tokens: int = 0,
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
@ -357,7 +361,7 @@ def create_attention_mask(
Attention mask of shape [B, seq_len, seq_len] where seq_len = N + num_prefix_tokens,
or None if patch_type is None
"""
patch_valid = patch_valid.bool()
patch_valid = patch_valid.to(torch.bool)
B = patch_valid.shape[0]
if num_prefix_tokens > 0:
@ -373,7 +377,7 @@ def create_attention_mask(
@register_notrace_function
def create_attention_mask2(
patch_valid: Optional[torch.Tensor],
patch_valid: torch.Tensor,
num_prefix_tokens: int = 0,
q_len: Optional[int] = None,
dtype: torch.dtype = torch.float32,
@ -411,7 +415,7 @@ def create_attention_mask2(
@register_notrace_function
def create_pool_mask(
patch_valid: Optional[torch.Tensor],
patch_valid:torch.Tensor,
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
patch_valid = patch_valid.bool()
@ -773,8 +777,16 @@ class VisionTransformerFlex(nn.Module):
patch_valid: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if attn_mask is None and patch_valid is not None:
attn_mask = create_attention_mask(
patch_valid,
num_prefix_tokens=self.num_prefix_tokens,
dtype=x.dtype
)
# Pass through embedding module with patch coordinate/type support
x = self.embeds(x, patch_coord=patch_coord, patch_valid=patch_valid)
x = self.embeds(x, patch_coord=patch_coord)
# Apply transformer blocks with masked attention if mask provided
if attn_mask is not None:
@ -827,7 +839,7 @@ class VisionTransformerFlex(nn.Module):
# For max pooling with mask
masked_x = x.clone()
masked_x[~patch_valid] = torch.finfo(masked_x.dtype).min
masked_x[~patch_valid] = -1e4 # torch.finfo(masked_x.dtype).min
masked_max = masked_x.max(dim=1)[0]
# Combine average and max
@ -864,27 +876,23 @@ class VisionTransformerFlex(nn.Module):
Returns:
Model output tensor
"""
# Handle dictionary input from NaFlex collator
if isinstance(x, dict):
assert patch_coord is None
assert patch_valid is None
# Extract the required components from the dictionary
if isinstance(x, torch.Tensor):
patches = x
else:
# Handle dictionary input from NaFlex collator
patch_coord = x['patch_coord']
patch_valid = x['patch_valid']
patches = x['patches']
if False:
# DEBUG, reconstruct patches
for i in range(len(patches)):
patch = patches[i][patch_valid[i]]
h = (patch_coord[i, :, 0].max() + 1).item()
w = (patch_coord[i, :, 1].max() + 1).item()
patch = patch.reshape(h, w, 16, 16, 3).permute(4, 0, 2, 1, 3)
patch = patch.reshape(3, h*16, w*16)
from torchvision.utils import save_image
save_image(patch, f'patch_{i}.jpg', normalize=True)
else:
patches = x
# DEBUG, reconstruct patches
# for i in range(len(patches)):
# patch = patches[i][patch_valid[i]]
# h = (patch_coord[i, :, 0].max() + 1).item()
# w = (patch_coord[i, :, 1].max() + 1).item()
# patch = patch.reshape(h, w, 16, 16, 3).permute(4, 0, 2, 1, 3)
# patch = patch.reshape(3, h*16, w*16)
# from torchvision.utils import save_image
# save_image(patch, f'patch_{i}.jpg', normalize=True)
# Create attention mask if patch_type is provided
if patch_valid is not None: