mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Improve several typing issues for flex vit, can (almost) work with jit if we bash h,w key into an int or str
This commit is contained in:
parent
97341fec51
commit
ea728f67fa
@ -321,7 +321,7 @@ def resample_patch_embed(
|
||||
verbose: bool = False,
|
||||
):
|
||||
""" Standalone function (computes matrix on each call). """
|
||||
assert len(patch_embed.shape) == 4, "Input tensor should be 4D (out_c, in_c, h, w)"
|
||||
assert len(patch_embed.shape) == 4, "Input tensor should be 4D (out_ch, in_ch, h, w)"
|
||||
assert len(new_size) == 2, "New shape should only be hw (height, width)"
|
||||
|
||||
old_size_tuple: Tuple[int, int] = tuple(patch_embed.shape[-2:])
|
||||
|
@ -42,7 +42,7 @@ def batch_patchify(
|
||||
pad: bool = True,
|
||||
) -> Tuple[torch.Tensor, Tuple[int, int]]:
|
||||
B, C, H, W = x.shape
|
||||
ph, pw = to_2tuple(patch_size)
|
||||
ph, pw = patch_size
|
||||
|
||||
# Ensure the image is divisible by patch size
|
||||
if pad and (H % ph != 0 or W % pw != 0):
|
||||
@ -202,13 +202,12 @@ class FlexEmbeds(nn.Module):
|
||||
else:
|
||||
return img_size[0] // self.patch_size[0], img_size[1] // self.patch_size[1]
|
||||
|
||||
def forward(self, x, patch_coord=None, patch_valid=None):
|
||||
def forward(self, x: torch.Tensor, patch_coord: Optional[torch.Tensor] = None):
|
||||
"""Forward pass for combined embedding
|
||||
|
||||
Args:
|
||||
x: Input tensor [B, C, H, W] or pre-patchified [B, N, P*P*C]
|
||||
patch_coord: Optional patch coordinates [B, N, 2] for NaFlex
|
||||
patch_valid: Optional patch type indicators (1=patch, 0=padding) for NaFlex
|
||||
|
||||
Returns:
|
||||
Embedded tensor with position encoding and class/register tokens applied
|
||||
@ -216,7 +215,7 @@ class FlexEmbeds(nn.Module):
|
||||
"""
|
||||
# Apply patch embedding
|
||||
naflex_grid_sizes: Optional[List[Tuple[int, int]]] = None
|
||||
grid_size: Optional[Tuple[int, int]] = None
|
||||
grid_size: Optional[List[int]] = None
|
||||
|
||||
B = x.shape[0]
|
||||
if self.is_linear:
|
||||
@ -227,7 +226,7 @@ class FlexEmbeds(nn.Module):
|
||||
# Calculate the appropriate grid size from coords
|
||||
max_y = patch_coord[:, :, 0].max(dim=1)[0] + 1
|
||||
max_x = patch_coord[:, :, 1].max(dim=1)[0] + 1
|
||||
naflex_grid_sizes = [(h.item(), w.item()) for h, w in zip(max_y, max_x)]
|
||||
naflex_grid_sizes = [(int(h.item()), int(w.item())) for h, w in zip(max_y, max_x)]
|
||||
else:
|
||||
_assert(x.ndim == 4, 'Expecting 2D image input with input ndim == 4')
|
||||
x, grid_size = batch_patchify(x, self.patch_size, pad=self.dynamic_img_pad)
|
||||
@ -257,6 +256,7 @@ class FlexEmbeds(nn.Module):
|
||||
if naflex_grid_sizes is not None:
|
||||
self._apply_learned_naflex_pos_embed(x, naflex_grid_sizes=naflex_grid_sizes)
|
||||
else:
|
||||
assert grid_size is not None
|
||||
self._apply_learned_pos_embed(x, grid_size=grid_size)
|
||||
elif self.pos_embed_type == 'rope':
|
||||
assert False, "ROPE not yet implemented"
|
||||
@ -287,15 +287,19 @@ class FlexEmbeds(nn.Module):
|
||||
orig_h, orig_w = self.pos_embed.shape[1:3]
|
||||
|
||||
# Determine unique grid sizes
|
||||
size_to_indices = {}
|
||||
size_to_indices: Dict[Tuple[int, int], List[int]] = {}
|
||||
for bi, (h, w) in enumerate(naflex_grid_sizes):
|
||||
if not (h, w) in size_to_indices:
|
||||
size_to_indices[(h, w)] = [bi]
|
||||
#k = h << 16 | w # FIXME can get jit compat with this
|
||||
k = (h, w)
|
||||
if not k in size_to_indices:
|
||||
size_to_indices[k] = [bi]
|
||||
else:
|
||||
size_to_indices[(h, w)].append(bi)
|
||||
size_to_indices[k].append(bi)
|
||||
|
||||
# Handle each batch element separately with its own grid size
|
||||
for (h, w), batch_indices in size_to_indices.items():
|
||||
for k, batch_indices in size_to_indices.items():
|
||||
h, w = k
|
||||
#h, w = k >> 16, k & 0xFFFF # FIXME can get jit compat with this
|
||||
# Interpolate only once for this (h, w)
|
||||
if (h == orig_h) and (w == orig_w):
|
||||
pos_embed_flat = self.pos_embed.reshape(orig_h * orig_w, -1)
|
||||
@ -315,7 +319,7 @@ class FlexEmbeds(nn.Module):
|
||||
def _apply_learned_pos_embed(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
grid_size: Tuple[int, int],
|
||||
grid_size: List[int],
|
||||
):
|
||||
orig_h, orig_w = self.pos_embed.shape[1:3]
|
||||
if grid_size[0] != orig_h or grid_size[1] != orig_w:
|
||||
@ -340,7 +344,7 @@ class FlexEmbeds(nn.Module):
|
||||
|
||||
@register_notrace_function
|
||||
def create_attention_mask(
|
||||
patch_valid: Optional[torch.Tensor],
|
||||
patch_valid: torch.Tensor,
|
||||
num_prefix_tokens: int = 0,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
) -> torch.Tensor:
|
||||
@ -357,7 +361,7 @@ def create_attention_mask(
|
||||
Attention mask of shape [B, seq_len, seq_len] where seq_len = N + num_prefix_tokens,
|
||||
or None if patch_type is None
|
||||
"""
|
||||
patch_valid = patch_valid.bool()
|
||||
patch_valid = patch_valid.to(torch.bool)
|
||||
B = patch_valid.shape[0]
|
||||
|
||||
if num_prefix_tokens > 0:
|
||||
@ -373,7 +377,7 @@ def create_attention_mask(
|
||||
|
||||
@register_notrace_function
|
||||
def create_attention_mask2(
|
||||
patch_valid: Optional[torch.Tensor],
|
||||
patch_valid: torch.Tensor,
|
||||
num_prefix_tokens: int = 0,
|
||||
q_len: Optional[int] = None,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
@ -411,7 +415,7 @@ def create_attention_mask2(
|
||||
|
||||
@register_notrace_function
|
||||
def create_pool_mask(
|
||||
patch_valid: Optional[torch.Tensor],
|
||||
patch_valid:torch.Tensor,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
) -> torch.Tensor:
|
||||
patch_valid = patch_valid.bool()
|
||||
@ -773,8 +777,16 @@ class VisionTransformerFlex(nn.Module):
|
||||
patch_valid: Optional[torch.Tensor] = None,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
if attn_mask is None and patch_valid is not None:
|
||||
attn_mask = create_attention_mask(
|
||||
patch_valid,
|
||||
num_prefix_tokens=self.num_prefix_tokens,
|
||||
dtype=x.dtype
|
||||
)
|
||||
|
||||
# Pass through embedding module with patch coordinate/type support
|
||||
x = self.embeds(x, patch_coord=patch_coord, patch_valid=patch_valid)
|
||||
x = self.embeds(x, patch_coord=patch_coord)
|
||||
|
||||
# Apply transformer blocks with masked attention if mask provided
|
||||
if attn_mask is not None:
|
||||
@ -827,7 +839,7 @@ class VisionTransformerFlex(nn.Module):
|
||||
|
||||
# For max pooling with mask
|
||||
masked_x = x.clone()
|
||||
masked_x[~patch_valid] = torch.finfo(masked_x.dtype).min
|
||||
masked_x[~patch_valid] = -1e4 # torch.finfo(masked_x.dtype).min
|
||||
masked_max = masked_x.max(dim=1)[0]
|
||||
|
||||
# Combine average and max
|
||||
@ -864,27 +876,23 @@ class VisionTransformerFlex(nn.Module):
|
||||
Returns:
|
||||
Model output tensor
|
||||
"""
|
||||
# Handle dictionary input from NaFlex collator
|
||||
if isinstance(x, dict):
|
||||
assert patch_coord is None
|
||||
assert patch_valid is None
|
||||
# Extract the required components from the dictionary
|
||||
if isinstance(x, torch.Tensor):
|
||||
patches = x
|
||||
else:
|
||||
# Handle dictionary input from NaFlex collator
|
||||
patch_coord = x['patch_coord']
|
||||
patch_valid = x['patch_valid']
|
||||
patches = x['patches']
|
||||
|
||||
if False:
|
||||
# DEBUG, reconstruct patches
|
||||
for i in range(len(patches)):
|
||||
patch = patches[i][patch_valid[i]]
|
||||
h = (patch_coord[i, :, 0].max() + 1).item()
|
||||
w = (patch_coord[i, :, 1].max() + 1).item()
|
||||
patch = patch.reshape(h, w, 16, 16, 3).permute(4, 0, 2, 1, 3)
|
||||
patch = patch.reshape(3, h*16, w*16)
|
||||
from torchvision.utils import save_image
|
||||
save_image(patch, f'patch_{i}.jpg', normalize=True)
|
||||
else:
|
||||
patches = x
|
||||
# DEBUG, reconstruct patches
|
||||
# for i in range(len(patches)):
|
||||
# patch = patches[i][patch_valid[i]]
|
||||
# h = (patch_coord[i, :, 0].max() + 1).item()
|
||||
# w = (patch_coord[i, :, 1].max() + 1).item()
|
||||
# patch = patch.reshape(h, w, 16, 16, 3).permute(4, 0, 2, 1, 3)
|
||||
# patch = patch.reshape(3, h*16, w*16)
|
||||
# from torchvision.utils import save_image
|
||||
# save_image(patch, f'patch_{i}.jpg', normalize=True)
|
||||
|
||||
# Create attention mask if patch_type is provided
|
||||
if patch_valid is not None:
|
||||
|
Loading…
x
Reference in New Issue
Block a user