Merge branch 'alexander-soare-fx-feature-extract-new'
commit
ee40b582bb
|
@ -4,9 +4,16 @@ import platform
|
|||
import os
|
||||
import fnmatch
|
||||
|
||||
try:
|
||||
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names, NodePathTracer
|
||||
has_fx_feature_extraction = True
|
||||
except ImportError:
|
||||
has_fx_feature_extraction = False
|
||||
|
||||
import timm
|
||||
from timm import list_models, create_model, set_scriptable, has_model_default_key, is_model_default_key, \
|
||||
get_model_default_value
|
||||
from timm.models.fx_features import _leaf_modules, _autowrap_functions
|
||||
|
||||
if hasattr(torch._C, '_jit_set_profiling_executor'):
|
||||
# legacy executor is too slow to compile large models for unit tests
|
||||
|
@ -297,3 +304,144 @@ def test_model_forward_features(model_name, batch_size):
|
|||
assert e == o.shape[1]
|
||||
assert o.shape[0] == batch_size
|
||||
assert not torch.isnan(o).any()
|
||||
|
||||
|
||||
@pytest.mark.timeout(120)
|
||||
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS))
|
||||
@pytest.mark.parametrize('batch_size', [1])
|
||||
def test_model_forward_fx(model_name, batch_size):
|
||||
"""
|
||||
Symbolically trace each model and run single forward pass through the resulting GraphModule
|
||||
Also check that the output of a forward pass through the GraphModule is the same as that from the original Module
|
||||
"""
|
||||
if not has_fx_feature_extraction:
|
||||
pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.")
|
||||
|
||||
model = create_model(model_name, pretrained=False)
|
||||
model.eval()
|
||||
|
||||
input_size = _get_input_size(model=model, target=TARGET_FWD_SIZE)
|
||||
if max(input_size) > MAX_FWD_SIZE:
|
||||
pytest.skip("Fixed input size model > limit.")
|
||||
|
||||
# This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode
|
||||
# So we trace once and look at the graph, and get the indices of the nodes that lead into the original fx output
|
||||
# node. Then we use those indices to select from train_nodes returned by torchvision get_graph_node_names
|
||||
tracer = NodePathTracer(leaf_modules=list(_leaf_modules), autowrap_functions=list(_autowrap_functions))
|
||||
graph = tracer.trace(model)
|
||||
graph_nodes = list(reversed(graph.nodes))
|
||||
output_node_names = [n.name for n in graph_nodes[0]._input_nodes.keys()]
|
||||
graph_node_names = [n.name for n in graph_nodes]
|
||||
output_node_indices = [-graph_node_names.index(node_name) for node_name in output_node_names]
|
||||
train_nodes, eval_nodes = get_graph_node_names(
|
||||
model, tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
|
||||
eval_return_nodes = [eval_nodes[ix] for ix in output_node_indices]
|
||||
|
||||
fx_model = create_feature_extractor(
|
||||
model, train_return_nodes=[train_nodes[-1]], eval_return_nodes=eval_return_nodes,
|
||||
tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
|
||||
|
||||
inputs = torch.randn((batch_size, *input_size))
|
||||
outputs = model(inputs)
|
||||
if isinstance(outputs, tuple):
|
||||
outputs = torch.cat(outputs)
|
||||
fx_outputs = tuple(fx_model(inputs).values())
|
||||
if isinstance(fx_outputs, tuple):
|
||||
fx_outputs = torch.cat(fx_outputs)
|
||||
|
||||
assert torch.all(fx_outputs == outputs)
|
||||
assert outputs.shape[0] == batch_size
|
||||
assert not torch.isnan(outputs).any(), 'Output included NaNs'
|
||||
|
||||
|
||||
@pytest.mark.timeout(120)
|
||||
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS, name_matches_cfg=True))
|
||||
@pytest.mark.parametrize('batch_size', [2])
|
||||
def test_model_backward_fx(model_name, batch_size):
|
||||
"""Symbolically trace each model and run single backward pass through the resulting GraphModule"""
|
||||
if not has_fx_feature_extraction:
|
||||
pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.")
|
||||
|
||||
input_size = _get_input_size(model_name=model_name, target=TARGET_BWD_SIZE)
|
||||
if max(input_size) > MAX_BWD_SIZE:
|
||||
pytest.skip("Fixed input size model > limit.")
|
||||
|
||||
model = create_model(model_name, pretrained=False, num_classes=42)
|
||||
model.train()
|
||||
|
||||
num_params = sum([x.numel() for x in model.parameters()])
|
||||
|
||||
input_size = _get_input_size(model=model, target=TARGET_FWD_SIZE)
|
||||
if max(input_size) > MAX_FWD_SIZE:
|
||||
pytest.skip("Fixed input size model > limit.")
|
||||
|
||||
# This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode
|
||||
# So we trace once and look at the graph, and get the indices of the nodes that lead into the original fx output
|
||||
# node. Then we use those indices to select from train_nodes returned by torchvision get_graph_node_names
|
||||
tracer = NodePathTracer(leaf_modules=list(_leaf_modules), autowrap_functions=list(_autowrap_functions))
|
||||
graph = tracer.trace(model)
|
||||
graph_nodes = list(reversed(graph.nodes))
|
||||
output_node_names = [n.name for n in graph_nodes[0]._input_nodes.keys()]
|
||||
graph_node_names = [n.name for n in graph_nodes]
|
||||
output_node_indices = [-graph_node_names.index(node_name) for node_name in output_node_names]
|
||||
train_nodes, eval_nodes = get_graph_node_names(
|
||||
model, tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
|
||||
train_return_nodes = [train_nodes[ix] for ix in output_node_indices]
|
||||
|
||||
model = create_feature_extractor(
|
||||
model, train_return_nodes=train_return_nodes, eval_return_nodes=[eval_nodes[-1]],
|
||||
tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
|
||||
|
||||
inputs = torch.randn((batch_size, *input_size))
|
||||
outputs = tuple(model(inputs).values())
|
||||
if isinstance(outputs, tuple):
|
||||
outputs = torch.cat(outputs)
|
||||
outputs.mean().backward()
|
||||
for n, x in model.named_parameters():
|
||||
assert x.grad is not None, f'No gradient for {n}'
|
||||
num_grad = sum([x.grad.numel() for x in model.parameters() if x.grad is not None])
|
||||
|
||||
assert outputs.shape[-1] == 42
|
||||
assert num_params == num_grad, 'Some parameters are missing gradients'
|
||||
assert not torch.isnan(outputs).any(), 'Output included NaNs'
|
||||
|
||||
# reason: model is scripted after fx tracing, but beit has torch.jit.is_scripting() control flow
|
||||
EXCLUDE_FX_JIT_FILTERS = [
|
||||
'deit_*_distilled_patch16_224',
|
||||
'levit*',
|
||||
'pit_*_distilled_224',
|
||||
]
|
||||
|
||||
@pytest.mark.timeout(120)
|
||||
@pytest.mark.parametrize(
|
||||
'model_name', list_models(
|
||||
exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS + EXCLUDE_FX_JIT_FILTERS, name_matches_cfg=True))
|
||||
@pytest.mark.parametrize('batch_size', [1])
|
||||
def test_model_forward_fx_torchscript(model_name, batch_size):
|
||||
"""Symbolically trace each model, script it, and run single forward pass"""
|
||||
if not has_fx_feature_extraction:
|
||||
pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.")
|
||||
|
||||
input_size = _get_input_size(model_name=model_name, target=TARGET_JIT_SIZE)
|
||||
if max(input_size) > MAX_JIT_SIZE:
|
||||
pytest.skip("Fixed input size model > limit.")
|
||||
|
||||
with set_scriptable(True):
|
||||
model = create_model(model_name, pretrained=False)
|
||||
model.eval()
|
||||
|
||||
input_size = _get_input_size(model=model, target=TARGET_FWD_SIZE)
|
||||
if max(input_size) > MAX_FWD_SIZE:
|
||||
pytest.skip("Fixed input size model > limit.")
|
||||
|
||||
train_nodes, eval_nodes = get_graph_node_names(
|
||||
model, tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
|
||||
model = create_feature_extractor(
|
||||
model, train_return_nodes=[train_nodes[-1]], eval_return_nodes=[eval_nodes[-1]],
|
||||
tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
|
||||
|
||||
model = torch.jit.script(model)
|
||||
outputs = model(torch.randn((batch_size, *input_size)))[train_nodes[-1]]
|
||||
|
||||
assert outputs.shape[0] == batch_size
|
||||
assert not torch.isnan(outputs).any(), 'Output included NaNs'
|
||||
|
|
|
@ -86,9 +86,11 @@ class Attention(nn.Module):
|
|||
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
|
||||
if qkv_bias:
|
||||
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
|
||||
self.register_buffer('k_bias', torch.zeros(all_head_dim), persistent=False)
|
||||
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
|
||||
else:
|
||||
self.q_bias = None
|
||||
self.k_bias = None
|
||||
self.v_bias = None
|
||||
|
||||
if window_size:
|
||||
|
@ -127,13 +129,7 @@ class Attention(nn.Module):
|
|||
|
||||
def forward(self, x, rel_pos_bias: Optional[torch.Tensor] = None):
|
||||
B, N, C = x.shape
|
||||
qkv_bias = None
|
||||
if self.q_bias is not None:
|
||||
if torch.jit.is_scripting():
|
||||
# FIXME requires_grad breaks w/ torchscript
|
||||
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias), self.v_bias))
|
||||
else:
|
||||
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
|
||||
qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) if self.q_bias is not None else None
|
||||
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
||||
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
||||
|
|
|
@ -19,6 +19,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|||
from .helpers import build_model_with_cfg, overlay_external_default_cfg
|
||||
from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_
|
||||
from .registry import register_model
|
||||
from .layers import _assert
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
@ -105,7 +106,7 @@ class ConvRelPosEnc(nn.Module):
|
|||
def forward(self, q, v, size: Tuple[int, int]):
|
||||
B, h, N, Ch = q.shape
|
||||
H, W = size
|
||||
assert N == 1 + H * W
|
||||
_assert(N == 1 + H * W, '')
|
||||
|
||||
# Convolutional relative position encoding.
|
||||
q_img = q[:, :, 1:, :] # [B, h, H*W, Ch]
|
||||
|
@ -177,7 +178,7 @@ class ConvPosEnc(nn.Module):
|
|||
def forward(self, x, size: Tuple[int, int]):
|
||||
B, N, C = x.shape
|
||||
H, W = size
|
||||
assert N == 1 + H * W
|
||||
_assert(N == 1 + H * W, '')
|
||||
|
||||
# Extract CLS token and image tokens.
|
||||
cls_token, img_tokens = x[:, :1], x[:, 1:] # [B, 1, C], [B, H*W, C]
|
||||
|
@ -275,7 +276,7 @@ class ParallelBlock(nn.Module):
|
|||
""" Feature map interpolation. """
|
||||
B, N, C = x.shape
|
||||
H, W = size
|
||||
assert N == 1 + H * W
|
||||
_assert(N == 1 + H * W, '')
|
||||
|
||||
cls_token = x[:, :1, :]
|
||||
img_tokens = x[:, 1:, :]
|
||||
|
|
|
@ -30,6 +30,7 @@ from .helpers import build_model_with_cfg
|
|||
from .layers import DropPath, to_2tuple, trunc_normal_, PatchEmbed, Mlp
|
||||
from .registry import register_model
|
||||
from .vision_transformer_hybrid import HybridEmbed
|
||||
from .fx_features import register_notrace_module
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -56,6 +57,7 @@ default_cfgs = {
|
|||
}
|
||||
|
||||
|
||||
@register_notrace_module # reason: FX can't symbolically trace control flow in forward method
|
||||
class GPSA(nn.Module):
|
||||
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.,
|
||||
locality_strength=1.):
|
||||
|
|
|
@ -22,6 +22,7 @@ NOTE: model names have been renamed from originals to represent actual input res
|
|||
Modifed from Timm. https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
||||
|
||||
"""
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -31,8 +32,9 @@ from functools import partial
|
|||
from typing import List
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .fx_features import register_notrace_function
|
||||
from .helpers import build_model_with_cfg
|
||||
from .layers import DropPath, to_2tuple, trunc_normal_
|
||||
from .layers import DropPath, to_2tuple, trunc_normal_, _assert
|
||||
from .registry import register_model
|
||||
from .vision_transformer import Mlp, Block
|
||||
|
||||
|
@ -116,8 +118,10 @@ class PatchEmbed(nn.Module):
|
|||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
# FIXME look at relaxing size constraints
|
||||
assert H == self.img_size[0] and W == self.img_size[1], \
|
||||
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
||||
_assert(H == self.img_size[0],
|
||||
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
|
||||
_assert(W == self.img_size[1],
|
||||
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
|
||||
x = self.proj(x).flatten(2).transpose(1, 2)
|
||||
return x
|
||||
|
||||
|
@ -255,6 +259,27 @@ def _compute_num_patches(img_size, patches):
|
|||
return [i[0] // p * i[1] // p for i, p in zip(img_size, patches)]
|
||||
|
||||
|
||||
@register_notrace_function
|
||||
def scale_image(x, ss: Tuple[int, int], crop_scale: bool = False): # annotations for torchscript
|
||||
"""
|
||||
Pulled out of CrossViT.forward_features to bury conditional logic in a leaf node for FX tracing.
|
||||
Args:
|
||||
x (Tensor): input image
|
||||
ss (tuple[int, int]): height and width to scale to
|
||||
crop_scale (bool): whether to crop instead of interpolate to achieve the desired scale. Defaults to False
|
||||
Returns:
|
||||
Tensor: the "scaled" image batch tensor
|
||||
"""
|
||||
H, W = x.shape[-2:]
|
||||
if H != ss[0] or W != ss[1]:
|
||||
if crop_scale and ss[0] <= H and ss[1] <= W:
|
||||
cu, cl = int(round((H - ss[0]) / 2.)), int(round((W - ss[1]) / 2.))
|
||||
x = x[:, :, cu:cu + ss[0], cl:cl + ss[1]]
|
||||
else:
|
||||
x = torch.nn.functional.interpolate(x, size=ss, mode='bicubic', align_corners=False)
|
||||
return x
|
||||
|
||||
|
||||
class CrossViT(nn.Module):
|
||||
""" Vision Transformer with support for patch or hybrid CNN input stage
|
||||
"""
|
||||
|
@ -342,17 +367,12 @@ class CrossViT(nn.Module):
|
|||
range(self.num_branches)])
|
||||
|
||||
def forward_features(self, x):
|
||||
B, C, H, W = x.shape
|
||||
B = x.shape[0]
|
||||
xs = []
|
||||
for i, patch_embed in enumerate(self.patch_embed):
|
||||
x_ = x
|
||||
ss = self.img_size_scaled[i]
|
||||
if H != ss[0] or W != ss[1]:
|
||||
if self.crop_scale and ss[0] <= H and ss[1] <= W:
|
||||
cu, cl = int(round((H - ss[0]) / 2.)), int(round((W - ss[1]) / 2.))
|
||||
x_ = x_[:, :, cu:cu + ss[0], cl:cl + ss[1]]
|
||||
else:
|
||||
x_ = torch.nn.functional.interpolate(x_, size=ss, mode='bicubic', align_corners=False)
|
||||
x_ = scale_image(x_, ss, self.crop_scale)
|
||||
x_ = patch_embed(x_)
|
||||
cls_tokens = self.cls_token_0 if i == 0 else self.cls_token_1 # hard-coded for torch jit script
|
||||
cls_tokens = cls_tokens.expand(B, -1, -1)
|
||||
|
|
|
@ -0,0 +1,73 @@
|
|||
""" PyTorch FX Based Feature Extraction Helpers
|
||||
Using https://pytorch.org/vision/stable/feature_extraction.html
|
||||
"""
|
||||
from typing import Callable
|
||||
from torch import nn
|
||||
|
||||
from .features import _get_feature_info
|
||||
|
||||
try:
|
||||
from torchvision.models.feature_extraction import create_feature_extractor
|
||||
has_fx_feature_extraction = True
|
||||
except ImportError:
|
||||
has_fx_feature_extraction = False
|
||||
|
||||
# Layers we went to treat as leaf modules
|
||||
from .layers import Conv2dSame, ScaledStdConv2dSame, BatchNormAct2d, BlurPool2d, CondConv2d, StdConv2dSame, DropPath
|
||||
from .layers.non_local_attn import BilinearAttnTransform
|
||||
from .layers.pool2d_same import MaxPool2dSame, AvgPool2dSame
|
||||
|
||||
# NOTE: By default, any modules from timm.models.layers that we want to treat as leaf modules go here
|
||||
# BUT modules from timm.models should use the registration mechanism below
|
||||
_leaf_modules = {
|
||||
BatchNormAct2d, # reason: flow control for jit scripting
|
||||
BilinearAttnTransform, # reason: flow control t <= 1
|
||||
BlurPool2d, # reason: TypeError: F.conv2d received Proxy in groups=x.shape[1]
|
||||
# Reason: get_same_padding has a max which raises a control flow error
|
||||
Conv2dSame, MaxPool2dSame, ScaledStdConv2dSame, StdConv2dSame, AvgPool2dSame,
|
||||
CondConv2d, # reason: TypeError: F.conv2d received Proxy in groups=self.groups * B (because B = x.shape[0])
|
||||
DropPath, # reason: TypeError: rand recieved Proxy in `size` argument
|
||||
}
|
||||
|
||||
try:
|
||||
from .layers import InplaceAbn
|
||||
_leaf_modules.add(InplaceAbn)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def register_notrace_module(module: nn.Module):
|
||||
"""
|
||||
Any module not under timm.models.layers should get this decorator if we don't want to trace through it.
|
||||
"""
|
||||
_leaf_modules.add(module)
|
||||
return module
|
||||
|
||||
|
||||
# Functions we want to autowrap (treat them as leaves)
|
||||
_autowrap_functions = set()
|
||||
|
||||
|
||||
def register_notrace_function(func: Callable):
|
||||
"""
|
||||
Decorator for functions which ought not to be traced through
|
||||
"""
|
||||
_autowrap_functions.add(func)
|
||||
return func
|
||||
|
||||
|
||||
class FeatureGraphNet(nn.Module):
|
||||
def __init__(self, model, out_indices, out_map=None):
|
||||
super().__init__()
|
||||
assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
|
||||
self.feature_info = _get_feature_info(model, out_indices)
|
||||
if out_map is not None:
|
||||
assert len(out_map) == len(out_indices)
|
||||
return_nodes = {info['module']: out_map[i] if out_map is not None else info['module']
|
||||
for i, info in enumerate(self.feature_info) if i in out_indices}
|
||||
self.graph_module = create_feature_extractor(
|
||||
model, return_nodes,
|
||||
tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
|
||||
|
||||
def forward(self, x):
|
||||
return list(self.graph_module(x).values())
|
|
@ -14,6 +14,7 @@ import torch.nn as nn
|
|||
|
||||
|
||||
from .features import FeatureListNet, FeatureDictNet, FeatureHookNet
|
||||
from .fx_features import FeatureGraphNet
|
||||
from .hub import has_hf_hub, download_cached_file, load_state_dict_from_hf, load_state_dict_from_url
|
||||
from .layers import Conv2dSame, Linear
|
||||
|
||||
|
@ -477,6 +478,8 @@ def build_model_with_cfg(
|
|||
feature_cls = feature_cls.lower()
|
||||
if 'hook' in feature_cls:
|
||||
feature_cls = FeatureHookNet
|
||||
elif feature_cls == 'fx':
|
||||
feature_cls = FeatureGraphNet
|
||||
else:
|
||||
assert False, f'Unknown feature class {feature_cls}'
|
||||
model = feature_cls(model, **feature_cfg)
|
||||
|
|
|
@ -22,6 +22,7 @@ import torch.nn.functional as F
|
|||
|
||||
from .helpers import to_2tuple, make_divisible
|
||||
from .weight_init import trunc_normal_
|
||||
from .trace_utils import _assert
|
||||
|
||||
|
||||
def rel_logits_1d(q, rel_k, permute_mask: List[int]):
|
||||
|
@ -133,8 +134,8 @@ class BottleneckAttn(nn.Module):
|
|||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
assert H == self.pos_embed.height
|
||||
assert W == self.pos_embed.width
|
||||
_assert(H == self.pos_embed.height, '')
|
||||
_assert(W == self.pos_embed.width, '')
|
||||
|
||||
x = self.qkv(x) # B, (2 * dim_head_qk + dim_head_v) * num_heads, H, W
|
||||
|
||||
|
@ -154,5 +155,3 @@ class BottleneckAttn(nn.Module):
|
|||
out = (attn @ v).transpose(-1, -2).reshape(B, self.dim_out_v, H, W) # B, dim_out, H, W
|
||||
out = self.pool(out)
|
||||
return out
|
||||
|
||||
|
||||
|
|
|
@ -12,6 +12,8 @@ Hacked together by / Copyright 2020 Ross Wightman
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .trace_utils import _assert
|
||||
|
||||
|
||||
class EvoNormBatch2d(nn.Module):
|
||||
def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-5, drop_block=None):
|
||||
|
@ -72,9 +74,9 @@ class EvoNormSample2d(nn.Module):
|
|||
nn.init.ones_(self.v)
|
||||
|
||||
def forward(self, x):
|
||||
assert x.dim() == 4, 'expected 4D input'
|
||||
_assert(x.dim() == 4, 'expected 4D input')
|
||||
B, C, H, W = x.shape
|
||||
assert C % self.groups == 0
|
||||
_assert(C % self.groups == 0, '')
|
||||
if self.apply_act:
|
||||
n = x * (x * self.v).sigmoid()
|
||||
x = x.reshape(B, self.groups, -1)
|
||||
|
|
|
@ -16,7 +16,7 @@ The attention mechanism works but it's slow as implemented.
|
|||
|
||||
Hacked together by / Copyright 2021 Ross Wightman
|
||||
"""
|
||||
from typing import Tuple, List
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
@ -24,6 +24,7 @@ import torch.nn.functional as F
|
|||
|
||||
from .helpers import make_divisible
|
||||
from .weight_init import trunc_normal_
|
||||
from .trace_utils import _assert
|
||||
|
||||
|
||||
def rel_logits_1d(q, rel_k, permute_mask: List[int]):
|
||||
|
@ -167,8 +168,8 @@ class HaloAttn(nn.Module):
|
|||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
assert H % self.block_size == 0
|
||||
assert W % self.block_size == 0
|
||||
_assert(H % self.block_size == 0, '')
|
||||
_assert(W % self.block_size == 0, '')
|
||||
num_h_blocks = H // self.block_size
|
||||
num_w_blocks = W // self.block_size
|
||||
num_blocks = num_h_blocks * num_w_blocks
|
||||
|
|
|
@ -10,6 +10,7 @@ from torch.nn import functional as F
|
|||
|
||||
from .conv_bn_act import ConvBnAct
|
||||
from .helpers import make_divisible
|
||||
from .trace_utils import _assert
|
||||
|
||||
|
||||
class NonLocalAttn(nn.Module):
|
||||
|
@ -83,7 +84,7 @@ class BilinearAttnTransform(nn.Module):
|
|||
|
||||
def resize_mat(self, x, t: int):
|
||||
B, C, block_size, block_size1 = x.shape
|
||||
assert block_size == block_size1
|
||||
_assert(block_size == block_size1, '')
|
||||
if t <= 1:
|
||||
return x
|
||||
x = x.view(B * C, -1, 1, 1)
|
||||
|
@ -95,7 +96,8 @@ class BilinearAttnTransform(nn.Module):
|
|||
return x
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[-1] % self.block_size == 0 and x.shape[-2] % self.block_size == 0
|
||||
_assert(x.shape[-1] % self.block_size == 0, '')
|
||||
_assert(x.shape[-2] % self.block_size == 0, '')
|
||||
B, C, H, W = x.shape
|
||||
out = self.conv1(x)
|
||||
rp = F.adaptive_max_pool2d(out, (self.block_size, 1))
|
||||
|
|
|
@ -9,6 +9,7 @@ from torch import nn as nn
|
|||
|
||||
from .conv_bn_act import ConvBnAct
|
||||
from .helpers import make_divisible
|
||||
from .trace_utils import _assert
|
||||
|
||||
|
||||
def _kernel_valid(k):
|
||||
|
@ -34,7 +35,7 @@ class SelectiveKernelAttn(nn.Module):
|
|||
self.fc_select = nn.Conv2d(attn_channels, channels * num_paths, kernel_size=1, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[1] == self.num_paths
|
||||
_assert(x.shape[1] == self.num_paths, '')
|
||||
x = x.sum(1).mean((2, 3), keepdim=True)
|
||||
x = self.fc_reduce(x)
|
||||
x = self.bn(x)
|
||||
|
|
|
@ -25,8 +25,10 @@ import torch.nn.functional as F
|
|||
from torch import nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .fx_features import register_notrace_function
|
||||
from .helpers import build_model_with_cfg, named_apply
|
||||
from .layers import PatchEmbed, Mlp, DropPath, create_classifier, trunc_normal_
|
||||
from .layers import _assert
|
||||
from .layers import create_conv2d, create_pool2d, to_ntuple
|
||||
from .registry import register_model
|
||||
|
||||
|
@ -128,8 +130,8 @@ class ConvPool(nn.Module):
|
|||
"""
|
||||
x is expected to have shape (B, C, H, W)
|
||||
"""
|
||||
assert x.shape[-2] % 2 == 0, 'BlockAggregation requires even input spatial dims'
|
||||
assert x.shape[-1] % 2 == 0, 'BlockAggregation requires even input spatial dims'
|
||||
_assert(x.shape[-2] % 2 == 0, 'BlockAggregation requires even input spatial dims')
|
||||
_assert(x.shape[-1] % 2 == 0, 'BlockAggregation requires even input spatial dims')
|
||||
x = self.conv(x)
|
||||
# Layer norm done over channel dim only
|
||||
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||
|
@ -144,8 +146,8 @@ def blockify(x, block_size: int):
|
|||
block_size (int): edge length of a single square block in units of H, W
|
||||
"""
|
||||
B, H, W, C = x.shape
|
||||
assert H % block_size == 0, '`block_size` must divide input height evenly'
|
||||
assert W % block_size == 0, '`block_size` must divide input width evenly'
|
||||
_assert(H % block_size == 0, '`block_size` must divide input height evenly')
|
||||
_assert(W % block_size == 0, '`block_size` must divide input width evenly')
|
||||
grid_height = H // block_size
|
||||
grid_width = W // block_size
|
||||
x = x.reshape(B, grid_height, block_size, grid_width, block_size, C)
|
||||
|
@ -153,6 +155,7 @@ def blockify(x, block_size: int):
|
|||
return x # (B, T, N, C)
|
||||
|
||||
|
||||
@register_notrace_function # reason: int receives Proxy
|
||||
def deblockify(x, block_size: int):
|
||||
"""blocks to image
|
||||
Args:
|
||||
|
|
|
@ -26,6 +26,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .fx_features import register_notrace_module
|
||||
from .helpers import build_model_with_cfg
|
||||
from .registry import register_model
|
||||
from .layers import ClassifierHead, DropPath, AvgPool2dSame, ScaledStdConv2d, ScaledStdConv2dSame,\
|
||||
|
@ -318,6 +319,7 @@ class DownsampleAvg(nn.Module):
|
|||
return self.conv(self.pool(x))
|
||||
|
||||
|
||||
@register_notrace_module # reason: mul_ causes FX to drop a relevant node. https://github.com/pytorch/pytorch/issues/68301
|
||||
class NormFreeBlock(nn.Module):
|
||||
"""Normalization-Free pre-activation block.
|
||||
"""
|
||||
|
|
|
@ -10,6 +10,7 @@ Changes for timm, feature extraction, and rounded channel variant hacked togethe
|
|||
Copyright 2020 Ross Wightman
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from functools import partial
|
||||
from math import ceil
|
||||
|
@ -92,7 +93,7 @@ class LinearBottleneck(nn.Module):
|
|||
if self.use_shortcut:
|
||||
if self.drop_path is not None:
|
||||
x = self.drop_path(x)
|
||||
x[:, 0:self.in_channels] += shortcut
|
||||
x = torch.cat([x[:, 0:self.in_channels] + shortcut, x[:, self.in_channels:]], dim=1)
|
||||
return x
|
||||
|
||||
|
||||
|
|
|
@ -21,11 +21,14 @@ import torch.nn as nn
|
|||
import torch.utils.checkpoint as checkpoint
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .fx_features import register_notrace_function
|
||||
from .helpers import build_model_with_cfg, overlay_external_default_cfg
|
||||
from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_
|
||||
from .layers import _assert
|
||||
from .registry import register_model
|
||||
from .vision_transformer import checkpoint_filter_fn, _init_vit_weights
|
||||
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -100,6 +103,7 @@ def window_partition(x, window_size: int):
|
|||
return windows
|
||||
|
||||
|
||||
@register_notrace_function # reason: int argument is a Proxy
|
||||
def window_reverse(windows, window_size: int, H: int, W: int):
|
||||
"""
|
||||
Args:
|
||||
|
@ -270,7 +274,7 @@ class SwinTransformerBlock(nn.Module):
|
|||
def forward(self, x):
|
||||
H, W = self.input_resolution
|
||||
B, L, C = x.shape
|
||||
assert L == H * W, "input feature has wrong size"
|
||||
_assert(L == H * W, "input feature has wrong size")
|
||||
|
||||
shortcut = x
|
||||
x = self.norm1(x)
|
||||
|
@ -329,8 +333,8 @@ class PatchMerging(nn.Module):
|
|||
"""
|
||||
H, W = self.input_resolution
|
||||
B, L, C = x.shape
|
||||
assert L == H * W, "input feature has wrong size"
|
||||
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
|
||||
_assert(L == H * W, "input feature has wrong size")
|
||||
_assert(H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even.")
|
||||
|
||||
x = x.view(B, H, W, C)
|
||||
|
||||
|
|
|
@ -9,12 +9,12 @@ https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/TNT
|
|||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from functools import partial
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.models.helpers import build_model_with_cfg
|
||||
from timm.models.layers import Mlp, DropPath, trunc_normal_
|
||||
from timm.models.layers.helpers import to_2tuple
|
||||
from timm.models.layers import _assert
|
||||
from timm.models.registry import register_model
|
||||
from timm.models.vision_transformer import resize_pos_embed
|
||||
|
||||
|
@ -109,7 +109,9 @@ class Block(nn.Module):
|
|||
pixel_embed = pixel_embed + self.drop_path(self.mlp_in(self.norm_mlp_in(pixel_embed)))
|
||||
# outer
|
||||
B, N, C = patch_embed.size()
|
||||
patch_embed[:, 1:] = patch_embed[:, 1:] + self.proj(self.norm1_proj(pixel_embed).reshape(B, N - 1, -1))
|
||||
patch_embed = torch.cat(
|
||||
[patch_embed[:, 0:1], patch_embed[:, 1:] + self.proj(self.norm1_proj(pixel_embed).reshape(B, N - 1, -1))],
|
||||
dim=1)
|
||||
patch_embed = patch_embed + self.drop_path(self.attn_out(self.norm_out(patch_embed)))
|
||||
patch_embed = patch_embed + self.drop_path(self.mlp(self.norm_mlp(patch_embed)))
|
||||
return pixel_embed, patch_embed
|
||||
|
@ -136,8 +138,10 @@ class PixelEmbed(nn.Module):
|
|||
|
||||
def forward(self, x, pixel_pos):
|
||||
B, C, H, W = x.shape
|
||||
assert H == self.img_size[0] and W == self.img_size[1], \
|
||||
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
||||
_assert(H == self.img_size[0],
|
||||
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
|
||||
_assert(W == self.img_size[1],
|
||||
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
|
||||
x = self.proj(x)
|
||||
x = self.unfold(x)
|
||||
x = x.transpose(1, 2).reshape(B * self.num_patches, self.in_dim, self.new_patch_size[0], self.new_patch_size[1])
|
||||
|
|
|
@ -22,9 +22,10 @@ from functools import partial
|
|||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .layers import Mlp, DropPath, to_2tuple, trunc_normal_
|
||||
from .fx_features import register_notrace_module
|
||||
from .registry import register_model
|
||||
from .vision_transformer import Attention
|
||||
from .helpers import build_model_with_cfg, overlay_external_default_cfg
|
||||
from .helpers import build_model_with_cfg
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
|
@ -62,6 +63,7 @@ default_cfgs = {
|
|||
Size_ = Tuple[int, int]
|
||||
|
||||
|
||||
@register_notrace_module # reason: FX can't symbolically trace control flow in forward method
|
||||
class LocallyGroupedAttn(nn.Module):
|
||||
""" LSA: self attention within a group
|
||||
"""
|
||||
|
|
|
@ -12,7 +12,8 @@ from typing import Union, List, Dict, Any, cast
|
|||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .helpers import build_model_with_cfg
|
||||
from .layers import ClassifierHead, ConvBnAct
|
||||
from .fx_features import register_notrace_module
|
||||
from .layers import ClassifierHead
|
||||
from .registry import register_model
|
||||
|
||||
__all__ = [
|
||||
|
@ -52,6 +53,7 @@ cfgs: Dict[str, List[Union[str, int]]] = {
|
|||
}
|
||||
|
||||
|
||||
@register_notrace_module # reason: FX can't symbolically trace control flow in forward method
|
||||
class ConvMlp(nn.Module):
|
||||
|
||||
def __init__(self, in_features=512, out_features=4096, kernel_size=7, mlp_ratio=1.0,
|
||||
|
|
|
@ -21,6 +21,7 @@ from .vision_transformer import _cfg, Mlp
|
|||
from .registry import register_model
|
||||
from .layers import DropPath, trunc_normal_, to_2tuple
|
||||
from .cait import ClassAttn
|
||||
from .fx_features import register_notrace_module
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
|
@ -97,6 +98,7 @@ default_cfgs = {
|
|||
}
|
||||
|
||||
|
||||
@register_notrace_module # reason: FX can't symbolically trace torch.arange in forward method
|
||||
class PositionalEncodingFourier(nn.Module):
|
||||
"""
|
||||
Positional encoding relying on a fourier kernel matching the one used in the "Attention is all of Need" paper.
|
||||
|
|
Loading…
Reference in New Issue