mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Exclude embeds module and mask attn functions from tracing
This commit is contained in:
parent
13e0f3a4a3
commit
b4bb0f452a
@ -27,6 +27,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCE
|
||||
from timm.layers import AttentionPoolLatent, Mlp, to_2tuple, get_act_layer, get_norm_layer, LayerType, _assert
|
||||
from timm.models._builder import build_model_with_cfg
|
||||
from timm.models._features import feature_take_indices
|
||||
from timm.models._features_fx import register_notrace_function, register_notrace_module
|
||||
from timm.models._registry import register_model, generate_default_cfgs
|
||||
from timm.models._manipulate import checkpoint_seq, named_apply
|
||||
|
||||
@ -55,6 +56,7 @@ def batch_patchify(
|
||||
return patches, (nh, nw)
|
||||
|
||||
|
||||
@register_notrace_module
|
||||
class FlexEmbeds(nn.Module):
|
||||
""" Na(Flex) Embedding module for Vision Transformers
|
||||
|
||||
@ -216,18 +218,18 @@ class FlexEmbeds(nn.Module):
|
||||
naflex_grid_sizes: Optional[List[Tuple[int, int]]] = None
|
||||
grid_size: Optional[Tuple[int, int]] = None
|
||||
|
||||
B = x.shape[0]
|
||||
if self.is_linear:
|
||||
# Linear embedding path, works with NaFlex mode or standard 2D mode
|
||||
B = x.shape[0]
|
||||
if x.ndim == 3:
|
||||
# pre-patchified NaFlex mode, input is expected to be (B, N, P*P*C) where N is num_patches
|
||||
_assert(patch_coord is not None, 'patch_coord must not be None in NaFlex mode')
|
||||
|
||||
if patch_coord is not None:
|
||||
_assert(x.ndim == 3, 'Expecting patchified input with ndim == 3')
|
||||
# Pre-patchified NaFlex mode, input is expected to be (B, N, P*P*C) where N is num_patches
|
||||
# Calculate the appropriate grid size from coords
|
||||
max_y = patch_coord[:, :, 0].max(dim=1)[0] + 1
|
||||
max_x = patch_coord[:, :, 1].max(dim=1)[0] + 1
|
||||
naflex_grid_sizes = [(h.item(), w.item()) for h, w in zip(max_y, max_x)]
|
||||
else:
|
||||
_assert(x.ndim == 4, 'Expecting 2D image input with input ndim == 4')
|
||||
x, grid_size = batch_patchify(x, self.patch_size, pad=self.dynamic_img_pad)
|
||||
|
||||
if self.norm_input is not None:
|
||||
@ -252,7 +254,7 @@ class FlexEmbeds(nn.Module):
|
||||
x = self.norm_proj(x)
|
||||
|
||||
if self.pos_embed_type == 'learned':
|
||||
if naflex_grid_sizes:
|
||||
if naflex_grid_sizes is not None:
|
||||
self._apply_learned_naflex_pos_embed(x, naflex_grid_sizes=naflex_grid_sizes)
|
||||
else:
|
||||
self._apply_learned_pos_embed(x, grid_size=grid_size)
|
||||
@ -336,6 +338,7 @@ class FlexEmbeds(nn.Module):
|
||||
x.add_(pos_embed)
|
||||
|
||||
|
||||
@register_notrace_function
|
||||
def create_attention_mask(
|
||||
patch_valid: Optional[torch.Tensor],
|
||||
num_prefix_tokens: int = 0,
|
||||
@ -367,6 +370,8 @@ def create_attention_mask(
|
||||
|
||||
return mask_float
|
||||
|
||||
|
||||
@register_notrace_function
|
||||
def create_attention_mask2(
|
||||
patch_valid: Optional[torch.Tensor],
|
||||
num_prefix_tokens: int = 0,
|
||||
@ -404,6 +409,7 @@ def create_attention_mask2(
|
||||
return mask_float
|
||||
|
||||
|
||||
@register_notrace_function
|
||||
def create_pool_mask(
|
||||
patch_valid: Optional[torch.Tensor],
|
||||
dtype: torch.dtype = torch.float32,
|
||||
|
Loading…
x
Reference in New Issue
Block a user