diff --git a/timm/models/vision_transformer_flex.py b/timm/models/vision_transformer_flex.py index b849f2c5..63db0f70 100644 --- a/timm/models/vision_transformer_flex.py +++ b/timm/models/vision_transformer_flex.py @@ -27,6 +27,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCE from timm.layers import AttentionPoolLatent, Mlp, to_2tuple, get_act_layer, get_norm_layer, LayerType, _assert from timm.models._builder import build_model_with_cfg from timm.models._features import feature_take_indices +from timm.models._features_fx import register_notrace_function, register_notrace_module from timm.models._registry import register_model, generate_default_cfgs from timm.models._manipulate import checkpoint_seq, named_apply @@ -55,6 +56,7 @@ def batch_patchify( return patches, (nh, nw) +@register_notrace_module class FlexEmbeds(nn.Module): """ Na(Flex) Embedding module for Vision Transformers @@ -216,18 +218,18 @@ class FlexEmbeds(nn.Module): naflex_grid_sizes: Optional[List[Tuple[int, int]]] = None grid_size: Optional[Tuple[int, int]] = None + B = x.shape[0] if self.is_linear: # Linear embedding path, works with NaFlex mode or standard 2D mode - B = x.shape[0] - if x.ndim == 3: - # pre-patchified NaFlex mode, input is expected to be (B, N, P*P*C) where N is num_patches - _assert(patch_coord is not None, 'patch_coord must not be None in NaFlex mode') - + if patch_coord is not None: + _assert(x.ndim == 3, 'Expecting patchified input with ndim == 3') + # Pre-patchified NaFlex mode, input is expected to be (B, N, P*P*C) where N is num_patches # Calculate the appropriate grid size from coords max_y = patch_coord[:, :, 0].max(dim=1)[0] + 1 max_x = patch_coord[:, :, 1].max(dim=1)[0] + 1 naflex_grid_sizes = [(h.item(), w.item()) for h, w in zip(max_y, max_x)] else: + _assert(x.ndim == 4, 'Expecting 2D image input with input ndim == 4') x, grid_size = batch_patchify(x, self.patch_size, pad=self.dynamic_img_pad) if self.norm_input is not None: @@ -252,7 +254,7 @@ class FlexEmbeds(nn.Module): x = self.norm_proj(x) if self.pos_embed_type == 'learned': - if naflex_grid_sizes: + if naflex_grid_sizes is not None: self._apply_learned_naflex_pos_embed(x, naflex_grid_sizes=naflex_grid_sizes) else: self._apply_learned_pos_embed(x, grid_size=grid_size) @@ -336,6 +338,7 @@ class FlexEmbeds(nn.Module): x.add_(pos_embed) +@register_notrace_function def create_attention_mask( patch_valid: Optional[torch.Tensor], num_prefix_tokens: int = 0, @@ -367,6 +370,8 @@ def create_attention_mask( return mask_float + +@register_notrace_function def create_attention_mask2( patch_valid: Optional[torch.Tensor], num_prefix_tokens: int = 0, @@ -404,6 +409,7 @@ def create_attention_mask2( return mask_float +@register_notrace_function def create_pool_mask( patch_valid: Optional[torch.Tensor], dtype: torch.dtype = torch.float32,