Exclude embeds module and mask attn functions from tracing

This commit is contained in:
Ross Wightman 2025-04-09 15:34:15 -07:00
parent 13e0f3a4a3
commit b4bb0f452a

View File

@ -27,6 +27,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCE
from timm.layers import AttentionPoolLatent, Mlp, to_2tuple, get_act_layer, get_norm_layer, LayerType, _assert
from timm.models._builder import build_model_with_cfg
from timm.models._features import feature_take_indices
from timm.models._features_fx import register_notrace_function, register_notrace_module
from timm.models._registry import register_model, generate_default_cfgs
from timm.models._manipulate import checkpoint_seq, named_apply
@ -55,6 +56,7 @@ def batch_patchify(
return patches, (nh, nw)
@register_notrace_module
class FlexEmbeds(nn.Module):
""" Na(Flex) Embedding module for Vision Transformers
@ -216,18 +218,18 @@ class FlexEmbeds(nn.Module):
naflex_grid_sizes: Optional[List[Tuple[int, int]]] = None
grid_size: Optional[Tuple[int, int]] = None
B = x.shape[0]
if self.is_linear:
# Linear embedding path, works with NaFlex mode or standard 2D mode
B = x.shape[0]
if x.ndim == 3:
# pre-patchified NaFlex mode, input is expected to be (B, N, P*P*C) where N is num_patches
_assert(patch_coord is not None, 'patch_coord must not be None in NaFlex mode')
if patch_coord is not None:
_assert(x.ndim == 3, 'Expecting patchified input with ndim == 3')
# Pre-patchified NaFlex mode, input is expected to be (B, N, P*P*C) where N is num_patches
# Calculate the appropriate grid size from coords
max_y = patch_coord[:, :, 0].max(dim=1)[0] + 1
max_x = patch_coord[:, :, 1].max(dim=1)[0] + 1
naflex_grid_sizes = [(h.item(), w.item()) for h, w in zip(max_y, max_x)]
else:
_assert(x.ndim == 4, 'Expecting 2D image input with input ndim == 4')
x, grid_size = batch_patchify(x, self.patch_size, pad=self.dynamic_img_pad)
if self.norm_input is not None:
@ -252,7 +254,7 @@ class FlexEmbeds(nn.Module):
x = self.norm_proj(x)
if self.pos_embed_type == 'learned':
if naflex_grid_sizes:
if naflex_grid_sizes is not None:
self._apply_learned_naflex_pos_embed(x, naflex_grid_sizes=naflex_grid_sizes)
else:
self._apply_learned_pos_embed(x, grid_size=grid_size)
@ -336,6 +338,7 @@ class FlexEmbeds(nn.Module):
x.add_(pos_embed)
@register_notrace_function
def create_attention_mask(
patch_valid: Optional[torch.Tensor],
num_prefix_tokens: int = 0,
@ -367,6 +370,8 @@ def create_attention_mask(
return mask_float
@register_notrace_function
def create_attention_mask2(
patch_valid: Optional[torch.Tensor],
num_prefix_tokens: int = 0,
@ -404,6 +409,7 @@ def create_attention_mask2(
return mask_float
@register_notrace_function
def create_pool_mask(
patch_valid: Optional[torch.Tensor],
dtype: torch.dtype = torch.float32,