mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
wip - pre-rebase
This commit is contained in:
parent
e051dce354
commit
b25ff96768
@ -4,10 +4,12 @@ import platform
|
|||||||
import os
|
import os
|
||||||
import fnmatch
|
import fnmatch
|
||||||
|
|
||||||
|
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names, NodePathTracer
|
||||||
|
|
||||||
import timm
|
import timm
|
||||||
from timm import list_models, create_model, set_scriptable, has_model_default_key, is_model_default_key, \
|
from timm import list_models, create_model, set_scriptable, has_model_default_key, is_model_default_key, \
|
||||||
get_model_default_value
|
get_model_default_value
|
||||||
from timm.models.fx_features import NodePathTracer
|
from timm.models.fx_features import _leaf_modules, _autowrap_functions
|
||||||
|
|
||||||
if hasattr(torch._C, '_jit_set_profiling_executor'):
|
if hasattr(torch._C, '_jit_set_profiling_executor'):
|
||||||
# legacy executor is too slow to compile large models for unit tests
|
# legacy executor is too slow to compile large models for unit tests
|
||||||
@ -312,12 +314,14 @@ def test_model_forward_fx(model_name, batch_size):
|
|||||||
if max(input_size) > MAX_FWD_SIZE:
|
if max(input_size) > MAX_FWD_SIZE:
|
||||||
pytest.skip("Fixed input size model > limit.")
|
pytest.skip("Fixed input size model > limit.")
|
||||||
|
|
||||||
tracer = NodePathTracer()
|
train_nodes, eval_nodes = get_graph_node_names(
|
||||||
graph = tracer.trace(model)
|
model, tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
|
||||||
model = torch.fx.GraphModule(model, graph)
|
model = create_feature_extractor(
|
||||||
|
model, train_return_nodes=[train_nodes[-1]], eval_return_nodes=[eval_nodes[-1]],
|
||||||
|
tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
|
||||||
|
|
||||||
inputs = torch.randn((batch_size, *input_size))
|
inputs = torch.randn((batch_size, *input_size))
|
||||||
outputs = model(inputs)
|
outputs = model(inputs)[eval_nodes[-1]]
|
||||||
|
|
||||||
assert outputs.shape[0] == batch_size
|
assert outputs.shape[0] == batch_size
|
||||||
assert not torch.isnan(outputs).any(), 'Output included NaNs'
|
assert not torch.isnan(outputs).any(), 'Output included NaNs'
|
||||||
@ -336,12 +340,30 @@ def test_model_backward_fx(model_name, batch_size):
|
|||||||
model.train()
|
model.train()
|
||||||
num_params = sum([x.numel() for x in model.parameters()])
|
num_params = sum([x.numel() for x in model.parameters()])
|
||||||
|
|
||||||
tracer = NodePathTracer()
|
input_size = _get_input_size(model=model, target=TARGET_FWD_SIZE)
|
||||||
|
if max(input_size) > MAX_FWD_SIZE:
|
||||||
|
pytest.skip("Fixed input size model > limit.")
|
||||||
|
|
||||||
|
# This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode
|
||||||
|
# If so, we need to return all of them in order to check all grads
|
||||||
|
# So we trace once and look at the graph, and get the indices of the nodes that lead into the original fx output
|
||||||
|
# node. Then we use those indices to select from train_nodes returned by torchvision get_graph_node_names
|
||||||
|
tracer = NodePathTracer(leaf_modules=list(_leaf_modules), autowrap_functions=list(_autowrap_functions))
|
||||||
graph = tracer.trace(model)
|
graph = tracer.trace(model)
|
||||||
model = torch.fx.GraphModule(model, graph)
|
graph_nodes = list(reversed(graph.nodes))
|
||||||
|
output_node_names = [n.name for n in graph_nodes[0]._input_nodes.keys()]
|
||||||
|
graph_node_names = [n.name for n in graph_nodes]
|
||||||
|
output_node_indices = [-graph_node_names.index(node_name) for node_name in output_node_names]
|
||||||
|
train_nodes, eval_nodes = get_graph_node_names(
|
||||||
|
model, tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
|
||||||
|
train_return_nodes = [train_nodes[ix] for ix in output_node_indices]
|
||||||
|
|
||||||
|
model = create_feature_extractor(
|
||||||
|
model, train_return_nodes=train_return_nodes, eval_return_nodes=[eval_nodes[-1]],
|
||||||
|
tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
|
||||||
|
|
||||||
inputs = torch.randn((batch_size, *input_size))
|
inputs = torch.randn((batch_size, *input_size))
|
||||||
outputs = model(inputs)
|
outputs = tuple(model(inputs).values())
|
||||||
if isinstance(outputs, tuple):
|
if isinstance(outputs, tuple):
|
||||||
outputs = torch.cat(outputs)
|
outputs = torch.cat(outputs)
|
||||||
outputs.mean().backward()
|
outputs.mean().backward()
|
||||||
@ -354,9 +376,14 @@ def test_model_backward_fx(model_name, batch_size):
|
|||||||
assert not torch.isnan(outputs).any(), 'Output included NaNs'
|
assert not torch.isnan(outputs).any(), 'Output included NaNs'
|
||||||
|
|
||||||
|
|
||||||
|
EXCLUDE_FX_JIT_FILTERS = [
|
||||||
|
'beit_*' # reason: model is scripted after fx tracing, but beit has torch.jit.is_scripting() control flow
|
||||||
|
]
|
||||||
|
|
||||||
@pytest.mark.timeout(120)
|
@pytest.mark.timeout(120)
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
'model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS, name_matches_cfg=True))
|
'model_name', list_models(
|
||||||
|
exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS + EXCLUDE_FX_JIT_FILTERS, name_matches_cfg=True))
|
||||||
@pytest.mark.parametrize('batch_size', [1])
|
@pytest.mark.parametrize('batch_size', [1])
|
||||||
def test_model_forward_fx_torchscript(model_name, batch_size):
|
def test_model_forward_fx_torchscript(model_name, batch_size):
|
||||||
"""Symbolically trace each model, script it, and run single forward pass"""
|
"""Symbolically trace each model, script it, and run single forward pass"""
|
||||||
@ -368,12 +395,18 @@ def test_model_forward_fx_torchscript(model_name, batch_size):
|
|||||||
model = create_model(model_name, pretrained=False)
|
model = create_model(model_name, pretrained=False)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
tracer = NodePathTracer()
|
input_size = _get_input_size(model=model, target=TARGET_FWD_SIZE)
|
||||||
graph = tracer.trace(model)
|
if max(input_size) > MAX_FWD_SIZE:
|
||||||
model = torch.fx.GraphModule(model, graph)
|
pytest.skip("Fixed input size model > limit.")
|
||||||
|
|
||||||
|
train_nodes, eval_nodes = get_graph_node_names(
|
||||||
|
model, tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
|
||||||
|
model = create_feature_extractor(
|
||||||
|
model, train_return_nodes=[train_nodes[-1]], eval_return_nodes=[eval_nodes[-1]],
|
||||||
|
tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
|
||||||
|
|
||||||
model = torch.jit.script(model)
|
model = torch.jit.script(model)
|
||||||
outputs = model(torch.randn((batch_size, *input_size)))
|
outputs = model(torch.randn((batch_size, *input_size)))[train_nodes[-1]]
|
||||||
|
|
||||||
assert outputs.shape[0] == batch_size
|
assert outputs.shape[0] == batch_size
|
||||||
assert not torch.isnan(outputs).any(), 'Output included NaNs'
|
assert not torch.isnan(outputs).any(), 'Output included NaNs'
|
||||||
|
@ -95,11 +95,11 @@ class ClassAttn(nn.Module):
|
|||||||
q = q * self.scale
|
q = q * self.scale
|
||||||
v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
||||||
|
|
||||||
attn = torch.matmul(q, k.transpose(-2, -1))
|
attn = (q @ k.transpose(-2, -1))
|
||||||
attn = attn.softmax(dim=-1)
|
attn = attn.softmax(dim=-1)
|
||||||
attn = self.attn_drop(attn)
|
attn = self.attn_drop(attn)
|
||||||
|
|
||||||
x_cls = torch.matmul(attn, v).transpose(1, 2).reshape(B, 1, C)
|
x_cls = (attn @ v).transpose(1, 2).reshape(B, 1, C)
|
||||||
x_cls = self.proj(x_cls)
|
x_cls = self.proj(x_cls)
|
||||||
x_cls = self.proj_drop(x_cls)
|
x_cls = self.proj_drop(x_cls)
|
||||||
|
|
||||||
@ -158,7 +158,7 @@ class TalkingHeadAttn(nn.Module):
|
|||||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||||
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
|
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
|
||||||
|
|
||||||
attn = torch.matmul(q, k.transpose(-2, -1))
|
attn = (q @ k.transpose(-2, -1))
|
||||||
|
|
||||||
attn = self.proj_l(attn.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
attn = self.proj_l(attn.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||||
|
|
||||||
@ -167,7 +167,7 @@ class TalkingHeadAttn(nn.Module):
|
|||||||
attn = self.proj_w(attn.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
attn = self.proj_w(attn.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||||
attn = self.attn_drop(attn)
|
attn = self.attn_drop(attn)
|
||||||
|
|
||||||
x = torch.matmul(attn, v).transpose(1, 2).reshape(B, N, C)
|
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||||
x = self.proj(x)
|
x = self.proj(x)
|
||||||
x = self.proj_drop(x)
|
x = self.proj_drop(x)
|
||||||
return x
|
return x
|
||||||
|
@ -19,6 +19,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|||||||
from .helpers import build_model_with_cfg, overlay_external_default_cfg
|
from .helpers import build_model_with_cfg, overlay_external_default_cfg
|
||||||
from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_
|
from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_
|
||||||
from .registry import register_model
|
from .registry import register_model
|
||||||
|
from .layers.trace_utils import _assert
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -105,7 +106,7 @@ class ConvRelPosEnc(nn.Module):
|
|||||||
def forward(self, q, v, size: Tuple[int, int]):
|
def forward(self, q, v, size: Tuple[int, int]):
|
||||||
B, h, N, Ch = q.shape
|
B, h, N, Ch = q.shape
|
||||||
H, W = size
|
H, W = size
|
||||||
torch._assert(N == 1 + H * W, '')
|
_assert(N == 1 + H * W, '')
|
||||||
|
|
||||||
# Convolutional relative position encoding.
|
# Convolutional relative position encoding.
|
||||||
q_img = q[:, :, 1:, :] # [B, h, H*W, Ch]
|
q_img = q[:, :, 1:, :] # [B, h, H*W, Ch]
|
||||||
@ -149,8 +150,8 @@ class FactorAtt_ConvRelPosEnc(nn.Module):
|
|||||||
|
|
||||||
# Factorized attention.
|
# Factorized attention.
|
||||||
k_softmax = k.softmax(dim=2)
|
k_softmax = k.softmax(dim=2)
|
||||||
factor_att = torch.matmul(k_softmax.transpose(-1, -2), v)
|
factor_att = k_softmax.transpose(-1, -2) @ v
|
||||||
factor_att = torch.matmul(q, factor_att)
|
factor_att = q @ factor_att
|
||||||
|
|
||||||
# Convolutional relative position encoding.
|
# Convolutional relative position encoding.
|
||||||
crpe = self.crpe(q, v, size=size) # [B, h, N, Ch]
|
crpe = self.crpe(q, v, size=size) # [B, h, N, Ch]
|
||||||
@ -177,7 +178,7 @@ class ConvPosEnc(nn.Module):
|
|||||||
def forward(self, x, size: Tuple[int, int]):
|
def forward(self, x, size: Tuple[int, int]):
|
||||||
B, N, C = x.shape
|
B, N, C = x.shape
|
||||||
H, W = size
|
H, W = size
|
||||||
torch._assert(N == 1 + H * W, '')
|
_assert(N == 1 + H * W, '')
|
||||||
|
|
||||||
# Extract CLS token and image tokens.
|
# Extract CLS token and image tokens.
|
||||||
cls_token, img_tokens = x[:, :1], x[:, 1:] # [B, 1, C], [B, H*W, C]
|
cls_token, img_tokens = x[:, :1], x[:, 1:] # [B, 1, C], [B, H*W, C]
|
||||||
@ -275,7 +276,7 @@ class ParallelBlock(nn.Module):
|
|||||||
""" Feature map interpolation. """
|
""" Feature map interpolation. """
|
||||||
B, N, C = x.shape
|
B, N, C = x.shape
|
||||||
H, W = size
|
H, W = size
|
||||||
torch._assert(N == 1 + H * W, '')
|
_assert(N == 1 + H * W, '')
|
||||||
|
|
||||||
cls_token = x[:, :1, :]
|
cls_token = x[:, :1, :]
|
||||||
img_tokens = x[:, 1:, :]
|
img_tokens = x[:, 1:, :]
|
||||||
|
@ -57,7 +57,7 @@ default_cfgs = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@register_leaf_module # FX can't symbolically trace control flow in forward method
|
@register_leaf_module # reason: FX can't symbolically trace control flow in forward method
|
||||||
class GPSA(nn.Module):
|
class GPSA(nn.Module):
|
||||||
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.,
|
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.,
|
||||||
locality_strength=1.):
|
locality_strength=1.):
|
||||||
@ -84,7 +84,7 @@ class GPSA(nn.Module):
|
|||||||
self.rel_indices = self.get_rel_indices(N)
|
self.rel_indices = self.get_rel_indices(N)
|
||||||
attn = self.get_attention(x)
|
attn = self.get_attention(x)
|
||||||
v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
||||||
x = torch.matmul(attn, v).transpose(1, 2).reshape(B, N, C)
|
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||||
x = self.proj(x)
|
x = self.proj(x)
|
||||||
x = self.proj_drop(x)
|
x = self.proj_drop(x)
|
||||||
return x
|
return x
|
||||||
@ -95,7 +95,7 @@ class GPSA(nn.Module):
|
|||||||
q, k = qk[0], qk[1]
|
q, k = qk[0], qk[1]
|
||||||
pos_score = self.rel_indices.expand(B, -1, -1, -1)
|
pos_score = self.rel_indices.expand(B, -1, -1, -1)
|
||||||
pos_score = self.pos_proj(pos_score).permute(0, 3, 1, 2)
|
pos_score = self.pos_proj(pos_score).permute(0, 3, 1, 2)
|
||||||
patch_score = torch.matmul(q, k.transpose(-2, -1)) * self.scale
|
patch_score = (q @ k.transpose(-2, -1)) * self.scale
|
||||||
patch_score = patch_score.softmax(dim=-1)
|
patch_score = patch_score.softmax(dim=-1)
|
||||||
pos_score = pos_score.softmax(dim=-1)
|
pos_score = pos_score.softmax(dim=-1)
|
||||||
|
|
||||||
@ -180,11 +180,11 @@ class MHSA(nn.Module):
|
|||||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||||
|
|
||||||
attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale
|
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||||
attn = attn.softmax(dim=-1)
|
attn = attn.softmax(dim=-1)
|
||||||
attn = self.attn_drop(attn)
|
attn = self.attn_drop(attn)
|
||||||
|
|
||||||
x = torch.matmul(attn, v).transpose(1, 2).reshape(B, N, C)
|
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||||
x = self.proj(x)
|
x = self.proj(x)
|
||||||
x = self.proj_drop(x)
|
x = self.proj_drop(x)
|
||||||
return x
|
return x
|
||||||
|
@ -22,6 +22,7 @@ NOTE: model names have been renamed from originals to represent actual input res
|
|||||||
Modifed from Timm. https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
Modifed from Timm. https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -31,8 +32,9 @@ from functools import partial
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
|
from .fx_features import register_autowrap_function
|
||||||
from .helpers import build_model_with_cfg
|
from .helpers import build_model_with_cfg
|
||||||
from .layers import DropPath, to_2tuple, trunc_normal_
|
from .layers import DropPath, to_2tuple, trunc_normal_, _assert
|
||||||
from .registry import register_model
|
from .registry import register_model
|
||||||
from .vision_transformer import Mlp, Block
|
from .vision_transformer import Mlp, Block
|
||||||
|
|
||||||
@ -116,8 +118,10 @@ class PatchEmbed(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
B, C, H, W = x.shape
|
B, C, H, W = x.shape
|
||||||
# FIXME look at relaxing size constraints
|
# FIXME look at relaxing size constraints
|
||||||
assert H == self.img_size[0] and W == self.img_size[1], \
|
_assert(H == self.img_size[0],
|
||||||
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
|
||||||
|
_assert(W == self.img_size[1],
|
||||||
|
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
|
||||||
x = self.proj(x).flatten(2).transpose(1, 2)
|
x = self.proj(x).flatten(2).transpose(1, 2)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@ -255,6 +259,27 @@ def _compute_num_patches(img_size, patches):
|
|||||||
return [i[0] // p * i[1] // p for i, p in zip(img_size, patches)]
|
return [i[0] // p * i[1] // p for i, p in zip(img_size, patches)]
|
||||||
|
|
||||||
|
|
||||||
|
@register_autowrap_function
|
||||||
|
def scale_image(x, ss: Tuple[int, int], crop_scale: bool = False): # annotations for torchscript
|
||||||
|
"""
|
||||||
|
Pulled out of CrossViT.forward_features to bury conditional logic in a leaf node for FX tracing.
|
||||||
|
Args:
|
||||||
|
x (Tensor): input image
|
||||||
|
ss (tuple[int, int]): height and width to scale to
|
||||||
|
crop_scale (bool): whether to crop instead of interpolate to achieve the desired scale. Defaults to False
|
||||||
|
Returns:
|
||||||
|
Tensor: the "scaled" image batch tensor
|
||||||
|
"""
|
||||||
|
H, W = x.shape[-2:]
|
||||||
|
if H != ss[0] or W != ss[1]:
|
||||||
|
if crop_scale and ss[0] <= H and ss[1] <= W:
|
||||||
|
cu, cl = int(round((H - ss[0]) / 2.)), int(round((W - ss[1]) / 2.))
|
||||||
|
x = x[:, :, cu:cu + ss[0], cl:cl + ss[1]]
|
||||||
|
else:
|
||||||
|
x = torch.nn.functional.interpolate(x, size=ss, mode='bicubic', align_corners=False)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
class CrossViT(nn.Module):
|
class CrossViT(nn.Module):
|
||||||
""" Vision Transformer with support for patch or hybrid CNN input stage
|
""" Vision Transformer with support for patch or hybrid CNN input stage
|
||||||
"""
|
"""
|
||||||
@ -342,17 +367,12 @@ class CrossViT(nn.Module):
|
|||||||
range(self.num_branches)])
|
range(self.num_branches)])
|
||||||
|
|
||||||
def forward_features(self, x):
|
def forward_features(self, x):
|
||||||
B, C, H, W = x.shape
|
B = x.shape[0]
|
||||||
xs = []
|
xs = []
|
||||||
for i, patch_embed in enumerate(self.patch_embed):
|
for i, patch_embed in enumerate(self.patch_embed):
|
||||||
x_ = x
|
x_ = x
|
||||||
ss = self.img_size_scaled[i]
|
ss = self.img_size_scaled[i]
|
||||||
if H != ss[0] or W != ss[1]:
|
x_ = scale_image(x_, ss, self.crop_scale)
|
||||||
if self.crop_scale and ss[0] <= H and ss[1] <= W:
|
|
||||||
cu, cl = int(round((H - ss[0]) / 2.)), int(round((W - ss[1]) / 2.))
|
|
||||||
x_ = x_[:, :, cu:cu + ss[0], cl:cl + ss[1]]
|
|
||||||
else:
|
|
||||||
x_ = torch.nn.functional.interpolate(x_, size=ss, mode='bicubic', align_corners=False)
|
|
||||||
x_ = patch_embed(x_)
|
x_ = patch_embed(x_)
|
||||||
cls_tokens = self.cls_token_0 if i == 0 else self.cls_token_1 # hard-coded for torch jit script
|
cls_tokens = self.cls_token_0 if i == 0 else self.cls_token_1 # hard-coded for torch jit script
|
||||||
cls_tokens = cls_tokens.expand(B, -1, -1)
|
cls_tokens = cls_tokens.expand(B, -1, -1)
|
||||||
|
@ -1,42 +1,31 @@
|
|||||||
""" PyTorch FX Based Feature Extraction Helpers
|
""" PyTorch FX Based Feature Extraction Helpers
|
||||||
An extension/alternative to timm.models.features making use of PyTorch FX. Here, the idea is to:
|
Using https://pytorch.org/vision/stable/feature_extraction.html
|
||||||
1. Symbolically trace a model producing a graph based intermediate representation (PyTorch FX functionality with
|
|
||||||
some custom tweaks)
|
|
||||||
2. Identify desired feature extraction nodes and reconfigure them as output nodes while deleting all unecessary
|
|
||||||
nodes. (custom - inspired by https://github.com/pytorch/vision/pull/3597)
|
|
||||||
3. Write the resulting graph into a GraphModule (PyTorch FX functionality)
|
|
||||||
Copyright 2021 Alexander Soare
|
|
||||||
"""
|
"""
|
||||||
from typing import Callable, Dict, Union, List, Optional
|
from typing import Callable
|
||||||
import math
|
|
||||||
from collections import OrderedDict
|
|
||||||
from pprint import pprint
|
|
||||||
from inspect import ismethod
|
|
||||||
import re
|
|
||||||
import warnings
|
|
||||||
from copy import deepcopy
|
|
||||||
from itertools import chain
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch import fx
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch.fx.graph_module import _copy_attr
|
|
||||||
|
|
||||||
from .features import _get_feature_info
|
from .features import _get_feature_info
|
||||||
from .fx_helpers import fx_float_to_int
|
|
||||||
|
|
||||||
# Layers we went to treat as leaf modules for FeatureGraphNet
|
try:
|
||||||
from .layers import Conv2dSame, ScaledStdConv2dSame, BatchNormAct2d, BlurPool2d, CondConv2d, StdConv2dSame
|
from torchvision.models.feature_extraction import create_feature_extractor
|
||||||
from .layers import GatherExcite, DropPath
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Layers we went to treat as leaf modules
|
||||||
|
from .layers import Conv2dSame, ScaledStdConv2dSame, BatchNormAct2d, BlurPool2d, CondConv2d, StdConv2dSame, DropPath
|
||||||
from .layers.non_local_attn import BilinearAttnTransform
|
from .layers.non_local_attn import BilinearAttnTransform
|
||||||
from .layers.pool2d_same import MaxPool2dSame, AvgPool2dSame
|
from .layers.pool2d_same import MaxPool2dSame, AvgPool2dSame
|
||||||
|
|
||||||
|
# NOTE: By default, any modules from timm.models.layers that we want to treat as leaf modules go here
|
||||||
# These modules will not be traced through.
|
# BUT modules from timm.models should use the registration mechanism below
|
||||||
_leaf_modules = {
|
_leaf_modules = {
|
||||||
Conv2dSame, ScaledStdConv2dSame, BatchNormAct2d, BlurPool2d, CondConv2d, StdConv2dSame, GatherExcite, DropPath,
|
BatchNormAct2d, # reason: flow control for jit scripting
|
||||||
BilinearAttnTransform, MaxPool2dSame, AvgPool2dSame
|
BilinearAttnTransform, # reason: flow control t <= 1
|
||||||
|
BlurPool2d, # reason: TypeError: F.conv2d received Proxy in groups=x.shape[1]
|
||||||
|
# Reason: get_same_padding has a max which raises a control flow error
|
||||||
|
Conv2dSame, MaxPool2dSame, ScaledStdConv2dSame, StdConv2dSame, AvgPool2dSame,
|
||||||
|
CondConv2d, # reason: TypeError: F.conv2d received Proxy in groups=self.groups * B (because B = x.shape[0])
|
||||||
|
DropPath, # reason: TypeError: rand recieved Proxy in `size` argument
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -54,425 +43,16 @@ def register_leaf_module(module: nn.Module):
|
|||||||
return module
|
return module
|
||||||
|
|
||||||
|
|
||||||
# These functions will not be traced through
|
# Functions we want to autowrap (treat them as leaves)
|
||||||
_autowrap_functions=(fx_float_to_int,)
|
_autowrap_functions = set()
|
||||||
|
|
||||||
|
|
||||||
class TimmTracer(fx.Tracer):
|
def register_autowrap_function(func: Callable):
|
||||||
"""
|
"""
|
||||||
Temporary bridge from torch.fx.Tracer to include any general workarounds required to make FX work for us
|
Decorator for functions which ought not to be traced through
|
||||||
"""
|
"""
|
||||||
def __init__(self, autowrap_modules=(math, ), autowrap_functions=(), enable_cpatching=False):
|
_autowrap_functions.add(func)
|
||||||
super().__init__(autowrap_modules=autowrap_modules, enable_cpatching=enable_cpatching)
|
return func
|
||||||
# FIXME: This is a workaround pending on a PyTorch PR https://github.com/pytorch/pytorch/pull/62106
|
|
||||||
self._autowrap_function_ids.update(set([id(f) for f in autowrap_functions]))
|
|
||||||
|
|
||||||
def create_node(self, kind, target, args, kwargs, name=None, type_expr=None):
|
|
||||||
# FIXME: This is a workaround pending on a PyTorch PR https://github.com/pytorch/pytorch/pull/62095
|
|
||||||
if target == F.pad:
|
|
||||||
kwargs['value'] = float(kwargs['value'])
|
|
||||||
return super().create_node(kind, target, args, kwargs, name=name, type_expr=type_expr)
|
|
||||||
|
|
||||||
|
|
||||||
class LeafNodeTracer(TimmTracer):
|
|
||||||
"""
|
|
||||||
Account for desired leaf nodes according to _leaf_modules and _autowrap functions
|
|
||||||
"""
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(autowrap_functions=_autowrap_functions)
|
|
||||||
|
|
||||||
def is_leaf_module(self, m: nn.Module, module_qualname: str) -> bool:
|
|
||||||
if isinstance(m, tuple(_leaf_modules)):
|
|
||||||
return True
|
|
||||||
return super().is_leaf_module(m, module_qualname)
|
|
||||||
|
|
||||||
|
|
||||||
def _is_subseq(x, y):
|
|
||||||
"""Check if y is a subseqence of x
|
|
||||||
https://stackoverflow.com/a/24017747/4391249
|
|
||||||
"""
|
|
||||||
iter_x = iter(x)
|
|
||||||
return all(any(x_item == y_item for x_item in iter_x) for y_item in y)
|
|
||||||
|
|
||||||
|
|
||||||
# Taken from https://github.com/pytorch/examples/blob/master/fx/module_tracer.py with modifications for storing
|
|
||||||
# qualified names for all Nodes, not just top-level Modules
|
|
||||||
class NodePathTracer(LeafNodeTracer):
|
|
||||||
"""
|
|
||||||
NodePathTracer is an FX tracer that, for each operation, also records the
|
|
||||||
qualified name of the Node from which the operation originated. A
|
|
||||||
qualified name here is a `.` seperated path walking the hierarchy from top
|
|
||||||
level module down to leaf operation or leaf module. The name of the top
|
|
||||||
level module is not included as part of the qualified name. For example,
|
|
||||||
if we trace a module who's forward method applies a ReLU module, the
|
|
||||||
qualified name for that node will simply be 'relu'.
|
|
||||||
|
|
||||||
Some notes on the specifics:
|
|
||||||
- Nodes are recorded to `self.node_to_qualname` which is a dictionary
|
|
||||||
mapping a given Node object to its qualified name.
|
|
||||||
- Nodes are recorded in the order which they are executed during
|
|
||||||
tracing.
|
|
||||||
- When a duplicate qualified name is encountered, a suffix of the form
|
|
||||||
_{int} is added. The counter starts from 1.
|
|
||||||
"""
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super(NodePathTracer, self).__init__(*args, **kwargs)
|
|
||||||
# Track the qualified name of the Node being traced
|
|
||||||
self.current_module_qualname = ''
|
|
||||||
# A map from FX Node to the qualified name
|
|
||||||
self.node_to_qualname = OrderedDict()
|
|
||||||
|
|
||||||
def call_module(self, m: torch.nn.Module, forward: Callable, args, kwargs):
|
|
||||||
"""
|
|
||||||
Override of `fx.Tracer.call_module`
|
|
||||||
This override:
|
|
||||||
1) Stores away the qualified name of the caller for restoration later
|
|
||||||
2) Adds the qualified name of the caller to
|
|
||||||
`current_module_qualname` for retrieval by `create_proxy`
|
|
||||||
3) Once a leaf module is reached, calls `create_proxy`
|
|
||||||
4) Restores the caller's qualified name into current_module_qualname
|
|
||||||
"""
|
|
||||||
old_qualname = self.current_module_qualname
|
|
||||||
try:
|
|
||||||
module_qualname = self.path_of_module(m)
|
|
||||||
self.current_module_qualname = module_qualname
|
|
||||||
if not self.is_leaf_module(m, module_qualname):
|
|
||||||
out = forward(*args, **kwargs)
|
|
||||||
return out
|
|
||||||
return self.create_proxy('call_module', module_qualname, args, kwargs)
|
|
||||||
finally:
|
|
||||||
self.current_module_qualname = old_qualname
|
|
||||||
|
|
||||||
def create_proxy(self, kind: str, target: fx.node.Target, args, kwargs,
|
|
||||||
name=None, type_expr=None) -> fx.proxy.Proxy:
|
|
||||||
"""
|
|
||||||
Override of `Tracer.create_proxy`. This override intercepts the recording
|
|
||||||
of every operation and stores away the current traced module's qualified
|
|
||||||
name in `node_to_qualname`
|
|
||||||
"""
|
|
||||||
proxy = super().create_proxy(kind, target, args, kwargs, name, type_expr)
|
|
||||||
self.node_to_qualname[proxy.node] = self._get_node_qualname(
|
|
||||||
self.current_module_qualname, proxy.node)
|
|
||||||
return proxy
|
|
||||||
|
|
||||||
def _get_node_qualname(
|
|
||||||
self, module_qualname: str, node: fx.node.Node) -> str:
|
|
||||||
node_qualname = module_qualname
|
|
||||||
if node.op == 'call_module':
|
|
||||||
# Node terminates in a leaf module so the module_qualname is a
|
|
||||||
# complete description of the node
|
|
||||||
for existing_qualname in reversed(self.node_to_qualname.values()):
|
|
||||||
# Check to see if existing_qualname is of the form
|
|
||||||
# {node_qualname} or {node_qualname}_{int}
|
|
||||||
if re.match(rf'{node_qualname}(_[0-9]+)?$',
|
|
||||||
existing_qualname) is not None:
|
|
||||||
postfix = existing_qualname.replace(node_qualname, '')
|
|
||||||
if len(postfix):
|
|
||||||
# Existing_qualname is of the form {node_qualname}_{int}
|
|
||||||
next_index = int(postfix[1:]) + 1
|
|
||||||
else:
|
|
||||||
# existing_qualname is of the form {node_qualname}
|
|
||||||
next_index = 1
|
|
||||||
node_qualname += f'_{next_index}'
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
# Node terminates in non- leaf module so the node name needs to be
|
|
||||||
# appended
|
|
||||||
if len(node_qualname) > 0:
|
|
||||||
# Only append '.' if we are deeper than the top level module
|
|
||||||
node_qualname += '.'
|
|
||||||
node_qualname += str(node)
|
|
||||||
return node_qualname
|
|
||||||
|
|
||||||
|
|
||||||
def _warn_graph_differences(
|
|
||||||
train_tracer: NodePathTracer, eval_tracer: NodePathTracer):
|
|
||||||
"""
|
|
||||||
Utility function for warning the user if there are differences between
|
|
||||||
the train graph and the eval graph.
|
|
||||||
"""
|
|
||||||
train_nodes = list(train_tracer.node_to_qualname.values())
|
|
||||||
eval_nodes = list(eval_tracer.node_to_qualname.values())
|
|
||||||
|
|
||||||
if len(train_nodes) == len(eval_nodes) and [
|
|
||||||
t == e for t, e in zip(train_nodes, eval_nodes)]:
|
|
||||||
return
|
|
||||||
|
|
||||||
suggestion_msg = (
|
|
||||||
"When choosing nodes for feature extraction, you may need to specify "
|
|
||||||
"output nodes for train and eval mode separately")
|
|
||||||
|
|
||||||
if _is_subseq(train_nodes, eval_nodes):
|
|
||||||
msg = ("NOTE: The nodes obtained by tracing the model in eval mode "
|
|
||||||
"are a subsequence of those obtained in train mode. ")
|
|
||||||
elif _is_subseq(eval_nodes, train_nodes):
|
|
||||||
msg = ("NOTE: The nodes obtained by tracing the model in train mode "
|
|
||||||
"are a subsequence of those obtained in eval mode. ")
|
|
||||||
else:
|
|
||||||
msg = ("The nodes obtained by tracing the model in train mode "
|
|
||||||
"are different to those obtained in eval mode. ")
|
|
||||||
warnings.warn(msg + suggestion_msg)
|
|
||||||
|
|
||||||
|
|
||||||
def print_graph_node_qualified_names(
|
|
||||||
model: nn.Module, tracer_kwargs: Dict = {}):
|
|
||||||
"""
|
|
||||||
Dev utility to prints nodes in order of execution. Useful for choosing
|
|
||||||
nodes for a FeatureGraphNet design. There are two reasons that qualified
|
|
||||||
node names can't easily be read directly from the code for a model:
|
|
||||||
1. Not all submodules are traced through. Modules from `torch.nn` all
|
|
||||||
fall within this category.
|
|
||||||
2. Node qualified names that occur more than once in the graph get a
|
|
||||||
`_{counter}` postfix.
|
|
||||||
The model will be traced twice: once in train mode, and once in eval mode.
|
|
||||||
If there are discrepancies between the graphs produced, both sets will
|
|
||||||
be printed and the user will be warned.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model (nn.Module): model on which we will extract the features
|
|
||||||
tracer_kwargs (Dict): a dictionary of keywork arguments for
|
|
||||||
`NodePathTracer` (which passes them onto it's parent class
|
|
||||||
`torch.fx.Tracer`).
|
|
||||||
"""
|
|
||||||
train_tracer = NodePathTracer(**tracer_kwargs)
|
|
||||||
train_tracer.trace(model.train())
|
|
||||||
eval_tracer = NodePathTracer(**tracer_kwargs)
|
|
||||||
eval_tracer.trace(model.eval())
|
|
||||||
train_nodes = list(train_tracer.node_to_qualname.values())
|
|
||||||
eval_nodes = list(eval_tracer.node_to_qualname.values())
|
|
||||||
if len(train_nodes) == len(eval_nodes) and [
|
|
||||||
t == e for t, e in zip(train_nodes, eval_nodes)]:
|
|
||||||
# Nodes are aligned in train vs eval mode
|
|
||||||
pprint(list(train_tracer.node_to_qualname.values()))
|
|
||||||
return
|
|
||||||
print("Nodes from train mode:")
|
|
||||||
pprint(list(train_tracer.node_to_qualname.values()))
|
|
||||||
print()
|
|
||||||
print("Nodes from eval mode:")
|
|
||||||
pprint(list(eval_tracer.node_to_qualname.values()))
|
|
||||||
print()
|
|
||||||
_warn_graph_differences(train_tracer, eval_tracer)
|
|
||||||
|
|
||||||
|
|
||||||
class DualGraphModule(fx.GraphModule):
|
|
||||||
"""
|
|
||||||
A derivative of `fx.GraphModule`. Differs in the following ways:
|
|
||||||
- Requires a train and eval version of the underlying graph
|
|
||||||
- Copies submodules according to the nodes of both train and eval graphs.
|
|
||||||
- Calling train(mode) switches between train graph and eval graph.
|
|
||||||
"""
|
|
||||||
def __init__(self,
|
|
||||||
root: torch.nn.Module,
|
|
||||||
train_graph: fx.Graph,
|
|
||||||
eval_graph: fx.Graph,
|
|
||||||
class_name: str = 'GraphModule'):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
root (torch.nn.Module): module from which the copied module
|
|
||||||
hierarchy is built
|
|
||||||
train_graph (Graph): the graph that should be used in train mode
|
|
||||||
eval_graph (Graph): the graph that should be used in eval mode
|
|
||||||
"""
|
|
||||||
super(fx.GraphModule, self).__init__()
|
|
||||||
|
|
||||||
self.__class__.__name__ = class_name
|
|
||||||
|
|
||||||
self.train_graph = train_graph
|
|
||||||
self.eval_graph = eval_graph
|
|
||||||
|
|
||||||
# Copy all get_attr and call_module ops (indicated by BOTH train and
|
|
||||||
# eval graphs)
|
|
||||||
for node in chain(iter(train_graph.nodes), iter(eval_graph.nodes)):
|
|
||||||
if node.op in ['get_attr', 'call_module']:
|
|
||||||
assert isinstance(node.target, str)
|
|
||||||
_copy_attr(root, self, node.target)
|
|
||||||
|
|
||||||
# eval mode by default
|
|
||||||
self.eval()
|
|
||||||
self.graph = eval_graph
|
|
||||||
|
|
||||||
# (borrowed from fx.GraphModule):
|
|
||||||
# Store the Tracer class responsible for creating a Graph separately as part of the
|
|
||||||
# GraphModule state, except when the Tracer is defined in a local namespace.
|
|
||||||
# Locally defined Tracers are not pickleable. This is needed because torch.package will
|
|
||||||
# serialize a GraphModule without retaining the Graph, and needs to use the correct Tracer
|
|
||||||
# to re-create the Graph during deserialization.
|
|
||||||
# TODO uncomment this when https://github.com/pytorch/pytorch/pull/63121 is available
|
|
||||||
# assert self.eval_graph._tracer_cls == self.train_graph._tracer_cls, \
|
|
||||||
# "Train mode and eval mode should use the same tracer class"
|
|
||||||
# self._tracer_cls = None
|
|
||||||
# if self.graph._tracer_cls and '<locals>' not in self.graph._tracer_cls.__qualname__:
|
|
||||||
# self._tracer_cls = self.graph._tracer_cls
|
|
||||||
|
|
||||||
def train(self, mode=True):
|
|
||||||
"""
|
|
||||||
Swap out the graph depending on the training mode.
|
|
||||||
NOTE this should be safe when calling model.eval() because that just
|
|
||||||
calls this with mode == False.
|
|
||||||
"""
|
|
||||||
if mode:
|
|
||||||
self.graph = self.train_graph
|
|
||||||
else:
|
|
||||||
self.graph = self.eval_graph
|
|
||||||
return super().train(mode=mode)
|
|
||||||
|
|
||||||
|
|
||||||
def build_feature_graph_net(
|
|
||||||
model: nn.Module,
|
|
||||||
return_nodes: Union[List[str], Dict[str, str]],
|
|
||||||
train_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None,
|
|
||||||
eval_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None,
|
|
||||||
tracer_kwargs: Dict = {}) -> fx.GraphModule:
|
|
||||||
"""
|
|
||||||
Creates a new graph module that returns intermediate nodes from a given
|
|
||||||
model as dictionary with user specified keys as strings, and the requested
|
|
||||||
outputs as values. This is achieved by re-writing the computation graph of
|
|
||||||
the model via FX to return the desired nodes as outputs. All unused nodes
|
|
||||||
are removed, together with their corresponding parameters.
|
|
||||||
|
|
||||||
A note on node specification: A node qualified name is specified as a `.`
|
|
||||||
seperated path walking the hierarchy from top level module down to leaf
|
|
||||||
operation or leaf module. For instance `blocks.5.3.bn1`. The keys of the
|
|
||||||
`return_nodes` argument should point to either a node's qualified name,
|
|
||||||
or some truncated version of it. For example, one could provide `blocks.5`
|
|
||||||
as a key, and the last node with that prefix will be selected.
|
|
||||||
`print_graph_node_qualified_names` is a useful helper function for getting
|
|
||||||
a list of qualified names of a model.
|
|
||||||
|
|
||||||
An attempt is made to keep all non-parametric properties of the original
|
|
||||||
model, but existing properties of the constructed `GraphModule` are not
|
|
||||||
overwritten.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model (nn.Module): model on which we will extract the features
|
|
||||||
return_nodes (Union[List[name], Dict[name, new_name]])): either a list
|
|
||||||
or a dict containing the names (or partial names - see note above)
|
|
||||||
of the nodes for which the activations will be returned. If it is
|
|
||||||
a `Dict`, the keys are the qualified node names, and the values
|
|
||||||
are the user-specified keys for the graph module's returned
|
|
||||||
dictionary. If it is a `List`, it is treated as a `Dict` mapping
|
|
||||||
node specification strings directly to output names.
|
|
||||||
tracer_kwargs (Dict): a dictionary of keywork arguments for
|
|
||||||
`NodePathTracer` (which passes them onto it's parent class
|
|
||||||
`torch.fx.Tracer`).
|
|
||||||
|
|
||||||
Examples::
|
|
||||||
|
|
||||||
>>> model = torchvision.models.resnet18()
|
|
||||||
>>> # extract layer1 and layer3, giving as names `feat1` and feat2`
|
|
||||||
>>> graph_module = torchvision.models._utils.build_feature_graph_net(m,
|
|
||||||
>>> {'layer1': 'feat1', 'layer3': 'feat2'})
|
|
||||||
>>> out = graph_module(torch.rand(1, 3, 224, 224))
|
|
||||||
>>> print([(k, v.shape) for k, v in out.items()])
|
|
||||||
>>> [('feat1', torch.Size([1, 64, 56, 56])),
|
|
||||||
>>> ('feat2', torch.Size([1, 256, 14, 14]))]
|
|
||||||
|
|
||||||
"""
|
|
||||||
is_training = model.training
|
|
||||||
|
|
||||||
if isinstance(return_nodes, list):
|
|
||||||
return_nodes = {n: n for n in return_nodes}
|
|
||||||
return_nodes = {str(k): str(v) for k, v in return_nodes.items()}
|
|
||||||
|
|
||||||
assert not ((train_return_nodes is None) ^ (eval_return_nodes is None)), \
|
|
||||||
("If any of `train_return_nodes` and `eval_return_nodes` are "
|
|
||||||
"specified, then both should be specified")
|
|
||||||
|
|
||||||
if train_return_nodes is None:
|
|
||||||
train_return_nodes = deepcopy(return_nodes)
|
|
||||||
eval_return_nodes = deepcopy(return_nodes)
|
|
||||||
|
|
||||||
# Repeat the tracing and graph rewriting for train and eval mode
|
|
||||||
tracers = {}
|
|
||||||
graphs = {}
|
|
||||||
return_nodes = {
|
|
||||||
'train': train_return_nodes,
|
|
||||||
'eval': eval_return_nodes
|
|
||||||
}
|
|
||||||
for mode in ['train', 'eval']:
|
|
||||||
if mode == 'train':
|
|
||||||
model.train()
|
|
||||||
elif mode == 'eval':
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
# Instantiate our NodePathTracer and use that to trace the model
|
|
||||||
tracer = NodePathTracer(**tracer_kwargs)
|
|
||||||
graph = tracer.trace(model)
|
|
||||||
|
|
||||||
name = model.__class__.__name__ if isinstance(
|
|
||||||
model, nn.Module) else model.__name__
|
|
||||||
graph_module = fx.GraphModule(tracer.root, graph, name)
|
|
||||||
|
|
||||||
available_nodes = [f'{v}.{k}' for k, v in tracer.node_to_qualname.items()]
|
|
||||||
# FIXME We don't know if we should expect this to happen
|
|
||||||
assert len(set(available_nodes)) == len(available_nodes), \
|
|
||||||
"There are duplicate nodes! Please raise an issue https://github.com/pytorch/vision/issues"
|
|
||||||
# Check that all outputs in return_nodes are present in the model
|
|
||||||
for query in return_nodes[mode].keys():
|
|
||||||
if not any([m.startswith(query) for m in available_nodes]):
|
|
||||||
raise ValueError(f"return_node: {query} is not present in model")
|
|
||||||
|
|
||||||
# Remove existing output nodes (train mode)
|
|
||||||
orig_output_nodes = []
|
|
||||||
for n in reversed(graph_module.graph.nodes):
|
|
||||||
if n.op == "output":
|
|
||||||
orig_output_nodes.append(n)
|
|
||||||
assert len(orig_output_nodes)
|
|
||||||
for n in orig_output_nodes:
|
|
||||||
graph_module.graph.erase_node(n)
|
|
||||||
|
|
||||||
# Find nodes corresponding to return_nodes and make them into output_nodes
|
|
||||||
nodes = [n for n in graph_module.graph.nodes]
|
|
||||||
output_nodes = OrderedDict()
|
|
||||||
for n in reversed(nodes):
|
|
||||||
if 'tensor_constant' in str(n):
|
|
||||||
# NOTE Without this control flow we would get a None value for
|
|
||||||
# `module_qualname = tracer.node_to_qualname.get(n)`.
|
|
||||||
# On the other hand, we can safely assume that we'll never need to
|
|
||||||
# get this as an interesting intermediate node.
|
|
||||||
continue
|
|
||||||
module_qualname = tracer.node_to_qualname.get(n)
|
|
||||||
for query in return_nodes[mode]:
|
|
||||||
depth = query.count('.')
|
|
||||||
if '.'.join(module_qualname.split('.')[:depth + 1]) == query:
|
|
||||||
output_nodes[return_nodes[mode][query]] = n
|
|
||||||
return_nodes[mode].pop(query)
|
|
||||||
break
|
|
||||||
output_nodes = OrderedDict(reversed(list(output_nodes.items())))
|
|
||||||
|
|
||||||
# And add them in the end of the graph
|
|
||||||
with graph_module.graph.inserting_after(nodes[-1]):
|
|
||||||
graph_module.graph.output(output_nodes)
|
|
||||||
|
|
||||||
# Remove unused modules / parameters
|
|
||||||
graph_module.graph.eliminate_dead_code()
|
|
||||||
graph_module.recompile()
|
|
||||||
|
|
||||||
# Keep track of the tracer and graph so we can choose the main one
|
|
||||||
tracers[mode] = tracer
|
|
||||||
graphs[mode] = graph
|
|
||||||
|
|
||||||
# Warn user if there are any discrepancies between the graphs of the
|
|
||||||
# train and eval modes
|
|
||||||
_warn_graph_differences(tracers['train'], tracers['eval'])
|
|
||||||
|
|
||||||
# Build the final graph module
|
|
||||||
graph_module = DualGraphModule(
|
|
||||||
model, graphs['train'], graphs['eval'], class_name=name)
|
|
||||||
|
|
||||||
# Keep non-parameter model properties for reference
|
|
||||||
for attr_str in model.__dir__():
|
|
||||||
attr = getattr(model, attr_str)
|
|
||||||
if (not attr_str.startswith('_')
|
|
||||||
and attr_str not in graph_module.__dir__()
|
|
||||||
and not ismethod(attr)
|
|
||||||
and not isinstance(attr, (nn.Module, nn.Parameter))):
|
|
||||||
setattr(graph_module, attr_str, attr)
|
|
||||||
|
|
||||||
# Restore original training mode
|
|
||||||
graph_module.train(is_training)
|
|
||||||
|
|
||||||
return graph_module
|
|
||||||
|
|
||||||
|
|
||||||
class FeatureGraphNet(nn.Module):
|
class FeatureGraphNet(nn.Module):
|
||||||
@ -483,7 +63,10 @@ class FeatureGraphNet(nn.Module):
|
|||||||
assert len(out_map) == len(out_indices)
|
assert len(out_map) == len(out_indices)
|
||||||
return_nodes = {info['module']: out_map[i] if out_map is not None else info['module']
|
return_nodes = {info['module']: out_map[i] if out_map is not None else info['module']
|
||||||
for i, info in enumerate(self.feature_info) if i in out_indices}
|
for i, info in enumerate(self.feature_info) if i in out_indices}
|
||||||
self.graph_module = build_feature_graph_net(model, return_nodes)
|
self.graph_module = create_feature_extractor(
|
||||||
|
model, return_nodes,
|
||||||
|
tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return list(self.graph_module(x).values())
|
return list(self.graph_module(x).values())
|
||||||
|
|
@ -1,7 +0,0 @@
|
|||||||
|
|
||||||
def fx_float_to_int(x: float) -> int:
|
|
||||||
"""
|
|
||||||
Symbolic tracing helper to substitute for inbuilt `int`.
|
|
||||||
Hint: Inbuilt `int` can't accept an argument of type `Proxy`
|
|
||||||
"""
|
|
||||||
return int(x)
|
|
@ -22,7 +22,7 @@ import torch.nn.functional as F
|
|||||||
|
|
||||||
from .helpers import to_2tuple, make_divisible
|
from .helpers import to_2tuple, make_divisible
|
||||||
from .weight_init import trunc_normal_
|
from .weight_init import trunc_normal_
|
||||||
from timm.models.fx_helpers import fx_and
|
from .trace_utils import _assert
|
||||||
|
|
||||||
|
|
||||||
def rel_logits_1d(q, rel_k, permute_mask: List[int]):
|
def rel_logits_1d(q, rel_k, permute_mask: List[int]):
|
||||||
@ -37,7 +37,7 @@ def rel_logits_1d(q, rel_k, permute_mask: List[int]):
|
|||||||
permute_mask: permute output dim according to this
|
permute_mask: permute output dim according to this
|
||||||
"""
|
"""
|
||||||
B, H, W, dim = q.shape
|
B, H, W, dim = q.shape
|
||||||
x = torch.matmul(q, rel_k.transpose(-1, -2))
|
x = (q @ rel_k.transpose(-1, -2))
|
||||||
x = x.reshape(-1, W, 2 * W -1)
|
x = x.reshape(-1, W, 2 * W -1)
|
||||||
|
|
||||||
# pad to shift from relative to absolute indexing
|
# pad to shift from relative to absolute indexing
|
||||||
@ -134,8 +134,8 @@ class BottleneckAttn(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
B, C, H, W = x.shape
|
B, C, H, W = x.shape
|
||||||
torch._assert(H == self.pos_embed.height, '')
|
_assert(H == self.pos_embed.height, '')
|
||||||
torch._assert(W == self.pos_embed.width, '')
|
_assert(W == self.pos_embed.width, '')
|
||||||
|
|
||||||
x = self.qkv(x) # B, (2 * dim_head_qk + dim_head_v) * num_heads, H, W
|
x = self.qkv(x) # B, (2 * dim_head_qk + dim_head_v) * num_heads, H, W
|
||||||
|
|
||||||
|
@ -12,6 +12,8 @@ Hacked together by / Copyright 2020 Ross Wightman
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from .trace_utils import _assert
|
||||||
|
|
||||||
|
|
||||||
class EvoNormBatch2d(nn.Module):
|
class EvoNormBatch2d(nn.Module):
|
||||||
def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-5, drop_block=None):
|
def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-5, drop_block=None):
|
||||||
@ -72,9 +74,9 @@ class EvoNormSample2d(nn.Module):
|
|||||||
nn.init.ones_(self.v)
|
nn.init.ones_(self.v)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
torch._assert(x.dim() == 4, 'expected 4D input')
|
_assert(x.dim() == 4, 'expected 4D input')
|
||||||
B, C, H, W = x.shape
|
B, C, H, W = x.shape
|
||||||
torch._assert(C % self.groups == 0, '')
|
_assert(C % self.groups == 0, '')
|
||||||
if self.apply_act:
|
if self.apply_act:
|
||||||
n = x * (x * self.v).sigmoid()
|
n = x * (x * self.v).sigmoid()
|
||||||
x = x.reshape(B, self.groups, -1)
|
x = x.reshape(B, self.groups, -1)
|
||||||
|
@ -7,7 +7,6 @@ Official code consulted as reference: https://github.com/xvjiarui/GCNet
|
|||||||
|
|
||||||
Hacked together by / Copyright 2021 Ross Wightman
|
Hacked together by / Copyright 2021 Ross Wightman
|
||||||
"""
|
"""
|
||||||
import torch
|
|
||||||
from torch import nn as nn
|
from torch import nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
@ -53,7 +52,7 @@ class GlobalContext(nn.Module):
|
|||||||
if self.conv_attn is not None:
|
if self.conv_attn is not None:
|
||||||
attn = self.conv_attn(x).reshape(B, 1, H * W) # (B, 1, H * W)
|
attn = self.conv_attn(x).reshape(B, 1, H * W) # (B, 1, H * W)
|
||||||
attn = F.softmax(attn, dim=-1).unsqueeze(3) # (B, 1, H * W, 1)
|
attn = F.softmax(attn, dim=-1).unsqueeze(3) # (B, 1, H * W, 1)
|
||||||
context = torch.matmul(x.reshape(B, C, H * W).unsqueeze(1), attn)
|
context = x.reshape(B, C, H * W).unsqueeze(1) @ attn
|
||||||
context = context.view(B, C, 1, 1)
|
context = context.view(B, C, 1, 1)
|
||||||
else:
|
else:
|
||||||
context = x.mean(dim=(2, 3), keepdim=True)
|
context = x.mean(dim=(2, 3), keepdim=True)
|
||||||
|
@ -16,7 +16,7 @@ The attention mechanism works but it's slow as implemented.
|
|||||||
|
|
||||||
Hacked together by / Copyright 2021 Ross Wightman
|
Hacked together by / Copyright 2021 Ross Wightman
|
||||||
"""
|
"""
|
||||||
from typing import Tuple, List
|
from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -24,6 +24,7 @@ import torch.nn.functional as F
|
|||||||
|
|
||||||
from .helpers import make_divisible
|
from .helpers import make_divisible
|
||||||
from .weight_init import trunc_normal_
|
from .weight_init import trunc_normal_
|
||||||
|
from .trace_utils import _assert
|
||||||
|
|
||||||
|
|
||||||
def rel_logits_1d(q, rel_k, permute_mask: List[int]):
|
def rel_logits_1d(q, rel_k, permute_mask: List[int]):
|
||||||
@ -41,7 +42,7 @@ def rel_logits_1d(q, rel_k, permute_mask: List[int]):
|
|||||||
rel_size = rel_k.shape[0]
|
rel_size = rel_k.shape[0]
|
||||||
win_size = (rel_size + 1) // 2
|
win_size = (rel_size + 1) // 2
|
||||||
|
|
||||||
x = torch.matmul(q, rel_k.transpose(-1, -2))
|
x = (q @ rel_k.transpose(-1, -2))
|
||||||
x = x.reshape(-1, W, rel_size)
|
x = x.reshape(-1, W, rel_size)
|
||||||
|
|
||||||
# pad to shift from relative to absolute indexing
|
# pad to shift from relative to absolute indexing
|
||||||
@ -167,8 +168,8 @@ class HaloAttn(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
B, C, H, W = x.shape
|
B, C, H, W = x.shape
|
||||||
torch._assert(H % self.block_size == 0, '')
|
_assert(H % self.block_size == 0, '')
|
||||||
torch._assert(W % self.block_size == 0, '')
|
_assert(W % self.block_size == 0, '')
|
||||||
num_h_blocks = H // self.block_size
|
num_h_blocks = H // self.block_size
|
||||||
num_w_blocks = W // self.block_size
|
num_w_blocks = W // self.block_size
|
||||||
num_blocks = num_h_blocks * num_w_blocks
|
num_blocks = num_h_blocks * num_w_blocks
|
||||||
|
@ -116,8 +116,8 @@ class LambdaLayer(nn.Module):
|
|||||||
v = self.norm_v(v).reshape(B, self.dim_v, M).transpose(-1, -2) # B, M, V
|
v = self.norm_v(v).reshape(B, self.dim_v, M).transpose(-1, -2) # B, M, V
|
||||||
k = F.softmax(k.reshape(B, self.dim_qk, M), dim=-1) # B, K, M
|
k = F.softmax(k.reshape(B, self.dim_qk, M), dim=-1) # B, K, M
|
||||||
|
|
||||||
content_lam = torch.matmul(k, v) # B, K, V
|
content_lam = k @ v # B, K, V
|
||||||
content_out = torch.matmul(q, content_lam.unsqueeze(1)) # B, num_heads, M, V
|
content_out = q @ content_lam.unsqueeze(1) # B, num_heads, M, V
|
||||||
|
|
||||||
if self.pos_emb is None:
|
if self.pos_emb is None:
|
||||||
position_lam = self.conv_lambda(v.reshape(B, 1, H, W, self.dim_v)) # B, H, W, V, K
|
position_lam = self.conv_lambda(v.reshape(B, 1, H, W, self.dim_v)) # B, H, W, V, K
|
||||||
|
@ -10,7 +10,7 @@ from torch.nn import functional as F
|
|||||||
|
|
||||||
from .conv_bn_act import ConvBnAct
|
from .conv_bn_act import ConvBnAct
|
||||||
from .helpers import make_divisible
|
from .helpers import make_divisible
|
||||||
from timm.models.fx_helpers import fx_and
|
from .trace_utils import _assert
|
||||||
|
|
||||||
|
|
||||||
class NonLocalAttn(nn.Module):
|
class NonLocalAttn(nn.Module):
|
||||||
@ -84,7 +84,7 @@ class BilinearAttnTransform(nn.Module):
|
|||||||
|
|
||||||
def resize_mat(self, x, t: int):
|
def resize_mat(self, x, t: int):
|
||||||
B, C, block_size, block_size1 = x.shape
|
B, C, block_size, block_size1 = x.shape
|
||||||
torch._assert(block_size == block_size1, '')
|
_assert(block_size == block_size1, '')
|
||||||
if t <= 1:
|
if t <= 1:
|
||||||
return x
|
return x
|
||||||
x = x.view(B * C, -1, 1, 1)
|
x = x.view(B * C, -1, 1, 1)
|
||||||
@ -96,8 +96,8 @@ class BilinearAttnTransform(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
torch._assert(x.shape[-1] % self.block_size == 0, '')
|
_assert(x.shape[-1] % self.block_size == 0, '')
|
||||||
torch._assert(x.shape[-2] % self.block_size == 0, '')
|
_assert(x.shape[-2] % self.block_size == 0, '')
|
||||||
B, C, H, W = x.shape
|
B, C, H, W = x.shape
|
||||||
out = self.conv1(x)
|
out = self.conv1(x)
|
||||||
rp = F.adaptive_max_pool2d(out, (self.block_size, 1))
|
rp = F.adaptive_max_pool2d(out, (self.block_size, 1))
|
||||||
|
@ -9,6 +9,7 @@ from torch import nn as nn
|
|||||||
|
|
||||||
from .conv_bn_act import ConvBnAct
|
from .conv_bn_act import ConvBnAct
|
||||||
from .helpers import make_divisible
|
from .helpers import make_divisible
|
||||||
|
from .trace_utils import _assert
|
||||||
|
|
||||||
|
|
||||||
def _kernel_valid(k):
|
def _kernel_valid(k):
|
||||||
@ -34,7 +35,7 @@ class SelectiveKernelAttn(nn.Module):
|
|||||||
self.fc_select = nn.Conv2d(attn_channels, channels * num_paths, kernel_size=1, bias=False)
|
self.fc_select = nn.Conv2d(attn_channels, channels * num_paths, kernel_size=1, bias=False)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
torch._assert(x.shape[1] == self.num_paths, '')
|
_assert(x.shape[1] == self.num_paths, '')
|
||||||
x = x.sum(1).mean((2, 3), keepdim=True)
|
x = x.sum(1).mean((2, 3), keepdim=True)
|
||||||
x = self.fc_reduce(x)
|
x = self.fc_reduce(x)
|
||||||
x = self.bn(x)
|
x = self.bn(x)
|
||||||
|
@ -1,183 +0,0 @@
|
|||||||
""" Shifted Window Attn
|
|
||||||
|
|
||||||
This is a WIP experiment to apply windowed attention from the Swin Transformer
|
|
||||||
to a stand-alone module for use as an attn block in conv nets.
|
|
||||||
|
|
||||||
Based on original swin window code at https://github.com/microsoft/Swin-Transformer
|
|
||||||
Swin Transformer paper: https://arxiv.org/pdf/2103.14030.pdf
|
|
||||||
"""
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
from .drop import DropPath
|
|
||||||
from .helpers import to_2tuple
|
|
||||||
from .weight_init import trunc_normal_
|
|
||||||
from timm.models.fx_helpers import fx_float_to_int
|
|
||||||
|
|
||||||
|
|
||||||
def window_partition(x, win_size: int):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
x: (B, H, W, C)
|
|
||||||
win_size (int): window size
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
windows: (num_windows*B, window_size, window_size, C)
|
|
||||||
"""
|
|
||||||
B, H, W, C = x.shape
|
|
||||||
x = x.view(B, H // win_size, win_size, W // win_size, win_size, C)
|
|
||||||
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, win_size, win_size, C)
|
|
||||||
return windows
|
|
||||||
|
|
||||||
|
|
||||||
def window_reverse(windows, win_size: int, H: int, W: int):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
windows: (num_windows*B, window_size, window_size, C)
|
|
||||||
win_size (int): Window size
|
|
||||||
H (int): Height of image
|
|
||||||
W (int): Width of image
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
x: (B, H, W, C)
|
|
||||||
"""
|
|
||||||
B = fx_float_to_int(windows.shape[0] / (H * W / win_size / win_size))
|
|
||||||
x = windows.view(B, H // win_size, W // win_size, win_size, win_size, -1)
|
|
||||||
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class WindowAttention(nn.Module):
|
|
||||||
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
|
||||||
It supports both of shifted and non-shifted window.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dim (int): Number of input channels.
|
|
||||||
win_size (int): The height and width of the window.
|
|
||||||
num_heads (int): Number of attention heads.
|
|
||||||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
|
||||||
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, dim, dim_out=None, feat_size=None, stride=1, win_size=8, shift_size=None, num_heads=8,
|
|
||||||
qkv_bias=True, attn_drop=0.):
|
|
||||||
|
|
||||||
super().__init__()
|
|
||||||
self.dim_out = dim_out or dim
|
|
||||||
self.feat_size = to_2tuple(feat_size)
|
|
||||||
self.win_size = win_size
|
|
||||||
self.shift_size = shift_size or win_size // 2
|
|
||||||
if min(self.feat_size) <= win_size:
|
|
||||||
# if window size is larger than input resolution, we don't partition windows
|
|
||||||
self.shift_size = 0
|
|
||||||
self.win_size = min(self.feat_size)
|
|
||||||
assert 0 <= self.shift_size < self.win_size, "shift_size must in 0-window_size"
|
|
||||||
self.num_heads = num_heads
|
|
||||||
head_dim = self.dim_out // num_heads
|
|
||||||
self.scale = head_dim ** -0.5
|
|
||||||
|
|
||||||
if self.shift_size > 0:
|
|
||||||
# calculate attention mask for SW-MSA
|
|
||||||
H, W = self.feat_size
|
|
||||||
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
|
|
||||||
h_slices = (
|
|
||||||
slice(0, -self.win_size),
|
|
||||||
slice(-self.win_size, -self.shift_size),
|
|
||||||
slice(-self.shift_size, None))
|
|
||||||
w_slices = (
|
|
||||||
slice(0, -self.win_size),
|
|
||||||
slice(-self.win_size, -self.shift_size),
|
|
||||||
slice(-self.shift_size, None))
|
|
||||||
cnt = 0
|
|
||||||
for h in h_slices:
|
|
||||||
for w in w_slices:
|
|
||||||
img_mask[:, h, w, :] = cnt
|
|
||||||
cnt += 1
|
|
||||||
mask_windows = window_partition(img_mask, self.win_size) # num_win, window_size, window_size, 1
|
|
||||||
mask_windows = mask_windows.view(-1, self.win_size * self.win_size)
|
|
||||||
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
|
||||||
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
|
||||||
else:
|
|
||||||
attn_mask = None
|
|
||||||
self.register_buffer("attn_mask", attn_mask)
|
|
||||||
|
|
||||||
# define a parameter table of relative position bias
|
|
||||||
self.relative_position_bias_table = nn.Parameter(
|
|
||||||
# 2 * Wh - 1 * 2 * Ww - 1, nH
|
|
||||||
torch.zeros((2 * self.win_size - 1) * (2 * self.win_size - 1), num_heads))
|
|
||||||
trunc_normal_(self.relative_position_bias_table, std=.02)
|
|
||||||
|
|
||||||
# get pair-wise relative position index for each token inside the window
|
|
||||||
coords_h = torch.arange(self.win_size)
|
|
||||||
coords_w = torch.arange(self.win_size)
|
|
||||||
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
|
||||||
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
|
||||||
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
|
||||||
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
|
||||||
relative_coords[:, :, 0] += self.win_size - 1 # shift to start from 0
|
|
||||||
relative_coords[:, :, 1] += self.win_size - 1
|
|
||||||
relative_coords[:, :, 0] *= 2 * self.win_size - 1
|
|
||||||
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
|
||||||
self.register_buffer("relative_position_index", relative_position_index)
|
|
||||||
|
|
||||||
self.qkv = nn.Linear(dim, self.dim_out * 3, bias=qkv_bias)
|
|
||||||
self.attn_drop = nn.Dropout(attn_drop)
|
|
||||||
self.softmax = nn.Softmax(dim=-1)
|
|
||||||
self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()
|
|
||||||
|
|
||||||
def reset_parameters(self):
|
|
||||||
trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5)
|
|
||||||
trunc_normal_(self.relative_position_bias_table, std=.02)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
B, C, H, W = x.shape
|
|
||||||
x = x.permute(0, 2, 3, 1)
|
|
||||||
|
|
||||||
# cyclic shift
|
|
||||||
if self.shift_size > 0:
|
|
||||||
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
|
||||||
else:
|
|
||||||
shifted_x = x
|
|
||||||
|
|
||||||
# partition windows
|
|
||||||
win_size_sq = self.win_size * self.win_size
|
|
||||||
x_windows = window_partition(shifted_x, self.win_size) # num_win * B, window_size, window_size, C
|
|
||||||
x_windows = x_windows.view(-1, win_size_sq, C) # num_win * B, window_size*window_size, C
|
|
||||||
BW, N, _ = x_windows.shape
|
|
||||||
|
|
||||||
qkv = self.qkv(x_windows)
|
|
||||||
qkv = qkv.reshape(BW, N, 3, self.num_heads, self.dim_out // self.num_heads).permute(2, 0, 3, 1, 4)
|
|
||||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
|
||||||
q = q * self.scale
|
|
||||||
attn = torch.matmul(q, k.transpose(-2, -1))
|
|
||||||
|
|
||||||
relative_position_bias = self.relative_position_bias_table[
|
|
||||||
self.relative_position_index.view(-1)].view(win_size_sq, win_size_sq, -1)
|
|
||||||
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh * Ww, Wh * Ww
|
|
||||||
attn = attn + relative_position_bias.unsqueeze(0)
|
|
||||||
if self.attn_mask is not None:
|
|
||||||
num_win = self.attn_mask.shape[0]
|
|
||||||
attn = attn.view(B, num_win, self.num_heads, N, N) + self.attn_mask.unsqueeze(1).unsqueeze(0)
|
|
||||||
attn = attn.view(-1, self.num_heads, N, N)
|
|
||||||
attn = self.softmax(attn)
|
|
||||||
attn = self.attn_drop(attn)
|
|
||||||
|
|
||||||
x = torch.matmul(attn, v).transpose(1, 2).reshape(BW, N, self.dim_out)
|
|
||||||
|
|
||||||
# merge windows
|
|
||||||
x = x.view(-1, self.win_size, self.win_size, self.dim_out)
|
|
||||||
shifted_x = window_reverse(x, self.win_size, H, W) # B H' W' C
|
|
||||||
|
|
||||||
# reverse cyclic shift
|
|
||||||
if self.shift_size > 0:
|
|
||||||
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
|
||||||
else:
|
|
||||||
x = shifted_x
|
|
||||||
x = x.view(B, H, W, self.dim_out).permute(0, 3, 1, 2)
|
|
||||||
x = self.pool(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
@ -293,10 +293,10 @@ class Attention(nn.Module):
|
|||||||
k = k.permute(0, 2, 1, 3)
|
k = k.permute(0, 2, 1, 3)
|
||||||
v = v.permute(0, 2, 1, 3)
|
v = v.permute(0, 2, 1, 3)
|
||||||
|
|
||||||
attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale + self.get_attention_biases(x.device)
|
attn = q @ k.transpose(-2, -1) * self.scale + self.get_attention_biases(x.device)
|
||||||
attn = attn.softmax(dim=-1)
|
attn = attn.softmax(dim=-1)
|
||||||
|
|
||||||
x = torch.matmul(attn, v).transpose(1, 2).reshape(B, N, self.dh)
|
x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)
|
||||||
x = self.proj(x)
|
x = self.proj(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@ -387,10 +387,10 @@ class AttentionSubsample(nn.Module):
|
|||||||
v = v.permute(0, 2, 1, 3) # BHNC
|
v = v.permute(0, 2, 1, 3) # BHNC
|
||||||
q = self.q(x).view(B, self.resolution_2, self.num_heads, self.key_dim).permute(0, 2, 1, 3)
|
q = self.q(x).view(B, self.resolution_2, self.num_heads, self.key_dim).permute(0, 2, 1, 3)
|
||||||
|
|
||||||
attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale + self.get_attention_biases(x.device)
|
attn = q @ k.transpose(-2, -1) * self.scale + self.get_attention_biases(x.device)
|
||||||
attn = attn.softmax(dim=-1)
|
attn = attn.softmax(dim=-1)
|
||||||
|
|
||||||
x = torch.matmul(attn, v).transpose(1, 2).reshape(B, -1, self.dh)
|
x = (attn @ v).transpose(1, 2).reshape(B, -1, self.dh)
|
||||||
x = self.proj(x)
|
x = self.proj(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
@ -25,13 +25,13 @@ import torch.nn.functional as F
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
|
from .fx_features import register_autowrap_function
|
||||||
from .helpers import build_model_with_cfg, named_apply
|
from .helpers import build_model_with_cfg, named_apply
|
||||||
from .fx_helpers import fx_float_to_int
|
|
||||||
from .layers import PatchEmbed, Mlp, DropPath, create_classifier, trunc_normal_
|
from .layers import PatchEmbed, Mlp, DropPath, create_classifier, trunc_normal_
|
||||||
|
from .layers.trace_utils import _assert
|
||||||
from .layers import create_conv2d, create_pool2d, to_ntuple
|
from .layers import create_conv2d, create_pool2d, to_ntuple
|
||||||
from .registry import register_model
|
from .registry import register_model
|
||||||
|
|
||||||
|
|
||||||
_logger = logging.getLogger(__name__)
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -85,12 +85,12 @@ class Attention(nn.Module):
|
|||||||
qkv = self.qkv(x).reshape(B, T, N, 3, self.num_heads, C // self.num_heads).permute(3, 0, 4, 1, 2, 5)
|
qkv = self.qkv(x).reshape(B, T, N, 3, self.num_heads, C // self.num_heads).permute(3, 0, 4, 1, 2, 5)
|
||||||
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
||||||
|
|
||||||
attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale # (B, H, T, N, N)
|
attn = (q @ k.transpose(-2, -1)) * self.scale # (B, H, T, N, N)
|
||||||
attn = attn.softmax(dim=-1)
|
attn = attn.softmax(dim=-1)
|
||||||
attn = self.attn_drop(attn)
|
attn = self.attn_drop(attn)
|
||||||
|
|
||||||
# (B, H, T, N, C'), permute -> (B, T, N, C', H)
|
# (B, H, T, N, C'), permute -> (B, T, N, C', H)
|
||||||
x = torch.matmul(attn, v).permute(0, 2, 3, 4, 1).reshape(B, T, N, C)
|
x = (attn @ v).permute(0, 2, 3, 4, 1).reshape(B, T, N, C)
|
||||||
x = self.proj(x)
|
x = self.proj(x)
|
||||||
x = self.proj_drop(x)
|
x = self.proj_drop(x)
|
||||||
return x # (B, T, N, C)
|
return x # (B, T, N, C)
|
||||||
@ -130,8 +130,8 @@ class ConvPool(nn.Module):
|
|||||||
"""
|
"""
|
||||||
x is expected to have shape (B, C, H, W)
|
x is expected to have shape (B, C, H, W)
|
||||||
"""
|
"""
|
||||||
torch._assert(x.shape[-2] % 2 == 0, 'BlockAggregation requires even input spatial dims')
|
_assert(x.shape[-2] % 2 == 0, 'BlockAggregation requires even input spatial dims')
|
||||||
torch._assert(x.shape[-1] % 2 == 0, 'BlockAggregation requires even input spatial dims')
|
_assert(x.shape[-1] % 2 == 0, 'BlockAggregation requires even input spatial dims')
|
||||||
x = self.conv(x)
|
x = self.conv(x)
|
||||||
# Layer norm done over channel dim only
|
# Layer norm done over channel dim only
|
||||||
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||||
@ -146,8 +146,8 @@ def blockify(x, block_size: int):
|
|||||||
block_size (int): edge length of a single square block in units of H, W
|
block_size (int): edge length of a single square block in units of H, W
|
||||||
"""
|
"""
|
||||||
B, H, W, C = x.shape
|
B, H, W, C = x.shape
|
||||||
torch._assert(H % block_size == 0, '`block_size` must divide input height evenly')
|
_assert(H % block_size == 0, '`block_size` must divide input height evenly')
|
||||||
torch._assert(W % block_size == 0, '`block_size` must divide input width evenly')
|
_assert(W % block_size == 0, '`block_size` must divide input width evenly')
|
||||||
grid_height = H // block_size
|
grid_height = H // block_size
|
||||||
grid_width = W // block_size
|
grid_width = W // block_size
|
||||||
x = x.reshape(B, grid_height, block_size, grid_width, block_size, C)
|
x = x.reshape(B, grid_height, block_size, grid_width, block_size, C)
|
||||||
@ -155,6 +155,7 @@ def blockify(x, block_size: int):
|
|||||||
return x # (B, T, N, C)
|
return x # (B, T, N, C)
|
||||||
|
|
||||||
|
|
||||||
|
@register_autowrap_function # reason: int receives Proxy
|
||||||
def deblockify(x, block_size: int):
|
def deblockify(x, block_size: int):
|
||||||
"""blocks to image
|
"""blocks to image
|
||||||
Args:
|
Args:
|
||||||
@ -162,7 +163,7 @@ def deblockify(x, block_size: int):
|
|||||||
block_size (int): edge length of a single square block in units of desired H, W
|
block_size (int): edge length of a single square block in units of desired H, W
|
||||||
"""
|
"""
|
||||||
B, T, _, C = x.shape
|
B, T, _, C = x.shape
|
||||||
grid_size = fx_float_to_int(math.sqrt(T))
|
grid_size = int(math.sqrt(T))
|
||||||
height = width = grid_size * block_size
|
height = width = grid_size * block_size
|
||||||
x = x.reshape(B, grid_size, grid_size, block_size, block_size, C)
|
x = x.reshape(B, grid_size, grid_size, block_size, block_size, C)
|
||||||
x = x.transpose(2, 3).reshape(B, height, width, C)
|
x = x.transpose(2, 3).reshape(B, height, width, C)
|
||||||
|
@ -27,7 +27,6 @@ import torch.nn as nn
|
|||||||
|
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from .helpers import build_model_with_cfg
|
from .helpers import build_model_with_cfg
|
||||||
from timm.models.fx_features import register_leaf_module
|
|
||||||
from .registry import register_model
|
from .registry import register_model
|
||||||
from .layers import ClassifierHead, DropPath, AvgPool2dSame, ScaledStdConv2d, ScaledStdConv2dSame,\
|
from .layers import ClassifierHead, DropPath, AvgPool2dSame, ScaledStdConv2d, ScaledStdConv2dSame,\
|
||||||
get_act_layer, get_act_fn, get_attn, make_divisible
|
get_act_layer, get_act_fn, get_attn, make_divisible
|
||||||
@ -319,7 +318,6 @@ class DownsampleAvg(nn.Module):
|
|||||||
return self.conv(self.pool(x))
|
return self.conv(self.pool(x))
|
||||||
|
|
||||||
|
|
||||||
@register_leaf_module # FX feature extraction was giving different valued features. Perhaps to do with control flow?
|
|
||||||
class NormFreeBlock(nn.Module):
|
class NormFreeBlock(nn.Module):
|
||||||
"""Normalization-Free pre-activation block.
|
"""Normalization-Free pre-activation block.
|
||||||
"""
|
"""
|
||||||
|
@ -21,9 +21,10 @@ import torch.nn as nn
|
|||||||
import torch.utils.checkpoint as checkpoint
|
import torch.utils.checkpoint as checkpoint
|
||||||
|
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
|
from .fx_features import register_autowrap_function
|
||||||
from .helpers import build_model_with_cfg, overlay_external_default_cfg
|
from .helpers import build_model_with_cfg, overlay_external_default_cfg
|
||||||
from .fx_helpers import fx_float_to_int
|
|
||||||
from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_
|
from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_
|
||||||
|
from .layers.trace_utils import _assert
|
||||||
from .registry import register_model
|
from .registry import register_model
|
||||||
from .vision_transformer import checkpoint_filter_fn, _init_vit_weights
|
from .vision_transformer import checkpoint_filter_fn, _init_vit_weights
|
||||||
|
|
||||||
@ -102,6 +103,7 @@ def window_partition(x, window_size: int):
|
|||||||
return windows
|
return windows
|
||||||
|
|
||||||
|
|
||||||
|
@register_autowrap_function # reason: int argument is a Proxy
|
||||||
def window_reverse(windows, window_size: int, H: int, W: int):
|
def window_reverse(windows, window_size: int, H: int, W: int):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -113,7 +115,7 @@ def window_reverse(windows, window_size: int, H: int, W: int):
|
|||||||
Returns:
|
Returns:
|
||||||
x: (B, H, W, C)
|
x: (B, H, W, C)
|
||||||
"""
|
"""
|
||||||
B = fx_float_to_int(windows.shape[0] / (H * W / window_size / window_size))
|
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
||||||
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
||||||
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
||||||
return x
|
return x
|
||||||
@ -177,7 +179,7 @@ class WindowAttention(nn.Module):
|
|||||||
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
||||||
|
|
||||||
q = q * self.scale
|
q = q * self.scale
|
||||||
attn = torch.matmul(q, k.transpose(-2, -1))
|
attn = (q @ k.transpose(-2, -1))
|
||||||
|
|
||||||
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
||||||
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
|
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
|
||||||
@ -194,7 +196,7 @@ class WindowAttention(nn.Module):
|
|||||||
|
|
||||||
attn = self.attn_drop(attn)
|
attn = self.attn_drop(attn)
|
||||||
|
|
||||||
x = torch.matmul(attn, v).transpose(1, 2).reshape(B_, N, C)
|
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
||||||
x = self.proj(x)
|
x = self.proj(x)
|
||||||
x = self.proj_drop(x)
|
x = self.proj_drop(x)
|
||||||
return x
|
return x
|
||||||
@ -272,7 +274,7 @@ class SwinTransformerBlock(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
H, W = self.input_resolution
|
H, W = self.input_resolution
|
||||||
B, L, C = x.shape
|
B, L, C = x.shape
|
||||||
torch._assert(L == H * W, "input feature has wrong size")
|
_assert(L == H * W, "input feature has wrong size")
|
||||||
|
|
||||||
shortcut = x
|
shortcut = x
|
||||||
x = self.norm1(x)
|
x = self.norm1(x)
|
||||||
@ -331,8 +333,8 @@ class PatchMerging(nn.Module):
|
|||||||
"""
|
"""
|
||||||
H, W = self.input_resolution
|
H, W = self.input_resolution
|
||||||
B, L, C = x.shape
|
B, L, C = x.shape
|
||||||
torch._assert(L == H * W, "input feature has wrong size")
|
_assert(L == H * W, "input feature has wrong size")
|
||||||
torch._assert(H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even.")
|
_assert(H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even.")
|
||||||
|
|
||||||
x = x.view(B, H, W, C)
|
x = x.view(B, H, W, C)
|
||||||
|
|
||||||
|
@ -12,9 +12,9 @@ import torch.nn as nn
|
|||||||
|
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from timm.models.helpers import build_model_with_cfg
|
from timm.models.helpers import build_model_with_cfg
|
||||||
from timm.models.fx_helpers import fx_and
|
|
||||||
from timm.models.layers import Mlp, DropPath, trunc_normal_
|
from timm.models.layers import Mlp, DropPath, trunc_normal_
|
||||||
from timm.models.layers.helpers import to_2tuple
|
from timm.models.layers.helpers import to_2tuple
|
||||||
|
from timm.models.layers.trace_utils import _assert
|
||||||
from timm.models.registry import register_model
|
from timm.models.registry import register_model
|
||||||
from timm.models.vision_transformer import resize_pos_embed
|
from timm.models.vision_transformer import resize_pos_embed
|
||||||
|
|
||||||
@ -64,11 +64,11 @@ class Attention(nn.Module):
|
|||||||
q, k = qk.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
q, k = qk.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
||||||
v = self.v(x).reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
|
v = self.v(x).reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
|
||||||
|
|
||||||
attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale
|
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||||
attn = attn.softmax(dim=-1)
|
attn = attn.softmax(dim=-1)
|
||||||
attn = self.attn_drop(attn)
|
attn = self.attn_drop(attn)
|
||||||
|
|
||||||
x = torch.matmul(attn, v).transpose(1, 2).reshape(B, N, -1)
|
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
|
||||||
x = self.proj(x)
|
x = self.proj(x)
|
||||||
x = self.proj_drop(x)
|
x = self.proj_drop(x)
|
||||||
return x
|
return x
|
||||||
@ -138,9 +138,9 @@ class PixelEmbed(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x, pixel_pos):
|
def forward(self, x, pixel_pos):
|
||||||
B, C, H, W = x.shape
|
B, C, H, W = x.shape
|
||||||
torch._assert(H == self.img_size[0],
|
_assert(H == self.img_size[0],
|
||||||
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
|
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
|
||||||
torch._assert(W == self.img_size[1],
|
_assert(W == self.img_size[1],
|
||||||
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
|
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
|
||||||
x = self.proj(x)
|
x = self.proj(x)
|
||||||
x = self.unfold(x)
|
x = self.unfold(x)
|
||||||
|
@ -25,7 +25,7 @@ from .layers import Mlp, DropPath, to_2tuple, trunc_normal_
|
|||||||
from .fx_features import register_leaf_module
|
from .fx_features import register_leaf_module
|
||||||
from .registry import register_model
|
from .registry import register_model
|
||||||
from .vision_transformer import Attention
|
from .vision_transformer import Attention
|
||||||
from .helpers import build_model_with_cfg, overlay_external_default_cfg
|
from .helpers import build_model_with_cfg
|
||||||
|
|
||||||
|
|
||||||
def _cfg(url='', **kwargs):
|
def _cfg(url='', **kwargs):
|
||||||
@ -63,7 +63,7 @@ default_cfgs = {
|
|||||||
Size_ = Tuple[int, int]
|
Size_ = Tuple[int, int]
|
||||||
|
|
||||||
|
|
||||||
@register_leaf_module # FX can't symbolically trace control flow in forward method
|
@register_leaf_module # reason: FX can't symbolically trace control flow in forward method
|
||||||
class LocallyGroupedAttn(nn.Module):
|
class LocallyGroupedAttn(nn.Module):
|
||||||
""" LSA: self attention within a group
|
""" LSA: self attention within a group
|
||||||
"""
|
"""
|
||||||
@ -100,10 +100,10 @@ class LocallyGroupedAttn(nn.Module):
|
|||||||
qkv = self.qkv(x).reshape(
|
qkv = self.qkv(x).reshape(
|
||||||
B, _h * _w, self.ws * self.ws, 3, self.num_heads, C // self.num_heads).permute(3, 0, 1, 4, 2, 5)
|
B, _h * _w, self.ws * self.ws, 3, self.num_heads, C // self.num_heads).permute(3, 0, 1, 4, 2, 5)
|
||||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||||
attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale
|
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||||
attn = attn.softmax(dim=-1)
|
attn = attn.softmax(dim=-1)
|
||||||
attn = self.attn_drop(attn)
|
attn = self.attn_drop(attn)
|
||||||
attn = torch.matmul(attn, v).transpose(2, 3).reshape(B, _h, _w, self.ws, self.ws, C)
|
attn = (attn @ v).transpose(2, 3).reshape(B, _h, _w, self.ws, self.ws, C)
|
||||||
x = attn.transpose(2, 3).reshape(B, _h * self.ws, _w * self.ws, C)
|
x = attn.transpose(2, 3).reshape(B, _h * self.ws, _w * self.ws, C)
|
||||||
if pad_r > 0 or pad_b > 0:
|
if pad_r > 0 or pad_b > 0:
|
||||||
x = x[:, :H, :W, :].contiguous()
|
x = x[:, :H, :W, :].contiguous()
|
||||||
@ -185,11 +185,11 @@ class GlobalSubSampleAttn(nn.Module):
|
|||||||
kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||||
k, v = kv[0], kv[1]
|
k, v = kv[0], kv[1]
|
||||||
|
|
||||||
attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale
|
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||||
attn = attn.softmax(dim=-1)
|
attn = attn.softmax(dim=-1)
|
||||||
attn = self.attn_drop(attn)
|
attn = self.attn_drop(attn)
|
||||||
|
|
||||||
x = torch.matmul(attn, v).transpose(1, 2).reshape(B, N, C)
|
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||||
x = self.proj(x)
|
x = self.proj(x)
|
||||||
x = self.proj_drop(x)
|
x = self.proj_drop(x)
|
||||||
|
|
||||||
|
@ -13,7 +13,7 @@ from typing import Union, List, Dict, Any, cast
|
|||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from .helpers import build_model_with_cfg
|
from .helpers import build_model_with_cfg
|
||||||
from .fx_features import register_leaf_module
|
from .fx_features import register_leaf_module
|
||||||
from .layers import ClassifierHead, ConvBnAct
|
from .layers import ClassifierHead
|
||||||
from .registry import register_model
|
from .registry import register_model
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -53,7 +53,7 @@ cfgs: Dict[str, List[Union[str, int]]] = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@register_leaf_module # FX can't symbolically trace control flow in forward method
|
@register_leaf_module # reason: FX can't symbolically trace control flow in forward method
|
||||||
class ConvMlp(nn.Module):
|
class ConvMlp(nn.Module):
|
||||||
|
|
||||||
def __init__(self, in_features=512, out_features=4096, kernel_size=7, mlp_ratio=1.0,
|
def __init__(self, in_features=512, out_features=4096, kernel_size=7, mlp_ratio=1.0,
|
||||||
|
@ -100,10 +100,10 @@ class Attention(nn.Module):
|
|||||||
x = self.qkv(x).reshape(B, 3, self.num_heads, self.head_dim, -1).permute(1, 0, 2, 4, 3)
|
x = self.qkv(x).reshape(B, 3, self.num_heads, self.head_dim, -1).permute(1, 0, 2, 4, 3)
|
||||||
q, k, v = x[0], x[1], x[2]
|
q, k, v = x[0], x[1], x[2]
|
||||||
|
|
||||||
attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale
|
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||||
attn = attn.softmax(dim=-1)
|
attn = attn.softmax(dim=-1)
|
||||||
attn = self.attn_drop(attn)
|
attn = self.attn_drop(attn)
|
||||||
x = torch.matmul(attn, v)
|
x = attn @ v
|
||||||
|
|
||||||
x = x.permute(0, 1, 3, 2).reshape(B, -1, H, W)
|
x = x.permute(0, 1, 3, 2).reshape(B, -1, H, W)
|
||||||
x = self.proj(x)
|
x = self.proj(x)
|
||||||
|
@ -192,11 +192,11 @@ class Attention(nn.Module):
|
|||||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||||
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
||||||
|
|
||||||
attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale
|
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||||
attn = attn.softmax(dim=-1)
|
attn = attn.softmax(dim=-1)
|
||||||
attn = self.attn_drop(attn)
|
attn = self.attn_drop(attn)
|
||||||
|
|
||||||
x = torch.matmul(attn, v).transpose(1, 2).reshape(B, N, C)
|
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||||
x = self.proj(x)
|
x = self.proj(x)
|
||||||
x = self.proj_drop(x)
|
x = self.proj_drop(x)
|
||||||
return x
|
return x
|
||||||
|
@ -98,7 +98,7 @@ default_cfgs = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@register_leaf_module # FX can't symbolically trace torch.arange in forward method
|
@register_leaf_module # reason: FX can't symbolically trace torch.arange in forward method
|
||||||
class PositionalEncodingFourier(nn.Module):
|
class PositionalEncodingFourier(nn.Module):
|
||||||
"""
|
"""
|
||||||
Positional encoding relying on a fourier kernel matching the one used in the "Attention is all of Need" paper.
|
Positional encoding relying on a fourier kernel matching the one used in the "Attention is all of Need" paper.
|
||||||
@ -274,12 +274,12 @@ class XCA(nn.Module):
|
|||||||
# Paper section 3.2 l2-Normalization and temperature scaling
|
# Paper section 3.2 l2-Normalization and temperature scaling
|
||||||
q = torch.nn.functional.normalize(q, dim=-1)
|
q = torch.nn.functional.normalize(q, dim=-1)
|
||||||
k = torch.nn.functional.normalize(k, dim=-1)
|
k = torch.nn.functional.normalize(k, dim=-1)
|
||||||
attn = torch.matmul(q, k.transpose(-2, -1)) * self.temperature
|
attn = (q @ k.transpose(-2, -1)) * self.temperature
|
||||||
attn = attn.softmax(dim=-1)
|
attn = attn.softmax(dim=-1)
|
||||||
attn = self.attn_drop(attn)
|
attn = self.attn_drop(attn)
|
||||||
|
|
||||||
# (B, H, C', N), permute -> (B, N, H, C')
|
# (B, H, C', N), permute -> (B, N, H, C')
|
||||||
x = torch.matmul(attn, v).permute(0, 3, 1, 2).reshape(B, N, C)
|
x = (attn @ v).permute(0, 3, 1, 2).reshape(B, N, C)
|
||||||
x = self.proj(x)
|
x = self.proj(x)
|
||||||
x = self.proj_drop(x)
|
x = self.proj_drop(x)
|
||||||
return x
|
return x
|
||||||
|
Loading…
x
Reference in New Issue
Block a user