mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Exclude embeds module and mask attn functions from tracing
This commit is contained in:
parent
13e0f3a4a3
commit
b4bb0f452a
@ -27,6 +27,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCE
|
|||||||
from timm.layers import AttentionPoolLatent, Mlp, to_2tuple, get_act_layer, get_norm_layer, LayerType, _assert
|
from timm.layers import AttentionPoolLatent, Mlp, to_2tuple, get_act_layer, get_norm_layer, LayerType, _assert
|
||||||
from timm.models._builder import build_model_with_cfg
|
from timm.models._builder import build_model_with_cfg
|
||||||
from timm.models._features import feature_take_indices
|
from timm.models._features import feature_take_indices
|
||||||
|
from timm.models._features_fx import register_notrace_function, register_notrace_module
|
||||||
from timm.models._registry import register_model, generate_default_cfgs
|
from timm.models._registry import register_model, generate_default_cfgs
|
||||||
from timm.models._manipulate import checkpoint_seq, named_apply
|
from timm.models._manipulate import checkpoint_seq, named_apply
|
||||||
|
|
||||||
@ -55,6 +56,7 @@ def batch_patchify(
|
|||||||
return patches, (nh, nw)
|
return patches, (nh, nw)
|
||||||
|
|
||||||
|
|
||||||
|
@register_notrace_module
|
||||||
class FlexEmbeds(nn.Module):
|
class FlexEmbeds(nn.Module):
|
||||||
""" Na(Flex) Embedding module for Vision Transformers
|
""" Na(Flex) Embedding module for Vision Transformers
|
||||||
|
|
||||||
@ -216,18 +218,18 @@ class FlexEmbeds(nn.Module):
|
|||||||
naflex_grid_sizes: Optional[List[Tuple[int, int]]] = None
|
naflex_grid_sizes: Optional[List[Tuple[int, int]]] = None
|
||||||
grid_size: Optional[Tuple[int, int]] = None
|
grid_size: Optional[Tuple[int, int]] = None
|
||||||
|
|
||||||
|
B = x.shape[0]
|
||||||
if self.is_linear:
|
if self.is_linear:
|
||||||
# Linear embedding path, works with NaFlex mode or standard 2D mode
|
# Linear embedding path, works with NaFlex mode or standard 2D mode
|
||||||
B = x.shape[0]
|
if patch_coord is not None:
|
||||||
if x.ndim == 3:
|
_assert(x.ndim == 3, 'Expecting patchified input with ndim == 3')
|
||||||
# pre-patchified NaFlex mode, input is expected to be (B, N, P*P*C) where N is num_patches
|
# Pre-patchified NaFlex mode, input is expected to be (B, N, P*P*C) where N is num_patches
|
||||||
_assert(patch_coord is not None, 'patch_coord must not be None in NaFlex mode')
|
|
||||||
|
|
||||||
# Calculate the appropriate grid size from coords
|
# Calculate the appropriate grid size from coords
|
||||||
max_y = patch_coord[:, :, 0].max(dim=1)[0] + 1
|
max_y = patch_coord[:, :, 0].max(dim=1)[0] + 1
|
||||||
max_x = patch_coord[:, :, 1].max(dim=1)[0] + 1
|
max_x = patch_coord[:, :, 1].max(dim=1)[0] + 1
|
||||||
naflex_grid_sizes = [(h.item(), w.item()) for h, w in zip(max_y, max_x)]
|
naflex_grid_sizes = [(h.item(), w.item()) for h, w in zip(max_y, max_x)]
|
||||||
else:
|
else:
|
||||||
|
_assert(x.ndim == 4, 'Expecting 2D image input with input ndim == 4')
|
||||||
x, grid_size = batch_patchify(x, self.patch_size, pad=self.dynamic_img_pad)
|
x, grid_size = batch_patchify(x, self.patch_size, pad=self.dynamic_img_pad)
|
||||||
|
|
||||||
if self.norm_input is not None:
|
if self.norm_input is not None:
|
||||||
@ -252,7 +254,7 @@ class FlexEmbeds(nn.Module):
|
|||||||
x = self.norm_proj(x)
|
x = self.norm_proj(x)
|
||||||
|
|
||||||
if self.pos_embed_type == 'learned':
|
if self.pos_embed_type == 'learned':
|
||||||
if naflex_grid_sizes:
|
if naflex_grid_sizes is not None:
|
||||||
self._apply_learned_naflex_pos_embed(x, naflex_grid_sizes=naflex_grid_sizes)
|
self._apply_learned_naflex_pos_embed(x, naflex_grid_sizes=naflex_grid_sizes)
|
||||||
else:
|
else:
|
||||||
self._apply_learned_pos_embed(x, grid_size=grid_size)
|
self._apply_learned_pos_embed(x, grid_size=grid_size)
|
||||||
@ -336,6 +338,7 @@ class FlexEmbeds(nn.Module):
|
|||||||
x.add_(pos_embed)
|
x.add_(pos_embed)
|
||||||
|
|
||||||
|
|
||||||
|
@register_notrace_function
|
||||||
def create_attention_mask(
|
def create_attention_mask(
|
||||||
patch_valid: Optional[torch.Tensor],
|
patch_valid: Optional[torch.Tensor],
|
||||||
num_prefix_tokens: int = 0,
|
num_prefix_tokens: int = 0,
|
||||||
@ -367,6 +370,8 @@ def create_attention_mask(
|
|||||||
|
|
||||||
return mask_float
|
return mask_float
|
||||||
|
|
||||||
|
|
||||||
|
@register_notrace_function
|
||||||
def create_attention_mask2(
|
def create_attention_mask2(
|
||||||
patch_valid: Optional[torch.Tensor],
|
patch_valid: Optional[torch.Tensor],
|
||||||
num_prefix_tokens: int = 0,
|
num_prefix_tokens: int = 0,
|
||||||
@ -404,6 +409,7 @@ def create_attention_mask2(
|
|||||||
return mask_float
|
return mask_float
|
||||||
|
|
||||||
|
|
||||||
|
@register_notrace_function
|
||||||
def create_pool_mask(
|
def create_pool_mask(
|
||||||
patch_valid: Optional[torch.Tensor],
|
patch_valid: Optional[torch.Tensor],
|
||||||
dtype: torch.dtype = torch.float32,
|
dtype: torch.dtype = torch.float32,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user