mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
rename notrace registration and standardize trace_utils imports
This commit is contained in:
parent
0262a0e8e1
commit
65d827c7a6
@ -19,7 +19,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .helpers import build_model_with_cfg, overlay_external_default_cfg
|
||||
from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_
|
||||
from .registry import register_model
|
||||
from .layers.trace_utils import _assert
|
||||
from .layers import _assert
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
@ -30,7 +30,7 @@ from .helpers import build_model_with_cfg
|
||||
from .layers import DropPath, to_2tuple, trunc_normal_, PatchEmbed, Mlp
|
||||
from .registry import register_model
|
||||
from .vision_transformer_hybrid import HybridEmbed
|
||||
from .fx_features import register_leaf_module
|
||||
from .fx_features import register_notrace_module
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -57,7 +57,7 @@ default_cfgs = {
|
||||
}
|
||||
|
||||
|
||||
@register_leaf_module # reason: FX can't symbolically trace control flow in forward method
|
||||
@register_notrace_module # reason: FX can't symbolically trace control flow in forward method
|
||||
class GPSA(nn.Module):
|
||||
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.,
|
||||
locality_strength=1.):
|
||||
|
@ -32,7 +32,7 @@ from functools import partial
|
||||
from typing import List
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .fx_features import register_autowrap_function
|
||||
from .fx_features import register_notrace_function
|
||||
from .helpers import build_model_with_cfg
|
||||
from .layers import DropPath, to_2tuple, trunc_normal_, _assert
|
||||
from .registry import register_model
|
||||
@ -259,7 +259,7 @@ def _compute_num_patches(img_size, patches):
|
||||
return [i[0] // p * i[1] // p for i, p in zip(img_size, patches)]
|
||||
|
||||
|
||||
@register_autowrap_function
|
||||
@register_notrace_function
|
||||
def scale_image(x, ss: Tuple[int, int], crop_scale: bool = False): # annotations for torchscript
|
||||
"""
|
||||
Pulled out of CrossViT.forward_features to bury conditional logic in a leaf node for FX tracing.
|
||||
|
@ -36,7 +36,7 @@ except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def register_leaf_module(module: nn.Module):
|
||||
def register_notrace_module(module: nn.Module):
|
||||
"""
|
||||
Any module not under timm.models.layers should get this decorator if we don't want to trace through it.
|
||||
"""
|
||||
@ -48,7 +48,7 @@ def register_leaf_module(module: nn.Module):
|
||||
_autowrap_functions = set()
|
||||
|
||||
|
||||
def register_autowrap_function(func: Callable):
|
||||
def register_notrace_function(func: Callable):
|
||||
"""
|
||||
Decorator for functions which ought not to be traced through
|
||||
"""
|
||||
|
@ -25,10 +25,10 @@ import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .fx_features import register_autowrap_function
|
||||
from .fx_features import register_notrace_function
|
||||
from .helpers import build_model_with_cfg, named_apply
|
||||
from .layers import PatchEmbed, Mlp, DropPath, create_classifier, trunc_normal_
|
||||
from .layers.trace_utils import _assert
|
||||
from .layers import _assert
|
||||
from .layers import create_conv2d, create_pool2d, to_ntuple
|
||||
from .registry import register_model
|
||||
|
||||
@ -155,7 +155,7 @@ def blockify(x, block_size: int):
|
||||
return x # (B, T, N, C)
|
||||
|
||||
|
||||
@register_autowrap_function # reason: int receives Proxy
|
||||
@register_notrace_function # reason: int receives Proxy
|
||||
def deblockify(x, block_size: int):
|
||||
"""blocks to image
|
||||
Args:
|
||||
|
@ -26,7 +26,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .fx_features import register_leaf_module
|
||||
from .fx_features import register_notrace_module
|
||||
from .helpers import build_model_with_cfg
|
||||
from .registry import register_model
|
||||
from .layers import ClassifierHead, DropPath, AvgPool2dSame, ScaledStdConv2d, ScaledStdConv2dSame,\
|
||||
@ -319,7 +319,7 @@ class DownsampleAvg(nn.Module):
|
||||
return self.conv(self.pool(x))
|
||||
|
||||
|
||||
@register_leaf_module # reason: mul_ causes FX to drop a relevant node. https://github.com/pytorch/pytorch/issues/68301
|
||||
@register_notrace_module # reason: mul_ causes FX to drop a relevant node. https://github.com/pytorch/pytorch/issues/68301
|
||||
class NormFreeBlock(nn.Module):
|
||||
"""Normalization-Free pre-activation block.
|
||||
"""
|
||||
|
@ -21,10 +21,10 @@ import torch.nn as nn
|
||||
import torch.utils.checkpoint as checkpoint
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .fx_features import register_autowrap_function
|
||||
from .fx_features import register_notrace_function
|
||||
from .helpers import build_model_with_cfg, overlay_external_default_cfg
|
||||
from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_
|
||||
from .layers.trace_utils import _assert
|
||||
from .layers import _assert
|
||||
from .registry import register_model
|
||||
from .vision_transformer import checkpoint_filter_fn, _init_vit_weights
|
||||
|
||||
@ -103,7 +103,7 @@ def window_partition(x, window_size: int):
|
||||
return windows
|
||||
|
||||
|
||||
@register_autowrap_function # reason: int argument is a Proxy
|
||||
@register_notrace_function # reason: int argument is a Proxy
|
||||
def window_reverse(windows, window_size: int, H: int, W: int):
|
||||
"""
|
||||
Args:
|
||||
|
@ -14,7 +14,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.models.helpers import build_model_with_cfg
|
||||
from timm.models.layers import Mlp, DropPath, trunc_normal_
|
||||
from timm.models.layers.helpers import to_2tuple
|
||||
from timm.models.layers.trace_utils import _assert
|
||||
from timm.models.layers import _assert
|
||||
from timm.models.registry import register_model
|
||||
from timm.models.vision_transformer import resize_pos_embed
|
||||
|
||||
|
@ -22,7 +22,7 @@ from functools import partial
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .layers import Mlp, DropPath, to_2tuple, trunc_normal_
|
||||
from .fx_features import register_leaf_module
|
||||
from .fx_features import register_notrace_module
|
||||
from .registry import register_model
|
||||
from .vision_transformer import Attention
|
||||
from .helpers import build_model_with_cfg
|
||||
@ -63,7 +63,7 @@ default_cfgs = {
|
||||
Size_ = Tuple[int, int]
|
||||
|
||||
|
||||
@register_leaf_module # reason: FX can't symbolically trace control flow in forward method
|
||||
@register_notrace_module # reason: FX can't symbolically trace control flow in forward method
|
||||
class LocallyGroupedAttn(nn.Module):
|
||||
""" LSA: self attention within a group
|
||||
"""
|
||||
|
@ -12,7 +12,7 @@ from typing import Union, List, Dict, Any, cast
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .helpers import build_model_with_cfg
|
||||
from .fx_features import register_leaf_module
|
||||
from .fx_features import register_notrace_module
|
||||
from .layers import ClassifierHead
|
||||
from .registry import register_model
|
||||
|
||||
@ -53,7 +53,7 @@ cfgs: Dict[str, List[Union[str, int]]] = {
|
||||
}
|
||||
|
||||
|
||||
@register_leaf_module # reason: FX can't symbolically trace control flow in forward method
|
||||
@register_notrace_module # reason: FX can't symbolically trace control flow in forward method
|
||||
class ConvMlp(nn.Module):
|
||||
|
||||
def __init__(self, in_features=512, out_features=4096, kernel_size=7, mlp_ratio=1.0,
|
||||
|
@ -21,7 +21,7 @@ from .vision_transformer import _cfg, Mlp
|
||||
from .registry import register_model
|
||||
from .layers import DropPath, trunc_normal_, to_2tuple
|
||||
from .cait import ClassAttn
|
||||
from .fx_features import register_leaf_module
|
||||
from .fx_features import register_notrace_module
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
@ -98,7 +98,7 @@ default_cfgs = {
|
||||
}
|
||||
|
||||
|
||||
@register_leaf_module # reason: FX can't symbolically trace torch.arange in forward method
|
||||
@register_notrace_module # reason: FX can't symbolically trace torch.arange in forward method
|
||||
class PositionalEncodingFourier(nn.Module):
|
||||
"""
|
||||
Positional encoding relying on a fourier kernel matching the one used in the "Attention is all of Need" paper.
|
||||
|
Loading…
x
Reference in New Issue
Block a user