mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
wip - attempting to rebase
This commit is contained in:
parent
02c3a75a45
commit
0149ec30d7
@ -24,7 +24,7 @@ import torch.nn.functional as F
|
|||||||
from torch.fx.graph_module import _copy_attr
|
from torch.fx.graph_module import _copy_attr
|
||||||
|
|
||||||
from .features import _get_feature_info
|
from .features import _get_feature_info
|
||||||
from .fx_helpers import fx_and, fx_float_to_int
|
from .fx_helpers import fx_float_to_int
|
||||||
|
|
||||||
# Layers we went to treat as leaf modules for FeatureGraphNet
|
# Layers we went to treat as leaf modules for FeatureGraphNet
|
||||||
from .layers import Conv2dSame, ScaledStdConv2dSame, BatchNormAct2d, BlurPool2d, CondConv2d, StdConv2dSame
|
from .layers import Conv2dSame, ScaledStdConv2dSame, BatchNormAct2d, BlurPool2d, CondConv2d, StdConv2dSame
|
||||||
@ -55,7 +55,7 @@ def register_leaf_module(module: nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
# These functions will not be traced through
|
# These functions will not be traced through
|
||||||
_autowrap_functions=(fx_float_to_int, fx_and)
|
_autowrap_functions=(fx_float_to_int,)
|
||||||
|
|
||||||
|
|
||||||
class TimmTracer(fx.Tracer):
|
class TimmTracer(fx.Tracer):
|
||||||
|
@ -1,14 +1,4 @@
|
|||||||
|
|
||||||
|
|
||||||
def fx_and(a: bool, b: bool) -> bool:
|
|
||||||
"""
|
|
||||||
Symbolic tracing helper to substitute for normal usage of `* and *` within `torch._assert`.
|
|
||||||
Hint: Symbolic tracing does not support control flow but since an `assert` is either a dead-end or not, this hack
|
|
||||||
is okay.
|
|
||||||
"""
|
|
||||||
return (a and b)
|
|
||||||
|
|
||||||
|
|
||||||
def fx_float_to_int(x: float) -> int:
|
def fx_float_to_int(x: float) -> int:
|
||||||
"""
|
"""
|
||||||
Symbolic tracing helper to substitute for inbuilt `int`.
|
Symbolic tracing helper to substitute for inbuilt `int`.
|
||||||
|
@ -22,7 +22,6 @@ import torch.nn.functional as F
|
|||||||
|
|
||||||
from .helpers import to_2tuple, make_divisible
|
from .helpers import to_2tuple, make_divisible
|
||||||
from .weight_init import trunc_normal_
|
from .weight_init import trunc_normal_
|
||||||
from timm.models.fx_helpers import fx_and
|
|
||||||
|
|
||||||
|
|
||||||
def rel_logits_1d(q, rel_k, permute_mask: List[int]):
|
def rel_logits_1d(q, rel_k, permute_mask: List[int]):
|
||||||
|
@ -24,7 +24,7 @@ import torch.nn.functional as F
|
|||||||
|
|
||||||
from .helpers import make_divisible
|
from .helpers import make_divisible
|
||||||
from .weight_init import trunc_normal_
|
from .weight_init import trunc_normal_
|
||||||
from timm.models.fx_helpers import fx_and
|
from timm.models.fx_helpers import
|
||||||
|
|
||||||
|
|
||||||
def rel_logits_1d(q, rel_k, permute_mask: List[int]):
|
def rel_logits_1d(q, rel_k, permute_mask: List[int]):
|
||||||
|
@ -10,7 +10,6 @@ from torch.nn import functional as F
|
|||||||
|
|
||||||
from .conv_bn_act import ConvBnAct
|
from .conv_bn_act import ConvBnAct
|
||||||
from .helpers import make_divisible
|
from .helpers import make_divisible
|
||||||
from timm.models.fx_helpers import fx_and
|
|
||||||
|
|
||||||
|
|
||||||
class NonLocalAttn(nn.Module):
|
class NonLocalAttn(nn.Module):
|
||||||
@ -96,7 +95,8 @@ class BilinearAttnTransform(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
torch._assert(fx_and(x.shape[-1] % self.block_size == 0, x.shape[-2] % self.block_size == 0), '')
|
torch._assert(x.shape[-1] % self.block_size == 0, '')
|
||||||
|
torch._assert(x.shape[-2] % self.block_size == 0, '')
|
||||||
B, C, H, W = x.shape
|
B, C, H, W = x.shape
|
||||||
out = self.conv1(x)
|
out = self.conv1(x)
|
||||||
rp = F.adaptive_max_pool2d(out, (self.block_size, 1))
|
rp = F.adaptive_max_pool2d(out, (self.block_size, 1))
|
||||||
|
@ -9,11 +9,7 @@ Hacked together by / Copyright 2020 Ross Wightman
|
|||||||
from torch import nn as nn
|
from torch import nn as nn
|
||||||
|
|
||||||
from .helpers import to_2tuple
|
from .helpers import to_2tuple
|
||||||
<<<<<<< HEAD
|
|
||||||
from .trace_utils import _assert
|
from .trace_utils import _assert
|
||||||
=======
|
|
||||||
from timm.models.fx_helpers import fx_and
|
|
||||||
>>>>>>> Make all models FX traceable
|
|
||||||
|
|
||||||
|
|
||||||
class PatchEmbed(nn.Module):
|
class PatchEmbed(nn.Module):
|
||||||
|
@ -12,7 +12,6 @@ import torch.nn as nn
|
|||||||
|
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from timm.models.helpers import build_model_with_cfg
|
from timm.models.helpers import build_model_with_cfg
|
||||||
from timm.models.fx_helpers import fx_and
|
|
||||||
from timm.models.layers import Mlp, DropPath, trunc_normal_
|
from timm.models.layers import Mlp, DropPath, trunc_normal_
|
||||||
from timm.models.layers.helpers import to_2tuple
|
from timm.models.layers.helpers import to_2tuple
|
||||||
from timm.models.registry import register_model
|
from timm.models.registry import register_model
|
||||||
@ -138,7 +137,9 @@ class PixelEmbed(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x, pixel_pos):
|
def forward(self, x, pixel_pos):
|
||||||
B, C, H, W = x.shape
|
B, C, H, W = x.shape
|
||||||
torch._assert(fx_and(H == self.img_size[0], W == self.img_size[1]),
|
torch._assert(H == self.img_size[0],
|
||||||
|
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
|
||||||
|
torch._assert(W == self.img_size[1],
|
||||||
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
|
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
|
||||||
x = self.proj(x)
|
x = self.proj(x)
|
||||||
x = self.unfold(x)
|
x = self.unfold(x)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user