mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
All swin models support spatial output, add output_fmt to v1/v2 and use ClassifierHead.
* update ClassifierHead to allow different input format * add output format support to patch embed * fix some flatten issues for a few conv head models * add Format enum and helpers for tensor format (layout) choices
This commit is contained in:
parent
c30a160d3e
commit
acfd85ad68
@ -27,7 +27,8 @@ except ImportError:
|
||||
|
||||
import timm
|
||||
from timm import list_models, create_model, set_scriptable, get_pretrained_cfg_value
|
||||
from timm.models._features_fx import _leaf_modules, _autowrap_functions
|
||||
from timm.layers import Format, get_spatial_dim, get_channel_dim
|
||||
from timm.models import get_notrace_modules, get_notrace_functions
|
||||
|
||||
if hasattr(torch._C, '_jit_set_profiling_executor'):
|
||||
# legacy executor is too slow to compile large models for unit tests
|
||||
@ -37,10 +38,10 @@ if hasattr(torch._C, '_jit_set_profiling_executor'):
|
||||
|
||||
# transformer models don't support many of the spatial / feature based model functionalities
|
||||
NON_STD_FILTERS = [
|
||||
'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
|
||||
'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
|
||||
'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit*',
|
||||
'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*',
|
||||
'eva_*', 'flexivit*'
|
||||
'poolformer_*', 'volo_*', 'sequencer2d_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*',
|
||||
'eva_*', 'flexivit*',
|
||||
]
|
||||
NUM_NON_STD = len(NON_STD_FILTERS)
|
||||
|
||||
@ -52,7 +53,7 @@ if 'GITHUB_ACTIONS' in os.environ:
|
||||
'*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', '*efficientnetv2_xl*',
|
||||
'*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', 'vit_huge*', 'vit_gi*', 'swin*huge*',
|
||||
'swin*giant*', 'convnextv2_huge*', 'maxvit_xlarge*', 'davit_giant', 'davit_huge']
|
||||
NON_STD_EXCLUDE_FILTERS = ['vit_huge*', 'vit_gi*', 'swin*giant*', 'eva_giant*']
|
||||
NON_STD_EXCLUDE_FILTERS = ['vit_huge*', 'vit_gi*', 'eva_giant*']
|
||||
else:
|
||||
EXCLUDE_FILTERS = []
|
||||
NON_STD_EXCLUDE_FILTERS = ['vit_gi*']
|
||||
@ -156,6 +157,10 @@ def test_model_default_cfgs(model_name, batch_size):
|
||||
|
||||
pool_size = cfg['pool_size']
|
||||
input_size = model.default_cfg['input_size']
|
||||
output_fmt = getattr(model, 'output_fmt', 'NCHW')
|
||||
spatial_axis = get_spatial_dim(output_fmt)
|
||||
assert len(spatial_axis) == 2 # TODO add 1D sequence support
|
||||
feat_axis = get_channel_dim(output_fmt)
|
||||
|
||||
if all([x <= MAX_FWD_OUT_SIZE for x in input_size]) and \
|
||||
not any([fnmatch.fnmatch(model_name, x) for x in EXCLUDE_FILTERS]):
|
||||
@ -165,13 +170,14 @@ def test_model_default_cfgs(model_name, batch_size):
|
||||
|
||||
# test forward_features (always unpooled)
|
||||
outputs = model.forward_features(input_tensor)
|
||||
assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2], 'unpooled feature shape != config'
|
||||
assert outputs.shape[spatial_axis[0]] == pool_size[0], 'unpooled feature shape != config'
|
||||
assert outputs.shape[spatial_axis[1]] == pool_size[1], 'unpooled feature shape != config'
|
||||
|
||||
# test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features
|
||||
model.reset_classifier(0)
|
||||
outputs = model.forward(input_tensor)
|
||||
assert len(outputs.shape) == 2
|
||||
assert outputs.shape[-1] == model.num_features
|
||||
assert outputs.shape[1] == model.num_features
|
||||
|
||||
# test model forward without pooling and classifier
|
||||
model.reset_classifier(0, '') # reset classifier and set global pooling to pass-through
|
||||
@ -179,7 +185,7 @@ def test_model_default_cfgs(model_name, batch_size):
|
||||
assert len(outputs.shape) == 4
|
||||
if not isinstance(model, (timm.models.MobileNetV3, timm.models.GhostNet, timm.models.VGG)):
|
||||
# mobilenetv3/ghostnet/vgg forward_features vs removed pooling differ due to location or lack of GAP
|
||||
assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2]
|
||||
assert outputs.shape[spatial_axis[0]] == pool_size[0] and outputs.shape[spatial_axis[1]] == pool_size[1]
|
||||
|
||||
if 'pruned' not in model_name: # FIXME better pruned model handling
|
||||
# test classifier + global pool deletion via __init__
|
||||
@ -187,7 +193,7 @@ def test_model_default_cfgs(model_name, batch_size):
|
||||
outputs = model.forward(input_tensor)
|
||||
assert len(outputs.shape) == 4
|
||||
if not isinstance(model, (timm.models.MobileNetV3, timm.models.GhostNet, timm.models.VGG)):
|
||||
assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2]
|
||||
assert outputs.shape[spatial_axis[0]] == pool_size[0] and outputs.shape[spatial_axis[1]] == pool_size[1]
|
||||
|
||||
# check classifier name matches default_cfg
|
||||
if cfg.get('num_classes', None):
|
||||
@ -330,176 +336,181 @@ def test_model_forward_features(model_name, batch_size):
|
||||
model = create_model(model_name, pretrained=False, features_only=True)
|
||||
model.eval()
|
||||
expected_channels = model.feature_info.channels()
|
||||
expected_reduction = model.feature_info.reduction()
|
||||
assert len(expected_channels) >= 4 # all models here should have at least 4 feature levels by default, some 5 or 6
|
||||
|
||||
input_size = _get_input_size(model=model, target=TARGET_FFEAT_SIZE)
|
||||
if max(input_size) > MAX_FFEAT_SIZE:
|
||||
pytest.skip("Fixed input size model > limit.")
|
||||
output_fmt = getattr(model, 'output_fmt', 'NCHW')
|
||||
feat_axis = get_channel_dim(output_fmt)
|
||||
spatial_axis = get_spatial_dim(output_fmt)
|
||||
import math
|
||||
|
||||
outputs = model(torch.randn((batch_size, *input_size)))
|
||||
assert len(expected_channels) == len(outputs)
|
||||
for e, o in zip(expected_channels, outputs):
|
||||
assert e == o.shape[1]
|
||||
spatial_size = input_size[-2:]
|
||||
for e, r, o in zip(expected_channels, expected_reduction, outputs):
|
||||
assert e == o.shape[feat_axis]
|
||||
assert o.shape[spatial_axis[0]] <= math.ceil(spatial_size[0] / r) + 1
|
||||
assert o.shape[spatial_axis[1]] <= math.ceil(spatial_size[1] / r) + 1
|
||||
assert o.shape[0] == batch_size
|
||||
assert not torch.isnan(o).any()
|
||||
|
||||
|
||||
if not _IS_MAC:
|
||||
# MACOS test runners are really slow, only running tests below this point if not on a Darwin runner...
|
||||
def _create_fx_model(model, train=False):
|
||||
# This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode
|
||||
# So we trace once and look at the graph, and get the indices of the nodes that lead into the original fx output
|
||||
# node. Then we use those indices to select from train_nodes returned by torchvision get_graph_node_names
|
||||
tracer_kwargs = dict(
|
||||
leaf_modules=get_notrace_modules(),
|
||||
autowrap_functions=get_notrace_functions(),
|
||||
#enable_cpatching=True,
|
||||
param_shapes_constant=True
|
||||
)
|
||||
train_nodes, eval_nodes = get_graph_node_names(model, tracer_kwargs=tracer_kwargs)
|
||||
|
||||
def _create_fx_model(model, train=False):
|
||||
# This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode
|
||||
# So we trace once and look at the graph, and get the indices of the nodes that lead into the original fx output
|
||||
# node. Then we use those indices to select from train_nodes returned by torchvision get_graph_node_names
|
||||
tracer_kwargs = dict(
|
||||
leaf_modules=list(_leaf_modules),
|
||||
autowrap_functions=list(_autowrap_functions),
|
||||
#enable_cpatching=True,
|
||||
param_shapes_constant=True
|
||||
)
|
||||
train_nodes, eval_nodes = get_graph_node_names(model, tracer_kwargs=tracer_kwargs)
|
||||
eval_return_nodes = [eval_nodes[-1]]
|
||||
train_return_nodes = [train_nodes[-1]]
|
||||
if train:
|
||||
tracer = NodePathTracer(**tracer_kwargs)
|
||||
graph = tracer.trace(model)
|
||||
graph_nodes = list(reversed(graph.nodes))
|
||||
output_node_names = [n.name for n in graph_nodes[0]._input_nodes.keys()]
|
||||
graph_node_names = [n.name for n in graph_nodes]
|
||||
output_node_indices = [-graph_node_names.index(node_name) for node_name in output_node_names]
|
||||
train_return_nodes = [train_nodes[ix] for ix in output_node_indices]
|
||||
|
||||
eval_return_nodes = [eval_nodes[-1]]
|
||||
train_return_nodes = [train_nodes[-1]]
|
||||
if train:
|
||||
tracer = NodePathTracer(**tracer_kwargs)
|
||||
graph = tracer.trace(model)
|
||||
graph_nodes = list(reversed(graph.nodes))
|
||||
output_node_names = [n.name for n in graph_nodes[0]._input_nodes.keys()]
|
||||
graph_node_names = [n.name for n in graph_nodes]
|
||||
output_node_indices = [-graph_node_names.index(node_name) for node_name in output_node_names]
|
||||
train_return_nodes = [train_nodes[ix] for ix in output_node_indices]
|
||||
|
||||
fx_model = create_feature_extractor(
|
||||
model,
|
||||
train_return_nodes=train_return_nodes,
|
||||
eval_return_nodes=eval_return_nodes,
|
||||
tracer_kwargs=tracer_kwargs,
|
||||
)
|
||||
return fx_model
|
||||
fx_model = create_feature_extractor(
|
||||
model,
|
||||
train_return_nodes=train_return_nodes,
|
||||
eval_return_nodes=eval_return_nodes,
|
||||
tracer_kwargs=tracer_kwargs,
|
||||
)
|
||||
return fx_model
|
||||
|
||||
|
||||
EXCLUDE_FX_FILTERS = ['vit_gi*']
|
||||
# not enough memory to run fx on more models than other tests
|
||||
if 'GITHUB_ACTIONS' in os.environ:
|
||||
EXCLUDE_FX_FILTERS += [
|
||||
'beit_large*',
|
||||
'mixer_l*',
|
||||
'*nfnet_f2*',
|
||||
'*resnext101_32x32d',
|
||||
'resnetv2_152x2*',
|
||||
'resmlp_big*',
|
||||
'resnetrs270',
|
||||
'swin_large*',
|
||||
'vgg*',
|
||||
'vit_large*',
|
||||
'vit_base_patch8*',
|
||||
'xcit_large*',
|
||||
]
|
||||
EXCLUDE_FX_FILTERS = ['vit_gi*']
|
||||
# not enough memory to run fx on more models than other tests
|
||||
if 'GITHUB_ACTIONS' in os.environ:
|
||||
EXCLUDE_FX_FILTERS += [
|
||||
'beit_large*',
|
||||
'mixer_l*',
|
||||
'*nfnet_f2*',
|
||||
'*resnext101_32x32d',
|
||||
'resnetv2_152x2*',
|
||||
'resmlp_big*',
|
||||
'resnetrs270',
|
||||
'swin_large*',
|
||||
'vgg*',
|
||||
'vit_large*',
|
||||
'vit_base_patch8*',
|
||||
'xcit_large*',
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.fxforward
|
||||
@pytest.mark.timeout(120)
|
||||
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FX_FILTERS))
|
||||
@pytest.mark.parametrize('batch_size', [1])
|
||||
def test_model_forward_fx(model_name, batch_size):
|
||||
"""
|
||||
Symbolically trace each model and run single forward pass through the resulting GraphModule
|
||||
Also check that the output of a forward pass through the GraphModule is the same as that from the original Module
|
||||
"""
|
||||
if not has_fx_feature_extraction:
|
||||
pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.")
|
||||
|
||||
model = create_model(model_name, pretrained=False)
|
||||
model.eval()
|
||||
|
||||
input_size = _get_input_size(model=model, target=TARGET_FWD_FX_SIZE)
|
||||
if max(input_size) > MAX_FWD_FX_SIZE:
|
||||
pytest.skip("Fixed input size model > limit.")
|
||||
with torch.no_grad():
|
||||
inputs = torch.randn((batch_size, *input_size))
|
||||
outputs = model(inputs)
|
||||
if isinstance(outputs, tuple):
|
||||
outputs = torch.cat(outputs)
|
||||
|
||||
model = _create_fx_model(model)
|
||||
fx_outputs = tuple(model(inputs).values())
|
||||
if isinstance(fx_outputs, tuple):
|
||||
fx_outputs = torch.cat(fx_outputs)
|
||||
|
||||
assert torch.all(fx_outputs == outputs)
|
||||
assert outputs.shape[0] == batch_size
|
||||
assert not torch.isnan(outputs).any(), 'Output included NaNs'
|
||||
|
||||
|
||||
@pytest.mark.fxbackward
|
||||
@pytest.mark.timeout(120)
|
||||
@pytest.mark.parametrize('model_name', list_models(
|
||||
exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FX_FILTERS, name_matches_cfg=True))
|
||||
@pytest.mark.parametrize('batch_size', [2])
|
||||
def test_model_backward_fx(model_name, batch_size):
|
||||
"""Symbolically trace each model and run single backward pass through the resulting GraphModule"""
|
||||
if not has_fx_feature_extraction:
|
||||
pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.")
|
||||
|
||||
input_size = _get_input_size(model_name=model_name, target=TARGET_BWD_FX_SIZE)
|
||||
if max(input_size) > MAX_BWD_FX_SIZE:
|
||||
pytest.skip("Fixed input size model > limit.")
|
||||
|
||||
model = create_model(model_name, pretrained=False, num_classes=42)
|
||||
model.train()
|
||||
num_params = sum([x.numel() for x in model.parameters()])
|
||||
if 'GITHUB_ACTIONS' in os.environ and num_params > 100e6:
|
||||
pytest.skip("Skipping FX backward test on model with more than 100M params.")
|
||||
|
||||
model = _create_fx_model(model, train=True)
|
||||
outputs = tuple(model(torch.randn((batch_size, *input_size))).values())
|
||||
if isinstance(outputs, tuple):
|
||||
outputs = torch.cat(outputs)
|
||||
outputs.mean().backward()
|
||||
for n, x in model.named_parameters():
|
||||
assert x.grad is not None, f'No gradient for {n}'
|
||||
num_grad = sum([x.grad.numel() for x in model.parameters() if x.grad is not None])
|
||||
|
||||
assert outputs.shape[-1] == 42
|
||||
assert num_params == num_grad, 'Some parameters are missing gradients'
|
||||
assert not torch.isnan(outputs).any(), 'Output included NaNs'
|
||||
|
||||
|
||||
if 'GITHUB_ACTIONS' not in os.environ:
|
||||
# FIXME this test is causing GitHub actions to run out of RAM and abruptly kill the test process
|
||||
|
||||
# reason: model is scripted after fx tracing, but beit has torch.jit.is_scripting() control flow
|
||||
EXCLUDE_FX_JIT_FILTERS = [
|
||||
'deit_*_distilled_patch16_224',
|
||||
'levit*',
|
||||
'pit_*_distilled_224',
|
||||
] + EXCLUDE_FX_FILTERS
|
||||
|
||||
|
||||
@pytest.mark.fxforward
|
||||
@pytest.mark.timeout(120)
|
||||
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FX_FILTERS))
|
||||
@pytest.mark.parametrize(
|
||||
'model_name', list_models(
|
||||
exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS + EXCLUDE_FX_JIT_FILTERS, name_matches_cfg=True))
|
||||
@pytest.mark.parametrize('batch_size', [1])
|
||||
def test_model_forward_fx(model_name, batch_size):
|
||||
"""
|
||||
Symbolically trace each model and run single forward pass through the resulting GraphModule
|
||||
Also check that the output of a forward pass through the GraphModule is the same as that from the original Module
|
||||
"""
|
||||
def test_model_forward_fx_torchscript(model_name, batch_size):
|
||||
"""Symbolically trace each model, script it, and run single forward pass"""
|
||||
if not has_fx_feature_extraction:
|
||||
pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.")
|
||||
|
||||
model = create_model(model_name, pretrained=False)
|
||||
input_size = _get_input_size(model_name=model_name, target=TARGET_JIT_SIZE)
|
||||
if max(input_size) > MAX_JIT_SIZE:
|
||||
pytest.skip("Fixed input size model > limit.")
|
||||
|
||||
with set_scriptable(True):
|
||||
model = create_model(model_name, pretrained=False)
|
||||
model.eval()
|
||||
|
||||
input_size = _get_input_size(model=model, target=TARGET_FWD_FX_SIZE)
|
||||
if max(input_size) > MAX_FWD_FX_SIZE:
|
||||
pytest.skip("Fixed input size model > limit.")
|
||||
model = torch.jit.script(_create_fx_model(model))
|
||||
with torch.no_grad():
|
||||
inputs = torch.randn((batch_size, *input_size))
|
||||
outputs = model(inputs)
|
||||
outputs = tuple(model(torch.randn((batch_size, *input_size))).values())
|
||||
if isinstance(outputs, tuple):
|
||||
outputs = torch.cat(outputs)
|
||||
|
||||
model = _create_fx_model(model)
|
||||
fx_outputs = tuple(model(inputs).values())
|
||||
if isinstance(fx_outputs, tuple):
|
||||
fx_outputs = torch.cat(fx_outputs)
|
||||
|
||||
assert torch.all(fx_outputs == outputs)
|
||||
assert outputs.shape[0] == batch_size
|
||||
assert not torch.isnan(outputs).any(), 'Output included NaNs'
|
||||
|
||||
|
||||
@pytest.mark.fxbackward
|
||||
@pytest.mark.timeout(120)
|
||||
@pytest.mark.parametrize('model_name', list_models(
|
||||
exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FX_FILTERS, name_matches_cfg=True))
|
||||
@pytest.mark.parametrize('batch_size', [2])
|
||||
def test_model_backward_fx(model_name, batch_size):
|
||||
"""Symbolically trace each model and run single backward pass through the resulting GraphModule"""
|
||||
if not has_fx_feature_extraction:
|
||||
pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.")
|
||||
|
||||
input_size = _get_input_size(model_name=model_name, target=TARGET_BWD_FX_SIZE)
|
||||
if max(input_size) > MAX_BWD_FX_SIZE:
|
||||
pytest.skip("Fixed input size model > limit.")
|
||||
|
||||
model = create_model(model_name, pretrained=False, num_classes=42)
|
||||
model.train()
|
||||
num_params = sum([x.numel() for x in model.parameters()])
|
||||
if 'GITHUB_ACTIONS' in os.environ and num_params > 100e6:
|
||||
pytest.skip("Skipping FX backward test on model with more than 100M params.")
|
||||
|
||||
model = _create_fx_model(model, train=True)
|
||||
outputs = tuple(model(torch.randn((batch_size, *input_size))).values())
|
||||
if isinstance(outputs, tuple):
|
||||
outputs = torch.cat(outputs)
|
||||
outputs.mean().backward()
|
||||
for n, x in model.named_parameters():
|
||||
assert x.grad is not None, f'No gradient for {n}'
|
||||
num_grad = sum([x.grad.numel() for x in model.parameters() if x.grad is not None])
|
||||
|
||||
assert outputs.shape[-1] == 42
|
||||
assert num_params == num_grad, 'Some parameters are missing gradients'
|
||||
assert not torch.isnan(outputs).any(), 'Output included NaNs'
|
||||
|
||||
|
||||
if 'GITHUB_ACTIONS' not in os.environ:
|
||||
# FIXME this test is causing GitHub actions to run out of RAM and abruptly kill the test process
|
||||
|
||||
# reason: model is scripted after fx tracing, but beit has torch.jit.is_scripting() control flow
|
||||
EXCLUDE_FX_JIT_FILTERS = [
|
||||
'deit_*_distilled_patch16_224',
|
||||
'levit*',
|
||||
'pit_*_distilled_224',
|
||||
] + EXCLUDE_FX_FILTERS
|
||||
|
||||
|
||||
@pytest.mark.timeout(120)
|
||||
@pytest.mark.parametrize(
|
||||
'model_name', list_models(
|
||||
exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS + EXCLUDE_FX_JIT_FILTERS, name_matches_cfg=True))
|
||||
@pytest.mark.parametrize('batch_size', [1])
|
||||
def test_model_forward_fx_torchscript(model_name, batch_size):
|
||||
"""Symbolically trace each model, script it, and run single forward pass"""
|
||||
if not has_fx_feature_extraction:
|
||||
pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.")
|
||||
|
||||
input_size = _get_input_size(model_name=model_name, target=TARGET_JIT_SIZE)
|
||||
if max(input_size) > MAX_JIT_SIZE:
|
||||
pytest.skip("Fixed input size model > limit.")
|
||||
|
||||
with set_scriptable(True):
|
||||
model = create_model(model_name, pretrained=False)
|
||||
model.eval()
|
||||
|
||||
model = torch.jit.script(_create_fx_model(model))
|
||||
with torch.no_grad():
|
||||
outputs = tuple(model(torch.randn((batch_size, *input_size))).values())
|
||||
if isinstance(outputs, tuple):
|
||||
outputs = torch.cat(outputs)
|
||||
|
||||
assert outputs.shape[0] == batch_size
|
||||
assert not torch.isnan(outputs).any(), 'Output included NaNs'
|
||||
|
@ -20,6 +20,7 @@ from .evo_norm import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2,\
|
||||
EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a
|
||||
from .fast_norm import is_fast_norm, set_fast_norm, fast_group_norm, fast_layer_norm
|
||||
from .filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d
|
||||
from .format import Format, get_channel_dim, get_spatial_dim, nchw_to, nhwc_to
|
||||
from .gather_excite import GatherExcite
|
||||
from .global_context import GlobalContext
|
||||
from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple
|
||||
|
@ -9,31 +9,37 @@ Both a functional and a nn.Module version of the pooling is provided.
|
||||
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .format import get_spatial_dim, get_channel_dim
|
||||
|
||||
_int_tuple_2_t = Union[int, Tuple[int, int]]
|
||||
|
||||
|
||||
def adaptive_pool_feat_mult(pool_type='avg'):
|
||||
if pool_type == 'catavgmax':
|
||||
if pool_type.endswith('catavgmax'):
|
||||
return 2
|
||||
else:
|
||||
return 1
|
||||
|
||||
|
||||
def adaptive_avgmax_pool2d(x, output_size=1):
|
||||
def adaptive_avgmax_pool2d(x, output_size: _int_tuple_2_t = 1):
|
||||
x_avg = F.adaptive_avg_pool2d(x, output_size)
|
||||
x_max = F.adaptive_max_pool2d(x, output_size)
|
||||
return 0.5 * (x_avg + x_max)
|
||||
|
||||
|
||||
def adaptive_catavgmax_pool2d(x, output_size=1):
|
||||
def adaptive_catavgmax_pool2d(x, output_size: _int_tuple_2_t = 1):
|
||||
x_avg = F.adaptive_avg_pool2d(x, output_size)
|
||||
x_max = F.adaptive_max_pool2d(x, output_size)
|
||||
return torch.cat((x_avg, x_max), 1)
|
||||
|
||||
|
||||
def select_adaptive_pool2d(x, pool_type='avg', output_size=1):
|
||||
def select_adaptive_pool2d(x, pool_type='avg', output_size: _int_tuple_2_t = 1):
|
||||
"""Selectable global pooling function with dynamic input kernel size
|
||||
"""
|
||||
if pool_type == 'avg':
|
||||
@ -49,17 +55,56 @@ def select_adaptive_pool2d(x, pool_type='avg', output_size=1):
|
||||
return x
|
||||
|
||||
|
||||
class FastAdaptiveAvgPool2d(nn.Module):
|
||||
def __init__(self, flatten=False):
|
||||
super(FastAdaptiveAvgPool2d, self).__init__()
|
||||
class FastAdaptiveAvgPool(nn.Module):
|
||||
def __init__(self, flatten: bool = False, input_fmt: F = 'NCHW'):
|
||||
super(FastAdaptiveAvgPool, self).__init__()
|
||||
self.flatten = flatten
|
||||
self.dim = get_spatial_dim(input_fmt)
|
||||
|
||||
def forward(self, x):
|
||||
return x.mean((2, 3), keepdim=not self.flatten)
|
||||
return x.mean(self.dim, keepdim=not self.flatten)
|
||||
|
||||
|
||||
class FastAdaptiveMaxPool(nn.Module):
|
||||
def __init__(self, flatten: bool = False, input_fmt: str = 'NCHW'):
|
||||
super(FastAdaptiveMaxPool, self).__init__()
|
||||
self.flatten = flatten
|
||||
self.dim = get_spatial_dim(input_fmt)
|
||||
|
||||
def forward(self, x):
|
||||
return x.amax(self.dim, keepdim=not self.flatten)
|
||||
|
||||
|
||||
class FastAdaptiveAvgMaxPool(nn.Module):
|
||||
def __init__(self, flatten: bool = False, input_fmt: str = 'NCHW'):
|
||||
super(FastAdaptiveAvgMaxPool, self).__init__()
|
||||
self.flatten = flatten
|
||||
self.dim = get_spatial_dim(input_fmt)
|
||||
|
||||
def forward(self, x):
|
||||
x_avg = x.mean(self.dim, keepdim=not self.flatten)
|
||||
x_max = x.amax(self.dim, keepdim=not self.flatten)
|
||||
return 0.5 * x_avg + 0.5 * x_max
|
||||
|
||||
|
||||
class FastAdaptiveCatAvgMaxPool(nn.Module):
|
||||
def __init__(self, flatten: bool = False, input_fmt: str = 'NCHW'):
|
||||
super(FastAdaptiveCatAvgMaxPool, self).__init__()
|
||||
self.flatten = flatten
|
||||
self.dim_reduce = get_spatial_dim(input_fmt)
|
||||
if flatten:
|
||||
self.dim_cat = 1
|
||||
else:
|
||||
self.dim_cat = get_channel_dim(input_fmt)
|
||||
|
||||
def forward(self, x):
|
||||
x_avg = x.mean(self.dim_reduce, keepdim=not self.flatten)
|
||||
x_max = x.amax(self.dim_reduce, keepdim=not self.flatten)
|
||||
return torch.cat((x_avg, x_max), self.dim_cat)
|
||||
|
||||
|
||||
class AdaptiveAvgMaxPool2d(nn.Module):
|
||||
def __init__(self, output_size=1):
|
||||
def __init__(self, output_size: _int_tuple_2_t = 1):
|
||||
super(AdaptiveAvgMaxPool2d, self).__init__()
|
||||
self.output_size = output_size
|
||||
|
||||
@ -68,7 +113,7 @@ class AdaptiveAvgMaxPool2d(nn.Module):
|
||||
|
||||
|
||||
class AdaptiveCatAvgMaxPool2d(nn.Module):
|
||||
def __init__(self, output_size=1):
|
||||
def __init__(self, output_size: _int_tuple_2_t = 1):
|
||||
super(AdaptiveCatAvgMaxPool2d, self).__init__()
|
||||
self.output_size = output_size
|
||||
|
||||
@ -79,26 +124,41 @@ class AdaptiveCatAvgMaxPool2d(nn.Module):
|
||||
class SelectAdaptivePool2d(nn.Module):
|
||||
"""Selectable global pooling layer with dynamic input kernel size
|
||||
"""
|
||||
def __init__(self, output_size=1, pool_type='fast', flatten=False):
|
||||
def __init__(
|
||||
self,
|
||||
output_size: _int_tuple_2_t = 1,
|
||||
pool_type: str = 'fast',
|
||||
flatten: bool = False,
|
||||
input_fmt: str = 'NCHW',
|
||||
):
|
||||
super(SelectAdaptivePool2d, self).__init__()
|
||||
assert input_fmt in ('NCHW', 'NHWC')
|
||||
self.pool_type = pool_type or '' # convert other falsy values to empty string for consistent TS typing
|
||||
self.flatten = nn.Flatten(1) if flatten else nn.Identity()
|
||||
if pool_type == '':
|
||||
if not pool_type:
|
||||
self.pool = nn.Identity() # pass through
|
||||
elif pool_type == 'fast':
|
||||
assert output_size == 1
|
||||
self.pool = FastAdaptiveAvgPool2d(flatten)
|
||||
self.flatten = nn.Flatten(1) if flatten else nn.Identity()
|
||||
elif pool_type.startswith('fast') or input_fmt != 'NCHW':
|
||||
assert output_size == 1, 'Fast pooling and non NCHW input formats require output_size == 1.'
|
||||
if pool_type.endswith('avgmax'):
|
||||
self.pool = FastAdaptiveAvgMaxPool(flatten, input_fmt=input_fmt)
|
||||
elif pool_type.endswith('catavgmax'):
|
||||
self.pool = FastAdaptiveCatAvgMaxPool(flatten, input_fmt=input_fmt)
|
||||
elif pool_type.endswith('max'):
|
||||
self.pool = FastAdaptiveMaxPool(flatten, input_fmt=input_fmt)
|
||||
else:
|
||||
self.pool = FastAdaptiveAvgPool(flatten, input_fmt=input_fmt)
|
||||
self.flatten = nn.Identity()
|
||||
elif pool_type == 'avg':
|
||||
self.pool = nn.AdaptiveAvgPool2d(output_size)
|
||||
elif pool_type == 'avgmax':
|
||||
self.pool = AdaptiveAvgMaxPool2d(output_size)
|
||||
elif pool_type == 'catavgmax':
|
||||
self.pool = AdaptiveCatAvgMaxPool2d(output_size)
|
||||
elif pool_type == 'max':
|
||||
self.pool = nn.AdaptiveMaxPool2d(output_size)
|
||||
else:
|
||||
assert False, 'Invalid pool type: %s' % pool_type
|
||||
assert input_fmt == 'NCHW'
|
||||
if pool_type == 'avgmax':
|
||||
self.pool = AdaptiveAvgMaxPool2d(output_size)
|
||||
elif pool_type == 'catavgmax':
|
||||
self.pool = AdaptiveCatAvgMaxPool2d(output_size)
|
||||
elif pool_type == 'max':
|
||||
self.pool = nn.AdaptiveMaxPool2d(output_size)
|
||||
else:
|
||||
self.pool = nn.AdaptiveAvgPool2d(output_size)
|
||||
self.flatten = nn.Flatten(1) if flatten else nn.Identity()
|
||||
|
||||
def is_identity(self):
|
||||
return not self.pool_type
|
||||
|
@ -15,13 +15,23 @@ from .create_act import get_act_layer
|
||||
from .create_norm import get_norm_layer
|
||||
|
||||
|
||||
def _create_pool(num_features, num_classes, pool_type='avg', use_conv=False):
|
||||
def _create_pool(
|
||||
num_features: int,
|
||||
num_classes: int,
|
||||
pool_type: str = 'avg',
|
||||
use_conv: bool = False,
|
||||
input_fmt: Optional[str] = None,
|
||||
):
|
||||
flatten_in_pool = not use_conv # flatten when we use a Linear layer after pooling
|
||||
if not pool_type:
|
||||
assert num_classes == 0 or use_conv,\
|
||||
'Pooling can only be disabled if classifier is also removed or conv classifier is used'
|
||||
flatten_in_pool = False # disable flattening if pooling is pass-through (no pooling)
|
||||
global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=flatten_in_pool)
|
||||
global_pool = SelectAdaptivePool2d(
|
||||
pool_type=pool_type,
|
||||
flatten=flatten_in_pool,
|
||||
input_fmt=input_fmt,
|
||||
)
|
||||
num_pooled_features = num_features * global_pool.feat_mult()
|
||||
return global_pool, num_pooled_features
|
||||
|
||||
@ -36,9 +46,25 @@ def _create_fc(num_features, num_classes, use_conv=False):
|
||||
return fc
|
||||
|
||||
|
||||
def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False):
|
||||
global_pool, num_pooled_features = _create_pool(num_features, num_classes, pool_type, use_conv=use_conv)
|
||||
fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv)
|
||||
def create_classifier(
|
||||
num_features: int,
|
||||
num_classes: int,
|
||||
pool_type: str = 'avg',
|
||||
use_conv: bool = False,
|
||||
input_fmt: str = 'NCHW',
|
||||
):
|
||||
global_pool, num_pooled_features = _create_pool(
|
||||
num_features,
|
||||
num_classes,
|
||||
pool_type,
|
||||
use_conv=use_conv,
|
||||
input_fmt=input_fmt,
|
||||
)
|
||||
fc = _create_fc(
|
||||
num_pooled_features,
|
||||
num_classes,
|
||||
use_conv=use_conv,
|
||||
)
|
||||
return global_pool, fc
|
||||
|
||||
|
||||
@ -52,6 +78,7 @@ class ClassifierHead(nn.Module):
|
||||
pool_type: str = 'avg',
|
||||
drop_rate: float = 0.,
|
||||
use_conv: bool = False,
|
||||
input_fmt: str = 'NCHW',
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@ -64,28 +91,43 @@ class ClassifierHead(nn.Module):
|
||||
self.drop_rate = drop_rate
|
||||
self.in_features = in_features
|
||||
self.use_conv = use_conv
|
||||
self.input_fmt = input_fmt
|
||||
|
||||
self.global_pool, num_pooled_features = _create_pool(in_features, num_classes, pool_type, use_conv=use_conv)
|
||||
self.fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv)
|
||||
self.global_pool, self.fc = create_classifier(
|
||||
in_features,
|
||||
num_classes,
|
||||
pool_type,
|
||||
use_conv=use_conv,
|
||||
input_fmt=input_fmt,
|
||||
)
|
||||
self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity()
|
||||
|
||||
def reset(self, num_classes, global_pool=None):
|
||||
if global_pool is not None:
|
||||
if global_pool != self.global_pool.pool_type:
|
||||
self.global_pool, _ = _create_pool(self.in_features, num_classes, global_pool, use_conv=self.use_conv)
|
||||
self.flatten = nn.Flatten(1) if self.use_conv and global_pool else nn.Identity()
|
||||
num_pooled_features = self.in_features * self.global_pool.feat_mult()
|
||||
self.fc = _create_fc(num_pooled_features, num_classes, use_conv=self.use_conv)
|
||||
def reset(self, num_classes, pool_type=None):
|
||||
if pool_type is not None and pool_type != self.global_pool.pool_type:
|
||||
self.global_pool, self.fc = create_classifier(
|
||||
self.in_features,
|
||||
num_classes,
|
||||
pool_type=pool_type,
|
||||
use_conv=self.use_conv,
|
||||
input_fmt=self.input_fmt,
|
||||
)
|
||||
self.flatten = nn.Flatten(1) if self.use_conv and pool_type else nn.Identity()
|
||||
else:
|
||||
num_pooled_features = self.in_features * self.global_pool.feat_mult()
|
||||
self.fc = _create_fc(
|
||||
num_pooled_features,
|
||||
num_classes,
|
||||
use_conv=self.use_conv,
|
||||
)
|
||||
|
||||
def forward(self, x, pre_logits: bool = False):
|
||||
x = self.global_pool(x)
|
||||
if self.drop_rate:
|
||||
x = F.dropout(x, p=float(self.drop_rate), training=self.training)
|
||||
if pre_logits:
|
||||
return x.flatten(1)
|
||||
else:
|
||||
x = self.fc(x)
|
||||
return self.flatten(x)
|
||||
x = self.fc(x)
|
||||
return self.flatten(x)
|
||||
|
||||
|
||||
class NormMlpClassifierHead(nn.Module):
|
||||
|
58
timm/layers/format.py
Normal file
58
timm/layers/format.py
Normal file
@ -0,0 +1,58 @@
|
||||
from enum import Enum
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class Format(str, Enum):
|
||||
NCHW = 'NCHW'
|
||||
NHWC = 'NHWC'
|
||||
NCL = 'NCL'
|
||||
NLC = 'NLC'
|
||||
|
||||
|
||||
FormatT = Union[str, Format]
|
||||
|
||||
|
||||
def get_spatial_dim(fmt: FormatT):
|
||||
fmt = Format(fmt)
|
||||
if fmt is Format.NLC:
|
||||
dim = (1,)
|
||||
elif fmt is Format.NCL:
|
||||
dim = (2,)
|
||||
elif fmt is Format.NHWC:
|
||||
dim = (1, 2)
|
||||
else:
|
||||
dim = (2, 3)
|
||||
return dim
|
||||
|
||||
|
||||
def get_channel_dim(fmt: FormatT):
|
||||
fmt = Format(fmt)
|
||||
if fmt is Format.NHWC:
|
||||
dim = 3
|
||||
elif fmt is Format.NLC:
|
||||
dim = 2
|
||||
else:
|
||||
dim = 1
|
||||
return dim
|
||||
|
||||
|
||||
def nchw_to(x: torch.Tensor, fmt: Format):
|
||||
if fmt == Format.NHWC:
|
||||
x = x.permute(0, 2, 3, 1)
|
||||
elif fmt == Format.NLC:
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
elif fmt == Format.NCL:
|
||||
x = x.flatten(2)
|
||||
return x
|
||||
|
||||
|
||||
def nhwc_to(x: torch.Tensor, fmt: Format):
|
||||
if fmt == Format.NCHW:
|
||||
x = x.permute(0, 3, 1, 2)
|
||||
elif fmt == Format.NLC:
|
||||
x = x.flatten(1, 2)
|
||||
elif fmt == Format.NCL:
|
||||
x = x.flatten(1, 2).transpose(1, 2)
|
||||
return x
|
@ -9,12 +9,13 @@ Based on code in:
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import logging
|
||||
from typing import List
|
||||
from typing import List, Optional, Callable
|
||||
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .format import Format, nchw_to
|
||||
from .helpers import to_2tuple
|
||||
from .trace_utils import _assert
|
||||
|
||||
@ -24,15 +25,18 @@ _logger = logging.getLogger(__name__)
|
||||
class PatchEmbed(nn.Module):
|
||||
""" 2D Image to Patch Embedding
|
||||
"""
|
||||
output_fmt: Format
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_chans=3,
|
||||
embed_dim=768,
|
||||
norm_layer=None,
|
||||
flatten=True,
|
||||
bias=True,
|
||||
img_size: int = 224,
|
||||
patch_size: int = 16,
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
norm_layer: Optional[Callable] = None,
|
||||
flatten: bool = True,
|
||||
output_fmt: Optional[str] = None,
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
img_size = to_2tuple(img_size)
|
||||
@ -41,7 +45,13 @@ class PatchEmbed(nn.Module):
|
||||
self.patch_size = patch_size
|
||||
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
||||
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
||||
self.flatten = flatten
|
||||
if output_fmt is not None:
|
||||
self.flatten = False
|
||||
self.output_fmt = Format(output_fmt)
|
||||
else:
|
||||
# flatten spatial dim and transpose to channels last, kept for bwd compat
|
||||
self.flatten = flatten
|
||||
self.output_fmt = Format.NCHW
|
||||
|
||||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
|
||||
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
||||
@ -52,7 +62,9 @@ class PatchEmbed(nn.Module):
|
||||
_assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")
|
||||
x = self.proj(x)
|
||||
if self.flatten:
|
||||
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
|
||||
elif self.output_fmt != Format.NCHW:
|
||||
x = nchw_to(x, self.output_fmt)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
|
@ -15,26 +15,48 @@ from .weight_init import trunc_normal_
|
||||
|
||||
def gen_relative_position_index(
|
||||
q_size: Tuple[int, int],
|
||||
k_size: Tuple[int, int] = None,
|
||||
class_token: bool = False) -> torch.Tensor:
|
||||
k_size: Optional[Tuple[int, int]] = None,
|
||||
class_token: bool = False,
|
||||
) -> torch.Tensor:
|
||||
# Adapted with significant modifications from Swin / BeiT codebases
|
||||
# get pair-wise relative position index for each token inside the window
|
||||
q_coords = torch.stack(torch.meshgrid([torch.arange(q_size[0]), torch.arange(q_size[1])])).flatten(1) # 2, Wh, Ww
|
||||
if k_size is None:
|
||||
k_coords = q_coords
|
||||
k_size = q_size
|
||||
coords = torch.stack(
|
||||
torch.meshgrid([
|
||||
torch.arange(q_size[0]),
|
||||
torch.arange(q_size[1])
|
||||
])
|
||||
).flatten(1) # 2, Wh, Ww
|
||||
relative_coords = coords[:, :, None] - coords[:, None, :] # 2, Wh*Ww, Wh*Ww
|
||||
relative_coords = relative_coords.permute(1, 2, 0) # Qh*Qw, Kh*Kw, 2
|
||||
num_relative_distance = (2 * q_size[0] - 1) * (2 * q_size[1] - 1) + 3
|
||||
else:
|
||||
# different q vs k sizes is a WIP
|
||||
k_coords = torch.stack(torch.meshgrid([torch.arange(k_size[0]), torch.arange(k_size[1])])).flatten(1)
|
||||
relative_coords = q_coords[:, :, None] - k_coords[:, None, :] # 2, Wh*Ww, Wh*Ww
|
||||
relative_coords = relative_coords.permute(1, 2, 0) # Wh*Ww, Wh*Ww, 2
|
||||
# FIXME different q vs k sizes is a WIP, need to better offset the two grids?
|
||||
q_coords = torch.stack(
|
||||
torch.meshgrid([
|
||||
torch.arange(q_size[0]),
|
||||
torch.arange(q_size[1])
|
||||
])
|
||||
).flatten(1) # 2, Wh, Ww
|
||||
k_coords = torch.stack(
|
||||
torch.meshgrid([
|
||||
torch.arange(k_size[0]),
|
||||
torch.arange(k_size[1])
|
||||
])
|
||||
).flatten(1)
|
||||
relative_coords = q_coords[:, :, None] - k_coords[:, None, :] # 2, Wh*Ww, Wh*Ww
|
||||
relative_coords = relative_coords.permute(1, 2, 0) # Qh*Qw, Kh*Kw, 2
|
||||
# relative_coords[:, :, 0] += max(q_size[0], k_size[0]) - 1 # shift to start from 0
|
||||
# relative_coords[:, :, 1] += max(q_size[1], k_size[1]) - 1
|
||||
# relative_coords[:, :, 0] *= k_size[1] + q_size[1] - 1
|
||||
# relative_position_index = relative_coords.sum(-1) # Qh*Qw, Kh*Kw
|
||||
num_relative_distance = (q_size[0] + k_size[0] - 1) * (q_size[1] + q_size[1] - 1) + 3
|
||||
|
||||
_, relative_position_index = torch.unique(relative_coords.view(-1, 2), return_inverse=True, dim=0)
|
||||
|
||||
if class_token:
|
||||
# handle cls to token & token 2 cls & cls to cls as per beit for rel pos bias
|
||||
# NOTE not intended or tested with MLP log-coords
|
||||
max_size = (max(q_size[0], k_size[0]), max(q_size[1], k_size[1]))
|
||||
num_relative_distance = (2 * max_size[0] - 1) * (2 * max_size[1] - 1) + 3
|
||||
relative_position_index = F.pad(relative_position_index, [1, 0, 1, 0])
|
||||
relative_position_index[0, 0:] = num_relative_distance - 3
|
||||
relative_position_index[0:, 0] = num_relative_distance - 2
|
||||
@ -59,7 +81,7 @@ class RelPosBias(nn.Module):
|
||||
self.relative_position_bias_table = nn.Parameter(torch.zeros(num_relative_distance, num_heads))
|
||||
self.register_buffer(
|
||||
"relative_position_index",
|
||||
gen_relative_position_index(self.window_size, class_token=prefix_tokens > 0),
|
||||
gen_relative_position_index(self.window_size, class_token=prefix_tokens > 0).view(-1),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
@ -69,7 +91,7 @@ class RelPosBias(nn.Module):
|
||||
trunc_normal_(self.relative_position_bias_table, std=.02)
|
||||
|
||||
def get_bias(self) -> torch.Tensor:
|
||||
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)]
|
||||
relative_position_bias = self.relative_position_bias_table[self.relative_position_index]
|
||||
# win_h * win_w, win_h * win_w, num_heads
|
||||
relative_position_bias = relative_position_bias.view(self.bias_shape).permute(2, 0, 1)
|
||||
return relative_position_bias.unsqueeze(0).contiguous()
|
||||
@ -148,7 +170,7 @@ class RelPosMlp(nn.Module):
|
||||
|
||||
self.register_buffer(
|
||||
"relative_position_index",
|
||||
gen_relative_position_index(window_size),
|
||||
gen_relative_position_index(window_size).view(-1),
|
||||
persistent=False)
|
||||
|
||||
# get relative_coords_table
|
||||
@ -160,8 +182,7 @@ class RelPosMlp(nn.Module):
|
||||
def get_bias(self) -> torch.Tensor:
|
||||
relative_position_bias = self.mlp(self.rel_coords_log)
|
||||
if self.relative_position_index is not None:
|
||||
relative_position_bias = relative_position_bias.view(-1, self.num_heads)[
|
||||
self.relative_position_index.view(-1)] # Wh*Ww,Wh*Ww,nH
|
||||
relative_position_bias = relative_position_bias.view(-1, self.num_heads)[self.relative_position_index]
|
||||
relative_position_bias = relative_position_bias.view(self.bias_shape)
|
||||
relative_position_bias = relative_position_bias.permute(2, 0, 1)
|
||||
relative_position_bias = self.bias_act(relative_position_bias)
|
||||
|
@ -72,7 +72,8 @@ from ._builder import build_model_with_cfg, load_pretrained, load_custom_pretrai
|
||||
from ._factory import create_model, parse_model_name, safe_model_name
|
||||
from ._features import FeatureInfo, FeatureHooks, FeatureHookNet, FeatureListNet, FeatureDictNet
|
||||
from ._features_fx import FeatureGraphNet, GraphExtractNet, create_feature_extractor, \
|
||||
register_notrace_module, register_notrace_function
|
||||
register_notrace_module, is_notrace_module, get_notrace_modules, \
|
||||
register_notrace_function, is_notrace_function, get_notrace_functions
|
||||
from ._helpers import clean_state_dict, load_state_dict, load_checkpoint, remap_checkpoint, resume_checkpoint
|
||||
from ._hub import load_model_config_from_hf, load_state_dict_from_hf, push_to_hf_hub
|
||||
from ._manipulate import model_parameters, named_apply, named_modules, named_modules_with_params, \
|
||||
|
@ -392,6 +392,9 @@ def build_model_with_cfg(
|
||||
# Wrap the model in a feature extraction module if enabled
|
||||
if features:
|
||||
feature_cls = FeatureListNet
|
||||
output_fmt = getattr(model, 'output_fmt', None)
|
||||
if output_fmt is not None:
|
||||
feature_cfg.setdefault('output_fmt', output_fmt)
|
||||
if 'feature_cls' in feature_cfg:
|
||||
feature_cls = feature_cfg.pop('feature_cls')
|
||||
if isinstance(feature_cls, str):
|
||||
@ -403,7 +406,7 @@ def build_model_with_cfg(
|
||||
else:
|
||||
assert False, f'Unknown feature class {feature_cls}'
|
||||
model = feature_cls(model, **feature_cfg)
|
||||
model.pretrained_cfg = pretrained_cfg_for_features(pretrained_cfg) # add back default_cfg
|
||||
model.default_cfg = model.pretrained_cfg # alias for backwards compat
|
||||
model.pretrained_cfg = pretrained_cfg_for_features(pretrained_cfg) # add back pretrained cfg
|
||||
model.default_cfg = model.pretrained_cfg # alias for rename backwards compat (default_cfg -> pretrained_cfg)
|
||||
|
||||
return model
|
||||
|
@ -17,6 +17,8 @@ import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from timm.layers import Format
|
||||
|
||||
|
||||
__all__ = ['FeatureInfo', 'FeatureHooks', 'FeatureDictNet', 'FeatureListNet', 'FeatureHookNet']
|
||||
|
||||
@ -98,7 +100,7 @@ class FeatureHooks:
|
||||
self,
|
||||
hooks: Sequence[str],
|
||||
named_modules: dict,
|
||||
out_map: Sequence[Union[int, str]] = None,
|
||||
return_map: Sequence[Union[int, str]] = None,
|
||||
default_hook_type: str = 'forward',
|
||||
):
|
||||
# setup feature hooks
|
||||
@ -107,7 +109,7 @@ class FeatureHooks:
|
||||
for i, h in enumerate(hooks):
|
||||
hook_name = h['module']
|
||||
m = modules[hook_name]
|
||||
hook_id = out_map[i] if out_map else hook_name
|
||||
hook_id = return_map[i] if return_map else hook_name
|
||||
hook_fn = partial(self._collect_output_hook, hook_id)
|
||||
hook_type = h.get('hook_type', default_hook_type)
|
||||
if hook_type == 'forward_pre':
|
||||
@ -153,11 +155,11 @@ def _get_feature_info(net, out_indices):
|
||||
assert False, "Provided feature_info is not valid"
|
||||
|
||||
|
||||
def _get_return_layers(feature_info, out_map):
|
||||
def _get_return_layers(feature_info, return_map):
|
||||
module_names = feature_info.module_name()
|
||||
return_layers = {}
|
||||
for i, name in enumerate(module_names):
|
||||
return_layers[name] = out_map[i] if out_map is not None else feature_info.out_indices[i]
|
||||
return_layers[name] = return_map[i] if return_map is not None else feature_info.out_indices[i]
|
||||
return return_layers
|
||||
|
||||
|
||||
@ -180,7 +182,8 @@ class FeatureDictNet(nn.ModuleDict):
|
||||
self,
|
||||
model: nn.Module,
|
||||
out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
|
||||
out_map: Sequence[Union[int, str]] = None,
|
||||
return_map: Sequence[Union[int, str]] = None,
|
||||
output_fmt: str = 'NCHW',
|
||||
feature_concat: bool = False,
|
||||
flatten_sequential: bool = False,
|
||||
):
|
||||
@ -188,18 +191,19 @@ class FeatureDictNet(nn.ModuleDict):
|
||||
Args:
|
||||
model: Model from which to extract features.
|
||||
out_indices: Output indices of the model features to extract.
|
||||
out_map: Return id mapping for each output index, otherwise str(index) is used.
|
||||
return_map: Return id mapping for each output index, otherwise str(index) is used.
|
||||
feature_concat: Concatenate intermediate features that are lists or tuples instead of selecting
|
||||
first element e.g. `x[0]`
|
||||
flatten_sequential: Flatten first two-levels of sequential modules in model (re-writes model modules)
|
||||
"""
|
||||
super(FeatureDictNet, self).__init__()
|
||||
self.feature_info = _get_feature_info(model, out_indices)
|
||||
self.output_fmt = Format(output_fmt)
|
||||
self.concat = feature_concat
|
||||
self.grad_checkpointing = False
|
||||
self.return_layers = {}
|
||||
|
||||
return_layers = _get_return_layers(self.feature_info, out_map)
|
||||
return_layers = _get_return_layers(self.feature_info, return_map)
|
||||
modules = _module_list(model, flatten_sequential=flatten_sequential)
|
||||
remaining = set(return_layers.keys())
|
||||
layers = OrderedDict()
|
||||
@ -253,6 +257,7 @@ class FeatureListNet(FeatureDictNet):
|
||||
self,
|
||||
model: nn.Module,
|
||||
out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
|
||||
output_fmt: str = 'NCHW',
|
||||
feature_concat: bool = False,
|
||||
flatten_sequential: bool = False,
|
||||
):
|
||||
@ -264,9 +269,10 @@ class FeatureListNet(FeatureDictNet):
|
||||
first element e.g. `x[0]`
|
||||
flatten_sequential: Flatten first two-levels of sequential modules in model (re-writes model modules)
|
||||
"""
|
||||
super(FeatureListNet, self).__init__(
|
||||
super().__init__(
|
||||
model,
|
||||
out_indices=out_indices,
|
||||
output_fmt=output_fmt,
|
||||
feature_concat=feature_concat,
|
||||
flatten_sequential=flatten_sequential,
|
||||
)
|
||||
@ -292,8 +298,9 @@ class FeatureHookNet(nn.ModuleDict):
|
||||
self,
|
||||
model: nn.Module,
|
||||
out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
|
||||
out_map: Sequence[Union[int, str]] = None,
|
||||
out_as_dict: bool = False,
|
||||
return_map: Sequence[Union[int, str]] = None,
|
||||
return_dict: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
no_rewrite: bool = False,
|
||||
flatten_sequential: bool = False,
|
||||
default_hook_type: str = 'forward',
|
||||
@ -303,17 +310,18 @@ class FeatureHookNet(nn.ModuleDict):
|
||||
Args:
|
||||
model: Model from which to extract features.
|
||||
out_indices: Output indices of the model features to extract.
|
||||
out_map: Return id mapping for each output index, otherwise str(index) is used.
|
||||
out_as_dict: Output features as a dict.
|
||||
return_map: Return id mapping for each output index, otherwise str(index) is used.
|
||||
return_dict: Output features as a dict.
|
||||
no_rewrite: Enforce that model is not re-written if True, ie no modules are removed / changed.
|
||||
flatten_sequential arg must also be False if this is set True.
|
||||
flatten_sequential: Re-write modules by flattening first two levels of nn.Sequential containers.
|
||||
default_hook_type: The default hook type to use if not specified in model.feature_info.
|
||||
"""
|
||||
super(FeatureHookNet, self).__init__()
|
||||
super().__init__()
|
||||
assert not torch.jit.is_scripting()
|
||||
self.feature_info = _get_feature_info(model, out_indices)
|
||||
self.out_as_dict = out_as_dict
|
||||
self.return_dict = return_dict
|
||||
self.output_fmt = Format(output_fmt)
|
||||
self.grad_checkpointing = False
|
||||
|
||||
layers = OrderedDict()
|
||||
@ -340,7 +348,7 @@ class FeatureHookNet(nn.ModuleDict):
|
||||
break
|
||||
assert not remaining, f'Return layers ({remaining}) are not present in model'
|
||||
self.update(layers)
|
||||
self.hooks = FeatureHooks(hooks, model.named_modules(), out_map=out_map)
|
||||
self.hooks = FeatureHooks(hooks, model.named_modules(), return_map=return_map)
|
||||
|
||||
def set_grad_checkpointing(self, enable: bool = True):
|
||||
self.grad_checkpointing = enable
|
||||
@ -356,4 +364,4 @@ class FeatureHookNet(nn.ModuleDict):
|
||||
else:
|
||||
x = module(x)
|
||||
out = self.hooks.get_output(x.device)
|
||||
return out if self.out_as_dict else list(out.values())
|
||||
return out if self.return_dict else list(out.values())
|
||||
|
@ -19,6 +19,11 @@ from timm.layers import Conv2dSame, ScaledStdConv2dSame, CondConv2d, StdConv2dSa
|
||||
from timm.layers.non_local_attn import BilinearAttnTransform
|
||||
from timm.layers.pool2d_same import MaxPool2dSame, AvgPool2dSame
|
||||
|
||||
__all__ = ['register_notrace_module', 'is_notrace_module', 'get_notrace_modules',
|
||||
'register_notrace_function', 'is_notrace_function', 'get_notrace_functions',
|
||||
'create_feature_extractor', 'FeatureGraphNet', 'GraphExtractNet']
|
||||
|
||||
|
||||
# NOTE: By default, any modules from timm.models.layers that we want to treat as leaf modules go here
|
||||
# BUT modules from timm.models should use the registration mechanism below
|
||||
_leaf_modules = {
|
||||
@ -35,10 +40,6 @@ except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
__all__ = ['register_notrace_module', 'register_notrace_function', 'create_feature_extractor',
|
||||
'FeatureGraphNet', 'GraphExtractNet']
|
||||
|
||||
|
||||
def register_notrace_module(module: Type[nn.Module]):
|
||||
"""
|
||||
Any module not under timm.models.layers should get this decorator if we don't want to trace through it.
|
||||
@ -47,6 +48,14 @@ def register_notrace_module(module: Type[nn.Module]):
|
||||
return module
|
||||
|
||||
|
||||
def is_notrace_module(module: Type[nn.Module]):
|
||||
return module in _leaf_modules
|
||||
|
||||
|
||||
def get_notrace_modules():
|
||||
return list(_leaf_modules)
|
||||
|
||||
|
||||
# Functions we want to autowrap (treat them as leaves)
|
||||
_autowrap_functions = set()
|
||||
|
||||
@ -59,6 +68,14 @@ def register_notrace_function(func: Callable):
|
||||
return func
|
||||
|
||||
|
||||
def is_notrace_function(func: Callable):
|
||||
return func in _autowrap_functions
|
||||
|
||||
|
||||
def get_notrace_functions():
|
||||
return list(_autowrap_functions)
|
||||
|
||||
|
||||
def create_feature_extractor(model: nn.Module, return_nodes: Union[Dict[str, str], List[str]]):
|
||||
assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
|
||||
return _create_feature_extractor(
|
||||
|
@ -406,7 +406,7 @@ class ConvNeXt(nn.Module):
|
||||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes=0, global_pool=None):
|
||||
self.head.reset(num_classes, global_pool=global_pool)
|
||||
self.head.reset(num_classes, global_pool)
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.stem(x)
|
||||
@ -415,7 +415,7 @@ class ConvNeXt(nn.Module):
|
||||
return x
|
||||
|
||||
def forward_head(self, x, pre_logits: bool = False):
|
||||
return self.head(x, pre_logits=pre_logits)
|
||||
return self.head(x, pre_logits=True) if pre_logits else self.head(x)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
|
@ -359,10 +359,9 @@ class DLA(nn.Module):
|
||||
if self.drop_rate > 0.:
|
||||
x = F.dropout(x, p=self.drop_rate, training=self.training)
|
||||
if pre_logits:
|
||||
return x.flatten(1)
|
||||
else:
|
||||
x = self.fc(x)
|
||||
return self.flatten(x)
|
||||
x = self.fc(x)
|
||||
return self.flatten(x)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
|
@ -298,10 +298,9 @@ class DPN(nn.Module):
|
||||
if self.drop_rate > 0.:
|
||||
x = F.dropout(x, p=self.drop_rate, training=self.training)
|
||||
if pre_logits:
|
||||
return x.flatten(1)
|
||||
else:
|
||||
x = self.classifier(x)
|
||||
return self.flatten(x)
|
||||
x = self.classifier(x)
|
||||
return self.flatten(x)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
|
@ -18,6 +18,7 @@ This impl is/has:
|
||||
# Written by Jianwei Yang (jianwyan@microsoft.com)
|
||||
# --------------------------------------------------------
|
||||
from functools import partial
|
||||
from typing import Callable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -35,15 +36,15 @@ __all__ = ['FocalNet']
|
||||
class FocalModulation(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim: int,
|
||||
focal_window,
|
||||
focal_level,
|
||||
focal_factor=2,
|
||||
bias=True,
|
||||
use_post_norm=False,
|
||||
normalize_modulator=False,
|
||||
proj_drop=0.,
|
||||
norm_layer=LayerNorm2d,
|
||||
focal_level: int,
|
||||
focal_factor: int = 2,
|
||||
bias: bool = True,
|
||||
use_post_norm: bool = False,
|
||||
normalize_modulator: bool = False,
|
||||
proj_drop: float = 0.,
|
||||
norm_layer: Callable = LayerNorm2d,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -118,36 +119,38 @@ class LayerScale2d(nn.Module):
|
||||
|
||||
|
||||
class FocalNetBlock(nn.Module):
|
||||
r""" Focal Modulation Network Block.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
||||
proj_drop (float, optional): Dropout rate. Default: 0.0
|
||||
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
||||
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
||||
focal_level (int): Number of focal levels.
|
||||
focal_window (int): Focal window size at first focal level
|
||||
layerscale_value (float): Initial layerscale value
|
||||
use_post_norm (bool): Whether to use layernorm after modulation
|
||||
""" Focal Modulation Network Block.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
mlp_ratio=4.,
|
||||
focal_level=1,
|
||||
focal_window=3,
|
||||
use_post_norm=False,
|
||||
use_post_norm_in_modulation=False,
|
||||
normalize_modulator=False,
|
||||
layerscale_value=1e-4,
|
||||
proj_drop=0.,
|
||||
drop_path=0.,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=LayerNorm2d,
|
||||
dim: int,
|
||||
mlp_ratio: float = 4.,
|
||||
focal_level: int = 1,
|
||||
focal_window: int = 3,
|
||||
use_post_norm: bool = False,
|
||||
use_post_norm_in_modulation: bool = False,
|
||||
normalize_modulator: bool = False,
|
||||
layerscale_value: float = 1e-4,
|
||||
proj_drop: float = 0.,
|
||||
drop_path: float = 0.,
|
||||
act_layer: Callable = nn.GELU,
|
||||
norm_layer: Callable = LayerNorm2d,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
dim: Number of input channels.
|
||||
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
|
||||
focal_level: Number of focal levels.
|
||||
focal_window: Focal window size at first focal level.
|
||||
use_post_norm: Whether to use layer norm after modulation.
|
||||
use_post_norm_in_modulation: Whether to use layer norm in modulation.
|
||||
layerscale_value: Initial layerscale value.
|
||||
proj_drop: Dropout rate.
|
||||
drop_path: Stochastic depth rate.
|
||||
act_layer: Activation layer.
|
||||
norm_layer: Normalization layer.
|
||||
"""
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.mlp_ratio = mlp_ratio
|
||||
@ -197,42 +200,45 @@ class FocalNetBlock(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class BasicLayer(nn.Module):
|
||||
class FocalNetStage(nn.Module):
|
||||
""" A basic Focal Transformer layer for one stage.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
depth (int): Number of blocks.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
||||
drop (float, optional): Dropout rate. Default: 0.0
|
||||
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
||||
downsample (bool): Downsample layer at start of the layer. Default: True
|
||||
focal_level (int): Number of focal levels
|
||||
focal_window (int): Focal window size at first focal level
|
||||
layerscale_value (float): Initial layerscale value
|
||||
use_post_norm (bool): Whether to use layer norm after modulation
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
out_dim,
|
||||
depth,
|
||||
mlp_ratio=4.,
|
||||
downsample=True,
|
||||
focal_level=1,
|
||||
focal_window=1,
|
||||
use_overlap_down=False,
|
||||
use_post_norm=False,
|
||||
use_post_norm_in_modulation=False,
|
||||
normalize_modulator=False,
|
||||
layerscale_value=1e-4,
|
||||
proj_drop=0.,
|
||||
drop_path=0.,
|
||||
norm_layer=LayerNorm2d,
|
||||
dim: int,
|
||||
out_dim: int,
|
||||
depth: int,
|
||||
mlp_ratio: float = 4.,
|
||||
downsample: bool = True,
|
||||
focal_level: int = 1,
|
||||
focal_window: int = 1,
|
||||
use_overlap_down: bool = False,
|
||||
use_post_norm: bool = False,
|
||||
use_post_norm_in_modulation: bool = False,
|
||||
normalize_modulator: bool = False,
|
||||
layerscale_value: float = 1e-4,
|
||||
proj_drop: float = 0.,
|
||||
drop_path: float = 0.,
|
||||
norm_layer: Callable = LayerNorm2d,
|
||||
):
|
||||
|
||||
"""
|
||||
Args:
|
||||
dim: Number of input channels.
|
||||
out_dim: Number of output channels.
|
||||
depth: Number of blocks.
|
||||
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
|
||||
downsample: Downsample layer at start of the layer.
|
||||
focal_level: Number of focal levels
|
||||
focal_window: Focal window size at first focal level
|
||||
use_overlap_down: User overlapped convolution in downsample layer.
|
||||
use_post_norm: Whether to use layer norm after modulation.
|
||||
use_post_norm_in_modulation: Whether to use layer norm in modulation.
|
||||
layerscale_value: Initial layerscale value
|
||||
proj_drop: Dropout rate for projections.
|
||||
drop_path: Stochastic depth rate.
|
||||
norm_layer: Normalization layer.
|
||||
"""
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.depth = depth
|
||||
@ -281,22 +287,24 @@ class BasicLayer(nn.Module):
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
r"""
|
||||
Args:
|
||||
in_chs (int): Number of input image channels
|
||||
out_chs (int): Number of linear projection output channels
|
||||
stride (int): Downsample stride. Default: 4.
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_chs,
|
||||
out_chs,
|
||||
stride=4,
|
||||
overlap=False,
|
||||
norm_layer=None,
|
||||
in_chs: int,
|
||||
out_chs: int,
|
||||
stride: int = 4,
|
||||
overlap: bool = False,
|
||||
norm_layer: Optional[Callable] = None,
|
||||
):
|
||||
"""
|
||||
|
||||
Args:
|
||||
in_chs: Number of input image channels.
|
||||
out_chs: Number of linear projection output channels.
|
||||
stride: Downsample stride.
|
||||
overlap: Use overlapping convolutions if True.
|
||||
norm_layer: Normalization layer.
|
||||
"""
|
||||
super().__init__()
|
||||
self.stride = stride
|
||||
padding = 0
|
||||
@ -317,49 +325,47 @@ class Downsample(nn.Module):
|
||||
|
||||
|
||||
class FocalNet(nn.Module):
|
||||
r""" Focal Modulation Networks (FocalNets)
|
||||
|
||||
Args:
|
||||
in_chans (int): Number of input image channels. Default: 3
|
||||
num_classes (int): Number of classes for classification head. Default: 1000
|
||||
embed_dim (int): Patch embedding dimension. Default: 96
|
||||
depths (tuple(int)): Depth of each Focal Transformer layer.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
|
||||
focal_levels (list): How many focal levels at all stages. Note that this excludes the finest-grain level.
|
||||
Default: [1, 1, 1, 1]
|
||||
focal_windows (list): The focal window size at all stages. Default: [7, 5, 3, 1]
|
||||
use_overlap_down (bool): Whether to use convolutional embedding.
|
||||
use_post_norm (bool): Whether to use layernorm after modulation (it helps stablize training of large models)
|
||||
layerscale_value (float): Value for layer scale. Default: 1e-4
|
||||
drop_rate (float): Dropout rate. Default: 0
|
||||
drop_path_rate (float): Stochastic depth rate. Default: 0.1
|
||||
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
||||
|
||||
"""" Focal Modulation Networks (FocalNets)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_chans=3,
|
||||
num_classes=1000,
|
||||
global_pool='avg',
|
||||
embed_dim=96,
|
||||
depths=(2, 2, 6, 2),
|
||||
mlp_ratio=4.,
|
||||
focal_levels=(2, 2, 2, 2),
|
||||
focal_windows=(3, 3, 3, 3),
|
||||
use_overlap_down=False,
|
||||
use_post_norm=False,
|
||||
use_post_norm_in_modulation=False,
|
||||
normalize_modulator=False,
|
||||
head_hidden_size=None,
|
||||
head_init_scale=1.0,
|
||||
layerscale_value=None,
|
||||
drop_rate=0.,
|
||||
proj_drop_rate=0.,
|
||||
drop_path_rate=0.1,
|
||||
norm_layer=partial(LayerNorm2d, eps=1e-5),
|
||||
**kwargs,
|
||||
in_chans: int = 3,
|
||||
num_classes: int = 1000,
|
||||
global_pool: str = 'avg',
|
||||
embed_dim: int = 96,
|
||||
depths: Tuple[int, ...] = (2, 2, 6, 2),
|
||||
mlp_ratio: float = 4.,
|
||||
focal_levels: Tuple[int, ...] = (2, 2, 2, 2),
|
||||
focal_windows: Tuple[int, ...] = (3, 3, 3, 3),
|
||||
use_overlap_down: bool = False,
|
||||
use_post_norm: bool = False,
|
||||
use_post_norm_in_modulation: bool = False,
|
||||
normalize_modulator: bool = False,
|
||||
head_hidden_size: Optional[int] = None,
|
||||
head_init_scale: float = 1.0,
|
||||
layerscale_value: Optional[float] = None,
|
||||
drop_rate: bool = 0.,
|
||||
proj_drop_rate: bool = 0.,
|
||||
drop_path_rate: bool = 0.1,
|
||||
norm_layer: Callable = partial(LayerNorm2d, eps=1e-5),
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
in_chans: Number of input image channels.
|
||||
num_classes: Number of classes for classification head.
|
||||
embed_dim: Patch embedding dimension.
|
||||
depths: Depth of each Focal Transformer layer.
|
||||
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
|
||||
focal_levels: How many focal levels at all stages. Note that this excludes the finest-grain level.
|
||||
focal_windows: The focal window size at all stages.
|
||||
use_overlap_down: Whether to use convolutional embedding.
|
||||
use_post_norm: Whether to use layernorm after modulation (it helps stablize training of large models)
|
||||
layerscale_value: Value for layer scale.
|
||||
drop_rate: Dropout rate.
|
||||
drop_path_rate: Stochastic depth rate.
|
||||
norm_layer: Normalization layer.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.num_layers = len(depths)
|
||||
@ -382,7 +388,7 @@ class FocalNet(nn.Module):
|
||||
layers = []
|
||||
for i_layer in range(self.num_layers):
|
||||
out_dim = embed_dim[i_layer]
|
||||
layer = BasicLayer(
|
||||
layer = FocalNetStage(
|
||||
dim=in_dim,
|
||||
out_dim=out_dim,
|
||||
depth=depths[i_layer],
|
||||
@ -438,10 +444,10 @@ class FocalNet(nn.Module):
|
||||
|
||||
@torch.jit.ignore
|
||||
def get_classifier(self):
|
||||
return self.classifier.fc
|
||||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
self.classifier.reset(num_classes, global_pool=global_pool)
|
||||
self.head.reset(num_classes, pool_type=global_pool)
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.stem(x)
|
||||
@ -475,7 +481,7 @@ def _init_weights(module, name=None, head_init_scale=1.0):
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||
'crop_pct': .9, 'interpolation': 'bicubic',
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'stem.proj', 'classifier': 'head.fc',
|
||||
@ -498,19 +504,19 @@ default_cfgs = {
|
||||
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_base_lrf.pth'),
|
||||
"focalnet_large_fl3": _cfg(
|
||||
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_large_lrf_384.pth',
|
||||
input_size=(3, 384, 384), crop_pct=1.0, num_classes=21842),
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=21842),
|
||||
"focalnet_large_fl4": _cfg(
|
||||
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_large_lrf_384_fl4.pth',
|
||||
input_size=(3, 384, 384), crop_pct=1.0, num_classes=21842),
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=21842),
|
||||
"focalnet_xlarge_fl3": _cfg(
|
||||
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_xlarge_lrf_384.pth',
|
||||
input_size=(3, 384, 384), crop_pct=1.0, num_classes=21842),
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=21842),
|
||||
"focalnet_xlarge_fl4": _cfg(
|
||||
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_xlarge_lrf_384_fl4.pth',
|
||||
input_size=(3, 384, 384), crop_pct=1.0, num_classes=21842),
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=21842),
|
||||
"focalnet_huge_fl3": _cfg(
|
||||
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_huge_lrf_224.pth',
|
||||
num_classes=0),
|
||||
num_classes=21842),
|
||||
"focalnet_huge_fl4": _cfg(
|
||||
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_huge_lrf_224_fl4.pth',
|
||||
num_classes=0),
|
||||
@ -533,7 +539,7 @@ def checkpoint_filter_fn(state_dict, model: FocalNet):
|
||||
k = re.sub(r'norm([0-9])', r'norm\1_post', k)
|
||||
k = k.replace('ln.', 'norm.')
|
||||
k = k.replace('head', 'head.fc')
|
||||
if dest_dict[k].shape != v.shape:
|
||||
if k in dest_dict and dest_dict[k].numel() == v.numel() and dest_dict[k].shape != v.shape:
|
||||
v = v.reshape(dest_dict[k].shape)
|
||||
out_dict[k] = v
|
||||
return out_dict
|
||||
|
@ -29,12 +29,11 @@ import torch.utils.checkpoint as checkpoint
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import DropPath, to_2tuple, to_ntuple, Mlp, ClassifierHead, LayerNorm2d, \
|
||||
get_attn, get_act_layer, get_norm_layer, _assert
|
||||
get_attn, get_act_layer, get_norm_layer, RelPosBias, _assert
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features_fx import register_notrace_function
|
||||
from ._manipulate import named_apply
|
||||
from ._registry import register_model
|
||||
from .vision_transformer_relpos import RelPosBias # FIXME move to common location
|
||||
|
||||
__all__ = ['GlobalContextVit']
|
||||
|
||||
@ -222,7 +221,7 @@ class WindowAttentionGlobal(nn.Module):
|
||||
q, k, v = qkv.unbind(0)
|
||||
q = q * self.scale
|
||||
|
||||
attn = (q @ k.transpose(-2, -1))
|
||||
attn = q @ k.transpose(-2, -1).contiguous() # NOTE contiguous() fixes an odd jit bug in PyTorch 2.0
|
||||
attn = self.rel_pos(attn)
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
@ -145,13 +145,12 @@ class MobileNetV3(nn.Module):
|
||||
x = self.global_pool(x)
|
||||
x = self.conv_head(x)
|
||||
x = self.act2(x)
|
||||
x = self.flatten(x)
|
||||
if pre_logits:
|
||||
return x.flatten(1)
|
||||
else:
|
||||
x = self.flatten(x)
|
||||
if self.drop_rate > 0.:
|
||||
x = F.dropout(x, p=self.drop_rate, training=self.training)
|
||||
return self.classifier(x)
|
||||
return x
|
||||
if self.drop_rate > 0.:
|
||||
x = F.dropout(x, p=self.drop_rate, training=self.training)
|
||||
return self.classifier(x)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
|
@ -17,86 +17,24 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W
|
||||
# --------------------------------------------------------
|
||||
import logging
|
||||
import math
|
||||
from typing import Optional
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, to_ntuple, trunc_normal_, _assert
|
||||
from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, to_ntuple, trunc_normal_, _assert, ClassifierHead
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features_fx import register_notrace_function
|
||||
from ._manipulate import checkpoint_seq, named_apply
|
||||
from ._registry import register_model
|
||||
from .vision_transformer import checkpoint_filter_fn, get_init_weights_vit
|
||||
from .vision_transformer import get_init_weights_vit
|
||||
|
||||
__all__ = ['SwinTransformer'] # model_registry will add each entrypoint fn to this
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
||||
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'patch_embed.proj', 'classifier': 'head',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
'swin_base_patch4_window12_384': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22kto1k.pth',
|
||||
input_size=(3, 384, 384), crop_pct=1.0),
|
||||
|
||||
'swin_base_patch4_window7_224': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22kto1k.pth',
|
||||
),
|
||||
|
||||
'swin_large_patch4_window12_384': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22kto1k.pth',
|
||||
input_size=(3, 384, 384), crop_pct=1.0),
|
||||
|
||||
'swin_large_patch4_window7_224': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22kto1k.pth',
|
||||
),
|
||||
|
||||
'swin_small_patch4_window7_224': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth',
|
||||
),
|
||||
|
||||
'swin_tiny_patch4_window7_224': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth',
|
||||
),
|
||||
|
||||
'swin_base_patch4_window12_384_in22k': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth',
|
||||
input_size=(3, 384, 384), crop_pct=1.0, num_classes=21841),
|
||||
|
||||
'swin_base_patch4_window7_224_in22k': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth',
|
||||
num_classes=21841),
|
||||
|
||||
'swin_large_patch4_window12_384_in22k': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth',
|
||||
input_size=(3, 384, 384), crop_pct=1.0, num_classes=21841),
|
||||
|
||||
'swin_large_patch4_window7_224_in22k': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth',
|
||||
num_classes=21841),
|
||||
|
||||
'swin_s3_tiny_224': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/s3_t-1d53f6a8.pth'
|
||||
),
|
||||
'swin_s3_small_224': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/s3_s-3bb4c69d.pth'
|
||||
),
|
||||
'swin_s3_base_224': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/s3_b-a1e95db4.pth'
|
||||
)
|
||||
}
|
||||
_int_or_tuple_2_t = Union[int, Tuple[int, int]]
|
||||
|
||||
|
||||
def window_partition(x, window_size: int):
|
||||
@ -132,7 +70,7 @@ def window_reverse(windows, window_size: int, H: int, W: int):
|
||||
return x
|
||||
|
||||
|
||||
def get_relative_position_index(win_h, win_w):
|
||||
def get_relative_position_index(win_h: int, win_w: int):
|
||||
# get pair-wise relative position index for each token inside the window
|
||||
coords = torch.stack(torch.meshgrid([torch.arange(win_h), torch.arange(win_w)])) # 2, Wh, Ww
|
||||
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
||||
@ -145,21 +83,30 @@ def get_relative_position_index(win_h, win_w):
|
||||
|
||||
|
||||
class WindowAttention(nn.Module):
|
||||
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
||||
It supports both of shifted and non-shifted window.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
num_heads (int): Number of attention heads.
|
||||
head_dim (int): Number of channels per head (dim // num_heads if not set)
|
||||
window_size (tuple[int]): The height and width of the window.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
||||
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
||||
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
||||
""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
||||
It supports shifted and non-shifted windows.
|
||||
"""
|
||||
|
||||
def __init__(self, dim, num_heads, head_dim=None, window_size=7, qkv_bias=True, attn_drop=0., proj_drop=0.):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
head_dim: Optional[int] = None,
|
||||
window_size: _int_or_tuple_2_t = 7,
|
||||
qkv_bias: bool = True,
|
||||
attn_drop: float = 0.,
|
||||
proj_drop: float = 0.,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
dim: Number of input channels.
|
||||
num_heads: Number of attention heads.
|
||||
head_dim: Number of channels per head (dim // num_heads if not set)
|
||||
window_size: The height and width of the window.
|
||||
qkv_bias: If True, add a learnable bias to query, key, value.
|
||||
attn_drop: Dropout ratio of attention weight.
|
||||
proj_drop: Dropout ratio of output.
|
||||
"""
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.window_size = to_2tuple(window_size) # Wh, Ww
|
||||
@ -198,7 +145,7 @@ class WindowAttention(nn.Module):
|
||||
"""
|
||||
B_, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
||||
q, k, v = qkv.unbind(0)
|
||||
|
||||
q = q * self.scale
|
||||
attn = (q @ k.transpose(-2, -1))
|
||||
@ -206,7 +153,7 @@ class WindowAttention(nn.Module):
|
||||
|
||||
if mask is not None:
|
||||
num_win = mask.shape[0]
|
||||
attn = attn.view(B_ // num_win, num_win, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
||||
attn = attn.view(-1, num_win, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
||||
attn = attn.view(-1, self.num_heads, N, N)
|
||||
attn = self.softmax(attn)
|
||||
else:
|
||||
@ -221,28 +168,41 @@ class WindowAttention(nn.Module):
|
||||
|
||||
|
||||
class SwinTransformerBlock(nn.Module):
|
||||
r""" Swin Transformer Block.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
input_resolution (tuple[int]): Input resulotion.
|
||||
window_size (int): Window size.
|
||||
num_heads (int): Number of attention heads.
|
||||
head_dim (int): Enforce the number of channels per head
|
||||
shift_size (int): Shift size for SW-MSA.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
||||
drop (float, optional): Dropout rate. Default: 0.0
|
||||
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
||||
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
||||
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
||||
""" Swin Transformer Block.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, dim, input_resolution, num_heads=4, head_dim=None, window_size=7, shift_size=0,
|
||||
mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
|
||||
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
||||
self,
|
||||
dim: int,
|
||||
input_resolution: _int_or_tuple_2_t,
|
||||
num_heads: int = 4,
|
||||
head_dim: Optional[int] = None,
|
||||
window_size: _int_or_tuple_2_t = 7,
|
||||
shift_size: int = 0,
|
||||
mlp_ratio: float = 4.,
|
||||
qkv_bias: bool = True,
|
||||
drop: float = 0.,
|
||||
attn_drop: float = 0.,
|
||||
drop_path: float = 0.,
|
||||
act_layer: Callable = nn.GELU,
|
||||
norm_layer: Callable = nn.LayerNorm,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
dim: Number of input channels.
|
||||
input_resolution: Input resolution.
|
||||
window_size: Window size.
|
||||
num_heads: Number of attention heads.
|
||||
head_dim: Enforce the number of channels per head
|
||||
shift_size: Shift size for SW-MSA.
|
||||
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias: If True, add a learnable bias to query, key, value.
|
||||
drop: Dropout rate.
|
||||
attn_drop: Attention dropout rate.
|
||||
drop_path: Stochastic depth rate.
|
||||
act_layer: Activation layer.
|
||||
norm_layer: Normalization layer.
|
||||
"""
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.input_resolution = input_resolution
|
||||
@ -257,12 +217,23 @@ class SwinTransformerBlock(nn.Module):
|
||||
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = WindowAttention(
|
||||
dim, num_heads=num_heads, head_dim=head_dim, window_size=to_2tuple(self.window_size),
|
||||
qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim,
|
||||
window_size=to_2tuple(self.window_size),
|
||||
qkv_bias=qkv_bias,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop,
|
||||
)
|
||||
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
|
||||
self.mlp = Mlp(
|
||||
in_features=dim,
|
||||
hidden_features=int(dim * mlp_ratio),
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
)
|
||||
|
||||
if self.shift_size > 0:
|
||||
# calculate attention mask for SW-MSA
|
||||
@ -285,17 +256,15 @@ class SwinTransformerBlock(nn.Module):
|
||||
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
||||
else:
|
||||
attn_mask = None
|
||||
|
||||
self.register_buffer("attn_mask", attn_mask)
|
||||
|
||||
def forward(self, x):
|
||||
H, W = self.input_resolution
|
||||
B, L, C = x.shape
|
||||
_assert(L == H * W, "input feature has wrong size")
|
||||
B, H, W, C = x.shape
|
||||
_assert(H == self.input_resolution[0], "input feature has wrong size")
|
||||
_assert(W == self.input_resolution[1], "input feature has wrong size")
|
||||
|
||||
shortcut = x
|
||||
x = self.norm1(x)
|
||||
x = x.view(B, H, W, C)
|
||||
|
||||
# cyclic shift
|
||||
if self.shift_size > 0:
|
||||
@ -319,176 +288,227 @@ class SwinTransformerBlock(nn.Module):
|
||||
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
||||
else:
|
||||
x = shifted_x
|
||||
x = x.view(B, H * W, C)
|
||||
|
||||
# FFN
|
||||
x = shortcut + self.drop_path(x)
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||
|
||||
x = x.reshape(B, -1, C)
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||
x = x.reshape(B, H, W, C)
|
||||
return x
|
||||
|
||||
|
||||
class PatchMerging(nn.Module):
|
||||
r""" Patch Merging Layer.
|
||||
|
||||
Args:
|
||||
input_resolution (tuple[int]): Resolution of input feature.
|
||||
dim (int): Number of input channels.
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
||||
""" Patch Merging Layer.
|
||||
"""
|
||||
|
||||
def __init__(self, input_resolution, dim, out_dim=None, norm_layer=nn.LayerNorm):
|
||||
def __init__(self, dim: int, out_dim: Optional[int] = None, norm_layer: Callable = nn.LayerNorm):
|
||||
"""
|
||||
Args:
|
||||
dim: Number of input channels.
|
||||
out_dim: Number of output channels (or 2 * dim if None)
|
||||
norm_layer: Normalization layer.
|
||||
"""
|
||||
super().__init__()
|
||||
self.input_resolution = input_resolution
|
||||
self.dim = dim
|
||||
self.out_dim = out_dim or 2 * dim
|
||||
self.norm = norm_layer(4 * dim)
|
||||
self.reduction = nn.Linear(4 * dim, self.out_dim, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
x: B, H*W, C
|
||||
"""
|
||||
H, W = self.input_resolution
|
||||
B, L, C = x.shape
|
||||
_assert(L == H * W, "input feature has wrong size")
|
||||
_assert(H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even.")
|
||||
|
||||
x = x.view(B, H, W, C)
|
||||
|
||||
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
||||
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
||||
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
||||
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
||||
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
||||
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
||||
|
||||
B, H, W, C = x.shape
|
||||
_assert(H % 2 == 0, f"x height ({H}) is not even.")
|
||||
_assert(W % 2 == 0, f"x width ({W}) is not even.")
|
||||
x = x.reshape(B, H // 2, 2, W // 2, 2, C).permute(0, 1, 3, 4, 2, 5).flatten(3)
|
||||
x = self.norm(x)
|
||||
x = self.reduction(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class BasicLayer(nn.Module):
|
||||
class SwinTransformerStage(nn.Module):
|
||||
""" A basic Swin Transformer layer for one stage.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
input_resolution (tuple[int]): Input resolution.
|
||||
depth (int): Number of blocks.
|
||||
num_heads (int): Number of attention heads.
|
||||
head_dim (int): Channels per head (dim // num_heads if not set)
|
||||
window_size (int): Local window size.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
||||
drop (float, optional): Dropout rate. Default: 0.0
|
||||
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
||||
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
||||
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, dim, out_dim, input_resolution, depth, num_heads=4, head_dim=None,
|
||||
window_size=7, mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
|
||||
drop_path=0., norm_layer=nn.LayerNorm, downsample=None):
|
||||
|
||||
self,
|
||||
dim: int,
|
||||
out_dim: int,
|
||||
input_resolution: Tuple[int, int],
|
||||
depth: int,
|
||||
downsample: bool = True,
|
||||
num_heads: int = 4,
|
||||
head_dim: Optional[int] = None,
|
||||
window_size: _int_or_tuple_2_t = 7,
|
||||
mlp_ratio: float = 4.,
|
||||
qkv_bias: bool = True,
|
||||
drop: float = 0.,
|
||||
attn_drop: float = 0.,
|
||||
drop_path: Union[List[float], float] = 0.,
|
||||
norm_layer: Callable = nn.LayerNorm,
|
||||
output_nchw: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
dim: Number of input channels.
|
||||
input_resolution: Input resolution.
|
||||
depth: Number of blocks.
|
||||
num_heads: Number of attention heads.
|
||||
head_dim: Channels per head (dim // num_heads if not set)
|
||||
window_size: Local window size.
|
||||
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias: If True, add a learnable bias to query, key, value.
|
||||
drop: Dropout rate.
|
||||
attn_drop: Attention dropout rate.
|
||||
drop_path: Stochastic depth rate.
|
||||
norm_layer: Normalization layer.
|
||||
downsample: Downsample layer at the end of the layer.
|
||||
"""
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.input_resolution = input_resolution
|
||||
self.output_resolution = tuple(i // 2 for i in input_resolution) if downsample else input_resolution
|
||||
self.depth = depth
|
||||
self.use_nchw = output_nchw
|
||||
self.grad_checkpointing = False
|
||||
|
||||
# patch merging layer
|
||||
if downsample:
|
||||
self.downsample = PatchMerging(
|
||||
dim=dim,
|
||||
out_dim=out_dim,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
else:
|
||||
assert dim == out_dim
|
||||
self.downsample = nn.Identity()
|
||||
|
||||
# build blocks
|
||||
self.blocks = nn.Sequential(*[
|
||||
SwinTransformerBlock(
|
||||
dim=dim, input_resolution=input_resolution, num_heads=num_heads, head_dim=head_dim,
|
||||
window_size=window_size, shift_size=0 if (i % 2 == 0) else window_size // 2,
|
||||
mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop, attn_drop=attn_drop,
|
||||
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer)
|
||||
dim=out_dim,
|
||||
input_resolution=self.output_resolution,
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim,
|
||||
window_size=window_size,
|
||||
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
drop=drop,
|
||||
attn_drop=attn_drop,
|
||||
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
for i in range(depth)])
|
||||
|
||||
# patch merging layer
|
||||
if downsample is not None:
|
||||
self.downsample = downsample(input_resolution, dim=dim, out_dim=out_dim, norm_layer=norm_layer)
|
||||
else:
|
||||
self.downsample = None
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_nchw:
|
||||
x = x.permute(0, 2, 3, 1) # NCHW -> NHWC
|
||||
|
||||
x = self.downsample(x)
|
||||
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
x = checkpoint_seq(self.blocks, x)
|
||||
else:
|
||||
x = self.blocks(x)
|
||||
if self.downsample is not None:
|
||||
x = self.downsample(x)
|
||||
|
||||
if self.use_nchw:
|
||||
x = x.permute(0, 3, 1, 2) # NHWC -> NCHW
|
||||
return x
|
||||
|
||||
|
||||
class SwinTransformer(nn.Module):
|
||||
r""" Swin Transformer
|
||||
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
|
||||
https://arxiv.org/pdf/2103.14030
|
||||
""" Swin Transformer
|
||||
|
||||
Args:
|
||||
img_size (int | tuple(int)): Input image size. Default 224
|
||||
patch_size (int | tuple(int)): Patch size. Default: 4
|
||||
in_chans (int): Number of input image channels. Default: 3
|
||||
num_classes (int): Number of classes for classification head. Default: 1000
|
||||
embed_dim (int): Patch embedding dimension. Default: 96
|
||||
depths (tuple(int)): Depth of each Swin Transformer layer.
|
||||
num_heads (tuple(int)): Number of attention heads in different layers.
|
||||
head_dim (int, tuple(int)):
|
||||
window_size (int): Window size. Default: 7
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
|
||||
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
||||
drop_rate (float): Dropout rate. Default: 0
|
||||
attn_drop_rate (float): Attention dropout rate. Default: 0
|
||||
drop_path_rate (float): Stochastic depth rate. Default: 0.1
|
||||
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
||||
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
|
||||
patch_norm (bool): If True, add normalization after patch embedding. Default: True
|
||||
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
|
||||
https://arxiv.org/pdf/2103.14030
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, global_pool='avg',
|
||||
embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), head_dim=None,
|
||||
window_size=7, mlp_ratio=4., qkv_bias=True,
|
||||
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
||||
norm_layer=nn.LayerNorm, ape=False, patch_norm=True, weight_init='', **kwargs):
|
||||
self,
|
||||
img_size: _int_or_tuple_2_t = 224,
|
||||
patch_size: int = 4,
|
||||
in_chans: int = 3,
|
||||
num_classes: int = 1000,
|
||||
global_pool: str = 'avg',
|
||||
embed_dim: int = 96,
|
||||
depths: Tuple[int, ...] = (2, 2, 6, 2),
|
||||
num_heads: Tuple[int, ...] = (3, 6, 12, 24),
|
||||
head_dim: Optional[int] = None,
|
||||
window_size: _int_or_tuple_2_t = 7,
|
||||
mlp_ratio: float = 4.,
|
||||
qkv_bias: bool = True,
|
||||
drop_rate: float = 0.,
|
||||
attn_drop_rate: float = 0.,
|
||||
drop_path_rate: float = 0.1,
|
||||
norm_layer: Union[str, Callable] = nn.LayerNorm,
|
||||
weight_init: str = '',
|
||||
output_fmt: str = 'NHWC',
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
img_size: Input image size.
|
||||
patch_size: Patch size.
|
||||
in_chans: Number of input image channels.
|
||||
num_classes: Number of classes for classification head.
|
||||
embed_dim: Patch embedding dimension.
|
||||
depths: Depth of each Swin Transformer layer.
|
||||
num_heads: Number of attention heads in different layers.
|
||||
head_dim: Dimension of self-attention heads.
|
||||
window_size: Window size.
|
||||
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias: If True, add a learnable bias to query, key, value.
|
||||
drop_rate: Dropout rate.
|
||||
attn_drop_rate (float): Attention dropout rate.
|
||||
drop_path_rate (float): Stochastic depth rate.
|
||||
norm_layer (nn.Module): Normalization layer.
|
||||
"""
|
||||
super().__init__()
|
||||
assert global_pool in ('', 'avg')
|
||||
assert output_fmt in ('NCHW', 'NHWC')
|
||||
self.num_classes = num_classes
|
||||
self.global_pool = global_pool
|
||||
self.output_fmt = output_fmt
|
||||
|
||||
self.num_layers = len(depths)
|
||||
self.embed_dim = embed_dim
|
||||
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
|
||||
self.output_nchw = self.output_fmt == 'NCHW' # bool flag for fwd
|
||||
self.feature_info = []
|
||||
|
||||
if not isinstance(embed_dim, (tuple, list)):
|
||||
embed_dim = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
|
||||
|
||||
# split image into non-overlapping patches
|
||||
self.patch_embed = PatchEmbed(
|
||||
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
|
||||
norm_layer=norm_layer if patch_norm else None)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
embed_dim=embed_dim[0],
|
||||
norm_layer=norm_layer,
|
||||
output_fmt='NHWC',
|
||||
)
|
||||
self.patch_grid = self.patch_embed.grid_size
|
||||
|
||||
# absolute position embedding
|
||||
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) if ape else None
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
|
||||
# build layers
|
||||
if not isinstance(embed_dim, (tuple, list)):
|
||||
embed_dim = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
|
||||
embed_out_dim = embed_dim[1:] + [None]
|
||||
head_dim = to_ntuple(self.num_layers)(head_dim)
|
||||
window_size = to_ntuple(self.num_layers)(window_size)
|
||||
mlp_ratio = to_ntuple(self.num_layers)(mlp_ratio)
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
||||
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
|
||||
layers = []
|
||||
in_dim = embed_dim[0]
|
||||
scale = 1
|
||||
for i in range(self.num_layers):
|
||||
layers += [BasicLayer(
|
||||
dim=embed_dim[i],
|
||||
out_dim=embed_out_dim[i],
|
||||
input_resolution=(self.patch_grid[0] // (2 ** i), self.patch_grid[1] // (2 ** i)),
|
||||
out_dim = embed_dim[i]
|
||||
layers += [SwinTransformerStage(
|
||||
dim=in_dim,
|
||||
out_dim=out_dim,
|
||||
input_resolution=(
|
||||
self.patch_grid[0] // scale,
|
||||
self.patch_grid[1] // scale
|
||||
),
|
||||
depth=depths[i],
|
||||
downsample=i > 0,
|
||||
num_heads=num_heads[i],
|
||||
head_dim=head_dim[i],
|
||||
window_size=window_size[i],
|
||||
@ -496,29 +516,36 @@ class SwinTransformer(nn.Module):
|
||||
qkv_bias=qkv_bias,
|
||||
drop=drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[sum(depths[:i]):sum(depths[:i + 1])],
|
||||
drop_path=dpr[i],
|
||||
norm_layer=norm_layer,
|
||||
downsample=PatchMerging if (i < self.num_layers - 1) else None
|
||||
output_nchw=self.output_nchw,
|
||||
)]
|
||||
in_dim = out_dim
|
||||
if i > 0:
|
||||
scale *= 2
|
||||
self.feature_info += [dict(num_chs=out_dim, reduction=4 * scale, module=f'layers.{i}')]
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
self.norm = norm_layer(self.num_features)
|
||||
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
self.head = ClassifierHead(
|
||||
self.num_features,
|
||||
num_classes,
|
||||
pool_type=global_pool,
|
||||
drop_rate=drop_rate,
|
||||
input_fmt=self.output_fmt,
|
||||
)
|
||||
if weight_init != 'skip':
|
||||
self.init_weights(weight_init)
|
||||
|
||||
@torch.jit.ignore
|
||||
def init_weights(self, mode=''):
|
||||
assert mode in ('jax', 'jax_nlhb', 'moco', '')
|
||||
if self.absolute_pos_embed is not None:
|
||||
trunc_normal_(self.absolute_pos_embed, std=.02)
|
||||
head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
|
||||
named_apply(get_init_weights_vit(mode, head_bias=head_bias), self)
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay(self):
|
||||
nwd = {'absolute_pos_embed'}
|
||||
nwd = set()
|
||||
for n, _ in self.named_parameters():
|
||||
if 'relative_position_bias_table' in n:
|
||||
nwd.add(n)
|
||||
@ -527,7 +554,7 @@ class SwinTransformer(nn.Module):
|
||||
@torch.jit.ignore
|
||||
def group_matcher(self, coarse=False):
|
||||
return dict(
|
||||
stem=r'^absolute_pos_embed|patch_embed', # stem and embed
|
||||
stem=r'^patch_embed', # stem and embed
|
||||
blocks=r'^layers\.(\d+)' if coarse else [
|
||||
(r'^layers\.(\d+).downsample', (0,)),
|
||||
(r'^layers\.(\d+)\.\w+\.(\d+)', None),
|
||||
@ -542,28 +569,26 @@ class SwinTransformer(nn.Module):
|
||||
|
||||
@torch.jit.ignore
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
self.num_classes = num_classes
|
||||
if global_pool is not None:
|
||||
assert global_pool in ('', 'avg')
|
||||
self.global_pool = global_pool
|
||||
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
self.head.reset(num_classes, pool_type=global_pool)
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.patch_embed(x)
|
||||
if self.absolute_pos_embed is not None:
|
||||
x = x + self.absolute_pos_embed
|
||||
x = self.pos_drop(x)
|
||||
if self.output_nchw:
|
||||
# patch embed always outputs NHWC, stage layers expect NCHW input
|
||||
x = x.permute(0, 3, 1, 2)
|
||||
x = self.layers(x)
|
||||
x = self.norm(x) # B L C
|
||||
if self.output_nchw:
|
||||
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||
else:
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
def forward_head(self, x, pre_logits: bool = False):
|
||||
if self.global_pool == 'avg':
|
||||
x = x.mean(dim=1)
|
||||
return x if pre_logits else self.head(x)
|
||||
return self.head(x, pre_logits=True) if pre_logits else self.head(x)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
@ -571,15 +596,102 @@ class SwinTransformer(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def checkpoint_filter_fn(
|
||||
state_dict,
|
||||
model,
|
||||
adapt_layer_scale=False,
|
||||
interpolation='bicubic',
|
||||
antialias=True,
|
||||
):
|
||||
""" convert patch embedding weight from manual patchify + linear proj to conv"""
|
||||
import re
|
||||
out_dict = {}
|
||||
state_dict = state_dict.get('model', state_dict)
|
||||
state_dict = state_dict.get('state_dict', state_dict)
|
||||
for k, v in state_dict.items():
|
||||
k = re.sub(r'layers.(\d+).downsample', lambda x: f'layers.{int(x.group(1)) + 1}.downsample', k)
|
||||
k = k.replace('head.', 'head.fc.')
|
||||
out_dict[k] = v
|
||||
return out_dict
|
||||
|
||||
|
||||
def _create_swin_transformer(variant, pretrained=False, **kwargs):
|
||||
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1))))
|
||||
out_indices = kwargs.pop('out_indices', default_out_indices)
|
||||
|
||||
model = build_model_with_cfg(
|
||||
SwinTransformer, variant, pretrained,
|
||||
pretrained_filter_fn=checkpoint_filter_fn,
|
||||
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
|
||||
**kwargs)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'patch_embed.proj', 'classifier': 'head.fc',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
'swin_base_patch4_window12_384': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22kto1k.pth',
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
|
||||
|
||||
'swin_base_patch4_window7_224': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22kto1k.pth',
|
||||
),
|
||||
|
||||
'swin_large_patch4_window12_384': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22kto1k.pth',
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
|
||||
|
||||
'swin_large_patch4_window7_224': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22kto1k.pth',
|
||||
),
|
||||
|
||||
'swin_small_patch4_window7_224': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth',
|
||||
),
|
||||
|
||||
'swin_tiny_patch4_window7_224': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth',
|
||||
),
|
||||
|
||||
'swin_base_patch4_window12_384_in22k': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth',
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=21841),
|
||||
|
||||
'swin_base_patch4_window7_224_in22k': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth',
|
||||
num_classes=21841),
|
||||
|
||||
'swin_large_patch4_window12_384_in22k': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth',
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=21841),
|
||||
|
||||
'swin_large_patch4_window7_224_in22k': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth',
|
||||
num_classes=21841),
|
||||
|
||||
'swin_s3_tiny_224': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/s3_t-1d53f6a8.pth'
|
||||
),
|
||||
'swin_s3_small_224': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/s3_s-3bb4c69d.pth'
|
||||
),
|
||||
'swin_s3_base_224': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/s3_b-a1e95db4.pth'
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@register_model
|
||||
def swin_base_patch4_window12_384(pretrained=False, **kwargs):
|
||||
""" Swin-B @ 384x384, pretrained ImageNet-22k, fine tune 1k
|
||||
|
@ -13,7 +13,7 @@ Modifications and additions for timm hacked together by / Copyright 2022, Ross W
|
||||
# Written by Ze Liu
|
||||
# --------------------------------------------------------
|
||||
import math
|
||||
from typing import Tuple, Optional
|
||||
from typing import Callable, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -21,76 +21,14 @@ import torch.nn.functional as F
|
||||
import torch.utils.checkpoint as checkpoint
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, _assert
|
||||
from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, _assert, ClassifierHead
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features_fx import register_notrace_function
|
||||
from ._registry import register_model
|
||||
|
||||
__all__ = ['SwinTransformerV2'] # model_registry will add each entrypoint fn to this
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
||||
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'patch_embed.proj', 'classifier': 'head',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
'swinv2_tiny_window8_256': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window8_256.pth',
|
||||
input_size=(3, 256, 256)
|
||||
),
|
||||
'swinv2_tiny_window16_256': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window16_256.pth',
|
||||
input_size=(3, 256, 256)
|
||||
),
|
||||
'swinv2_small_window8_256': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_small_patch4_window8_256.pth',
|
||||
input_size=(3, 256, 256)
|
||||
),
|
||||
'swinv2_small_window16_256': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_small_patch4_window16_256.pth',
|
||||
input_size=(3, 256, 256)
|
||||
),
|
||||
'swinv2_base_window8_256': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window8_256.pth',
|
||||
input_size=(3, 256, 256)
|
||||
),
|
||||
'swinv2_base_window16_256': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window16_256.pth',
|
||||
input_size=(3, 256, 256)
|
||||
),
|
||||
|
||||
'swinv2_base_window12_192_22k': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12_192_22k.pth',
|
||||
num_classes=21841, input_size=(3, 192, 192)
|
||||
),
|
||||
'swinv2_base_window12to16_192to256_22kft1k': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12to16_192to256_22kto1k_ft.pth',
|
||||
input_size=(3, 256, 256)
|
||||
),
|
||||
'swinv2_base_window12to24_192to384_22kft1k': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12to24_192to384_22kto1k_ft.pth',
|
||||
input_size=(3, 384, 384), crop_pct=1.0,
|
||||
),
|
||||
'swinv2_large_window12_192_22k': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12_192_22k.pth',
|
||||
num_classes=21841, input_size=(3, 192, 192)
|
||||
),
|
||||
'swinv2_large_window12to16_192to256_22kft1k': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12to16_192to256_22kto1k_ft.pth',
|
||||
input_size=(3, 256, 256)
|
||||
),
|
||||
'swinv2_large_window12to24_192to384_22kft1k': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12to24_192to384_22kto1k_ft.pth',
|
||||
input_size=(3, 384, 384), crop_pct=1.0,
|
||||
),
|
||||
}
|
||||
_int_or_tuple_2_t = Union[int, Tuple[int, int]]
|
||||
|
||||
|
||||
def window_partition(x, window_size: Tuple[int, int]):
|
||||
@ -141,9 +79,15 @@ class WindowAttention(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
|
||||
pretrained_window_size=[0, 0]):
|
||||
|
||||
self,
|
||||
dim,
|
||||
window_size,
|
||||
num_heads,
|
||||
qkv_bias=True,
|
||||
attn_drop=0.,
|
||||
proj_drop=0.,
|
||||
pretrained_window_size=[0, 0],
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.window_size = window_size # Wh, Ww
|
||||
@ -231,8 +175,8 @@ class WindowAttention(nn.Module):
|
||||
attn = attn + relative_position_bias.unsqueeze(0)
|
||||
|
||||
if mask is not None:
|
||||
nW = mask.shape[0]
|
||||
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
||||
num_win = mask.shape[0]
|
||||
attn = attn.view(-1, num_win, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
||||
attn = attn.view(-1, self.num_heads, N, N)
|
||||
attn = self.softmax(attn)
|
||||
else:
|
||||
@ -246,29 +190,42 @@ class WindowAttention(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class SwinTransformerBlock(nn.Module):
|
||||
r""" Swin Transformer Block.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
input_resolution (tuple[int]): Input resolution.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (int): Window size.
|
||||
shift_size (int): Shift size for SW-MSA.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
||||
drop (float, optional): Dropout rate. Default: 0.0
|
||||
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
||||
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
||||
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
||||
pretrained_window_size (int): Window size in pretraining.
|
||||
class SwinTransformerV2Block(nn.Module):
|
||||
""" Swin Transformer Block.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
|
||||
mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
|
||||
act_layer=nn.GELU, norm_layer=nn.LayerNorm, pretrained_window_size=0):
|
||||
self,
|
||||
dim,
|
||||
input_resolution,
|
||||
num_heads,
|
||||
window_size=7,
|
||||
shift_size=0,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=True,
|
||||
drop=0.,
|
||||
attn_drop=0.,
|
||||
drop_path=0.,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=nn.LayerNorm,
|
||||
pretrained_window_size=0,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
dim: Number of input channels.
|
||||
input_resolution: Input resolution.
|
||||
num_heads: Number of attention heads.
|
||||
window_size: Window size.
|
||||
shift_size: Shift size for SW-MSA.
|
||||
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias: If True, add a learnable bias to query, key, value.
|
||||
drop: Dropout rate.
|
||||
attn_drop: Attention dropout rate.
|
||||
drop_path: Stochastic depth rate.
|
||||
act_layer: Activation layer.
|
||||
norm_layer: Normalization layer.
|
||||
pretrained_window_size: Window size in pretraining.
|
||||
"""
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.input_resolution = to_2tuple(input_resolution)
|
||||
@ -280,13 +237,23 @@ class SwinTransformerBlock(nn.Module):
|
||||
self.mlp_ratio = mlp_ratio
|
||||
|
||||
self.attn = WindowAttention(
|
||||
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
|
||||
qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
|
||||
pretrained_window_size=to_2tuple(pretrained_window_size))
|
||||
dim,
|
||||
window_size=to_2tuple(self.window_size),
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop,
|
||||
pretrained_window_size=to_2tuple(pretrained_window_size),
|
||||
)
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
|
||||
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
|
||||
self.mlp = Mlp(
|
||||
in_features=dim,
|
||||
hidden_features=int(dim * mlp_ratio),
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
)
|
||||
self.norm2 = norm_layer(dim)
|
||||
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
|
||||
@ -322,10 +289,7 @@ class SwinTransformerBlock(nn.Module):
|
||||
return tuple(window_size), tuple(shift_size)
|
||||
|
||||
def _attn(self, x):
|
||||
H, W = self.input_resolution
|
||||
B, L, C = x.shape
|
||||
_assert(L == H * W, "input feature has wrong size")
|
||||
x = x.view(B, H, W, C)
|
||||
B, H, W, C = x.shape
|
||||
|
||||
# cyclic shift
|
||||
has_shift = any(self.shift_size)
|
||||
@ -350,113 +314,130 @@ class SwinTransformerBlock(nn.Module):
|
||||
x = torch.roll(shifted_x, shifts=self.shift_size, dims=(1, 2))
|
||||
else:
|
||||
x = shifted_x
|
||||
x = x.view(B, H * W, C)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
B, H, W, C = x.shape
|
||||
x = x + self.drop_path1(self.norm1(self._attn(x)))
|
||||
x = x.reshape(B, -1, C)
|
||||
x = x + self.drop_path2(self.norm2(self.mlp(x)))
|
||||
x = x.reshape(B, H, W, C)
|
||||
return x
|
||||
|
||||
|
||||
class PatchMerging(nn.Module):
|
||||
r""" Patch Merging Layer.
|
||||
|
||||
Args:
|
||||
input_resolution (tuple[int]): Resolution of input feature.
|
||||
dim (int): Number of input channels.
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
||||
""" Patch Merging Layer.
|
||||
"""
|
||||
|
||||
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
|
||||
def __init__(self, dim, out_dim=None, norm_layer=nn.LayerNorm):
|
||||
"""
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
out_dim (int): Number of output channels (or 2 * dim if None)
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
self.input_resolution = input_resolution
|
||||
self.dim = dim
|
||||
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
||||
self.norm = norm_layer(2 * dim)
|
||||
self.out_dim = out_dim or 2 * dim
|
||||
self.reduction = nn.Linear(4 * dim, self.out_dim, bias=False)
|
||||
self.norm = norm_layer(self.out_dim)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
x: B, H*W, C
|
||||
"""
|
||||
H, W = self.input_resolution
|
||||
B, L, C = x.shape
|
||||
_assert(L == H * W, "input feature has wrong size")
|
||||
_assert(H % 2 == 0, f"x size ({H}*{W}) are not even.")
|
||||
_assert(W % 2 == 0, f"x size ({H}*{W}) are not even.")
|
||||
|
||||
x = x.view(B, H, W, C)
|
||||
|
||||
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
||||
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
||||
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
||||
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
||||
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
||||
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
||||
|
||||
B, H, W, C = x.shape
|
||||
_assert(H % 2 == 0, f"x height ({H}) is not even.")
|
||||
_assert(W % 2 == 0, f"x width ({W}) is not even.")
|
||||
x = x.reshape(B, H // 2, 2, W // 2, 2, C).permute(0, 1, 3, 4, 2, 5).flatten(3)
|
||||
x = self.reduction(x)
|
||||
x = self.norm(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class BasicLayer(nn.Module):
|
||||
""" A basic Swin Transformer layer for one stage.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
input_resolution (tuple[int]): Input resolution.
|
||||
depth (int): Number of blocks.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (int): Local window size.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
||||
drop (float, optional): Dropout rate. Default: 0.0
|
||||
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
||||
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
||||
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
||||
pretrained_window_size (int): Local window size in pre-training.
|
||||
class SwinTransformerV2Stage(nn.Module):
|
||||
""" A Swin Transformer V2 Stage.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, dim, input_resolution, depth, num_heads, window_size,
|
||||
mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
|
||||
norm_layer=nn.LayerNorm, downsample=None, pretrained_window_size=0):
|
||||
|
||||
self,
|
||||
dim,
|
||||
out_dim,
|
||||
input_resolution,
|
||||
depth,
|
||||
num_heads,
|
||||
window_size,
|
||||
downsample=False,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=True,
|
||||
drop=0.,
|
||||
attn_drop=0.,
|
||||
drop_path=0.,
|
||||
norm_layer=nn.LayerNorm,
|
||||
pretrained_window_size=0,
|
||||
output_nchw=False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
dim: Number of input channels.
|
||||
input_resolution: Input resolution.
|
||||
depth: Number of blocks.
|
||||
num_heads: Number of attention heads.
|
||||
window_size: Local window size.
|
||||
downsample: Use downsample layer at start of the block.
|
||||
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias: If True, add a learnable bias to query, key, value.
|
||||
drop: Dropout rate
|
||||
attn_drop: Attention dropout rate.
|
||||
drop_path: Stochastic depth rate.
|
||||
norm_layer: Normalization layer.
|
||||
pretrained_window_size: Local window size in pretraining.
|
||||
output_nchw: Output tensors on NCHW format instead of NHWC.
|
||||
"""
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.input_resolution = input_resolution
|
||||
self.output_resolution = tuple(i // 2 for i in input_resolution) if downsample else input_resolution
|
||||
self.depth = depth
|
||||
self.output_nchw = output_nchw
|
||||
self.grad_checkpointing = False
|
||||
|
||||
# patch merging / downsample layer
|
||||
if downsample:
|
||||
self.downsample = PatchMerging(dim=dim, out_dim=out_dim, norm_layer=norm_layer)
|
||||
else:
|
||||
assert dim == out_dim
|
||||
self.downsample = nn.Identity()
|
||||
|
||||
# build blocks
|
||||
self.blocks = nn.ModuleList([
|
||||
SwinTransformerBlock(
|
||||
dim=dim, input_resolution=input_resolution,
|
||||
num_heads=num_heads, window_size=window_size,
|
||||
SwinTransformerV2Block(
|
||||
dim=out_dim,
|
||||
input_resolution=self.output_resolution,
|
||||
num_heads=num_heads,
|
||||
window_size=window_size,
|
||||
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
drop=drop, attn_drop=attn_drop,
|
||||
drop=drop,
|
||||
attn_drop=attn_drop,
|
||||
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
||||
norm_layer=norm_layer,
|
||||
pretrained_window_size=pretrained_window_size)
|
||||
pretrained_window_size=pretrained_window_size,
|
||||
)
|
||||
for i in range(depth)])
|
||||
|
||||
# patch merging layer
|
||||
if downsample is not None:
|
||||
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
|
||||
else:
|
||||
self.downsample = nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
if self.output_nchw:
|
||||
x = x.permute(0, 2, 3, 1) # NCHW -> NHWC
|
||||
|
||||
x = self.downsample(x)
|
||||
|
||||
for blk in self.blocks:
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
x = checkpoint.checkpoint(blk, x)
|
||||
else:
|
||||
x = blk(x)
|
||||
x = self.downsample(x)
|
||||
|
||||
if self.output_nchw:
|
||||
x = x.permute(0, 3, 1, 2) # NHWC -> NCHW
|
||||
return x
|
||||
|
||||
def _init_respostnorm(self):
|
||||
@ -468,88 +449,117 @@ class BasicLayer(nn.Module):
|
||||
|
||||
|
||||
class SwinTransformerV2(nn.Module):
|
||||
r""" Swin Transformer V2
|
||||
A PyTorch impl of : `Swin Transformer V2: Scaling Up Capacity and Resolution`
|
||||
- https://arxiv.org/abs/2111.09883
|
||||
Args:
|
||||
img_size (int | tuple(int)): Input image size. Default 224
|
||||
patch_size (int | tuple(int)): Patch size. Default: 4
|
||||
in_chans (int): Number of input image channels. Default: 3
|
||||
num_classes (int): Number of classes for classification head. Default: 1000
|
||||
embed_dim (int): Patch embedding dimension. Default: 96
|
||||
depths (tuple(int)): Depth of each Swin Transformer layer.
|
||||
num_heads (tuple(int)): Number of attention heads in different layers.
|
||||
window_size (int): Window size. Default: 7
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
|
||||
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
||||
drop_rate (float): Dropout rate. Default: 0
|
||||
attn_drop_rate (float): Attention dropout rate. Default: 0
|
||||
drop_path_rate (float): Stochastic depth rate. Default: 0.1
|
||||
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
||||
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
|
||||
patch_norm (bool): If True, add normalization after patch embedding. Default: True
|
||||
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
|
||||
pretrained_window_sizes (tuple(int)): Pretrained window sizes of each layer.
|
||||
""" Swin Transformer V2
|
||||
|
||||
A PyTorch impl of : `Swin Transformer V2: Scaling Up Capacity and Resolution`
|
||||
- https://arxiv.org/abs/2111.09883
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, global_pool='avg',
|
||||
embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24),
|
||||
window_size=7, mlp_ratio=4., qkv_bias=True,
|
||||
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
||||
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
|
||||
pretrained_window_sizes=(0, 0, 0, 0), **kwargs):
|
||||
self,
|
||||
img_size: _int_or_tuple_2_t = 224,
|
||||
patch_size: int = 4,
|
||||
in_chans: int = 3,
|
||||
num_classes: int = 1000,
|
||||
global_pool: str = 'avg',
|
||||
embed_dim: int = 96,
|
||||
depths: Tuple[int, ...] = (2, 2, 6, 2),
|
||||
num_heads: Tuple[int, ...] = (3, 6, 12, 24),
|
||||
window_size: _int_or_tuple_2_t = 7,
|
||||
mlp_ratio: float = 4.,
|
||||
qkv_bias: bool = True,
|
||||
drop_rate: float = 0.,
|
||||
attn_drop_rate: float = 0.,
|
||||
drop_path_rate: float = 0.1,
|
||||
norm_layer: Callable = nn.LayerNorm,
|
||||
pretrained_window_sizes: Tuple[int, ...] = (0, 0, 0, 0),
|
||||
output_fmt: str = 'NHWC',
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
img_size: Input image size.
|
||||
patch_size: Patch size.
|
||||
in_chans: Number of input image channels.
|
||||
num_classes: Number of classes for classification head.
|
||||
embed_dim: Patch embedding dimension.
|
||||
depths: Depth of each Swin Transformer stage (layer).
|
||||
num_heads: Number of attention heads in different layers.
|
||||
window_size: Window size.
|
||||
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias: If True, add a learnable bias to query, key, value.
|
||||
drop_rate: Dropout rate.
|
||||
attn_drop_rate: Attention dropout rate.
|
||||
drop_path_rate: Stochastic depth rate.
|
||||
norm_layer: Normalization layer.
|
||||
patch_norm: If True, add normalization after patch embedding.
|
||||
pretrained_window_sizes: Pretrained window sizes of each layer.
|
||||
output_fmt: Output tensor format if not None, otherwise output 'NHWC' by default.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.num_classes = num_classes
|
||||
assert global_pool in ('', 'avg')
|
||||
assert output_fmt in ('NCHW', 'NHWC')
|
||||
self.global_pool = global_pool
|
||||
self.output_fmt = output_fmt
|
||||
self.num_layers = len(depths)
|
||||
self.embed_dim = embed_dim
|
||||
self.patch_norm = patch_norm
|
||||
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
|
||||
self.output_nchw = self.output_fmt == 'NCHW'
|
||||
self.feature_info = []
|
||||
|
||||
if not isinstance(embed_dim, (tuple, list)):
|
||||
embed_dim = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
|
||||
|
||||
# split image into non-overlapping patches
|
||||
self.patch_embed = PatchEmbed(
|
||||
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
|
||||
norm_layer=norm_layer if self.patch_norm else None)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
embed_dim=embed_dim[0],
|
||||
norm_layer=norm_layer,
|
||||
output_fmt='NHWC',
|
||||
)
|
||||
|
||||
# absolute position embedding
|
||||
if ape:
|
||||
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
|
||||
trunc_normal_(self.absolute_pos_embed, std=.02)
|
||||
else:
|
||||
self.absolute_pos_embed = None
|
||||
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
|
||||
# stochastic depth
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
||||
|
||||
# build layers
|
||||
self.layers = nn.ModuleList()
|
||||
for i_layer in range(self.num_layers):
|
||||
layer = BasicLayer(
|
||||
dim=int(embed_dim * 2 ** i_layer),
|
||||
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
|
||||
layers = []
|
||||
in_dim = embed_dim[0]
|
||||
scale = 1
|
||||
for i in range(self.num_layers):
|
||||
out_dim = embed_dim[i]
|
||||
layers += [SwinTransformerV2Stage(
|
||||
dim=in_dim,
|
||||
out_dim=out_dim,
|
||||
input_resolution=(
|
||||
self.patch_embed.grid_size[0] // (2 ** i_layer),
|
||||
self.patch_embed.grid_size[1] // (2 ** i_layer)),
|
||||
depth=depths[i_layer],
|
||||
num_heads=num_heads[i_layer],
|
||||
self.patch_embed.grid_size[0] // scale,
|
||||
self.patch_embed.grid_size[1] // scale),
|
||||
depth=depths[i],
|
||||
downsample=i > 0,
|
||||
num_heads=num_heads[i],
|
||||
window_size=window_size,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
drop=drop_rate, attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
|
||||
drop_path=dpr[i],
|
||||
norm_layer=norm_layer,
|
||||
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
|
||||
pretrained_window_size=pretrained_window_sizes[i_layer]
|
||||
)
|
||||
self.layers.append(layer)
|
||||
pretrained_window_size=pretrained_window_sizes[i],
|
||||
output_nchw=self.output_nchw,
|
||||
)]
|
||||
in_dim = out_dim
|
||||
if i > 0:
|
||||
scale *= 2
|
||||
self.feature_info += [dict(num_chs=out_dim, reduction=4 * scale, module=f'layers.{i}')]
|
||||
|
||||
self.layers = nn.Sequential(*layers)
|
||||
self.norm = norm_layer(self.num_features)
|
||||
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
self.head = ClassifierHead(
|
||||
self.num_features,
|
||||
num_classes,
|
||||
pool_type=global_pool,
|
||||
drop_rate=drop_rate,
|
||||
input_fmt=self.output_fmt,
|
||||
)
|
||||
|
||||
self.apply(self._init_weights)
|
||||
for bly in self.layers:
|
||||
@ -563,7 +573,7 @@ class SwinTransformerV2(nn.Module):
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay(self):
|
||||
nod = {'absolute_pos_embed'}
|
||||
nod = set()
|
||||
for n, m in self.named_modules():
|
||||
if any([kw in n for kw in ("cpb_mlp", "logit_scale", 'relative_position_bias_table')]):
|
||||
nod.add(n)
|
||||
@ -587,31 +597,26 @@ class SwinTransformerV2(nn.Module):
|
||||
|
||||
@torch.jit.ignore
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
self.num_classes = num_classes
|
||||
if global_pool is not None:
|
||||
assert global_pool in ('', 'avg')
|
||||
self.global_pool = global_pool
|
||||
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
self.head.reset(num_classes, global_pool)
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.patch_embed(x)
|
||||
if self.absolute_pos_embed is not None:
|
||||
x = x + self.absolute_pos_embed
|
||||
x = self.pos_drop(x)
|
||||
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
|
||||
x = self.norm(x) # B L C
|
||||
if self.output_nchw:
|
||||
# patch embed always outputs NHWC, stage layers expect NCHW input if output_nchw = True
|
||||
x = x.permute(0, 3, 1, 2)
|
||||
x = self.layers(x)
|
||||
if self.output_nchw:
|
||||
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||
else:
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
def forward_head(self, x, pre_logits: bool = False):
|
||||
if self.global_pool == 'avg':
|
||||
x = x.mean(dim=1)
|
||||
return x if pre_logits else self.head(x)
|
||||
return self.head(x, pre_logits=True) if pre_logits else self.head(x)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
@ -620,25 +625,87 @@ class SwinTransformerV2(nn.Module):
|
||||
|
||||
|
||||
def checkpoint_filter_fn(state_dict, model):
|
||||
import re
|
||||
out_dict = {}
|
||||
if 'model' in state_dict:
|
||||
# For deit models
|
||||
state_dict = state_dict['model']
|
||||
state_dict = state_dict.get('model', state_dict)
|
||||
state_dict = state_dict.get('state_dict', state_dict)
|
||||
for k, v in state_dict.items():
|
||||
if any([n in k for n in ('relative_position_index', 'relative_coords_table')]):
|
||||
continue # skip buffers that should not be persistent
|
||||
k = re.sub(r'layers.(\d+).downsample', lambda x: f'layers.{int(x.group(1)) + 1}.downsample', k)
|
||||
k = k.replace('head.', 'head.fc.')
|
||||
out_dict[k] = v
|
||||
return out_dict
|
||||
|
||||
|
||||
def _create_swin_transformer_v2(variant, pretrained=False, **kwargs):
|
||||
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 1, 1))))
|
||||
out_indices = kwargs.pop('out_indices', default_out_indices)
|
||||
|
||||
model = build_model_with_cfg(
|
||||
SwinTransformerV2, variant, pretrained,
|
||||
pretrained_filter_fn=checkpoint_filter_fn,
|
||||
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
|
||||
**kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 256, 256), 'pool_size': (8, 8),
|
||||
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'patch_embed.proj', 'classifier': 'head.fc',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
'swinv2_tiny_window8_256': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window8_256.pth',
|
||||
),
|
||||
'swinv2_tiny_window16_256': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window16_256.pth',
|
||||
),
|
||||
'swinv2_small_window8_256': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_small_patch4_window8_256.pth',
|
||||
),
|
||||
'swinv2_small_window16_256': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_small_patch4_window16_256.pth',
|
||||
),
|
||||
'swinv2_base_window8_256': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window8_256.pth',
|
||||
),
|
||||
'swinv2_base_window16_256': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window16_256.pth',
|
||||
),
|
||||
|
||||
'swinv2_base_window12_192_22k': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12_192_22k.pth',
|
||||
num_classes=21841, input_size=(3, 192, 192), pool_size=(6, 6)
|
||||
),
|
||||
'swinv2_base_window12to16_192to256_22kft1k': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12to16_192to256_22kto1k_ft.pth',
|
||||
),
|
||||
'swinv2_base_window12to24_192to384_22kft1k': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12to24_192to384_22kto1k_ft.pth',
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0,
|
||||
),
|
||||
'swinv2_large_window12_192_22k': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12_192_22k.pth',
|
||||
num_classes=21841, input_size=(3, 192, 192), pool_size=(6, 6)
|
||||
),
|
||||
'swinv2_large_window12to16_192to256_22kft1k': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12to16_192to256_22kto1k_ft.pth',
|
||||
),
|
||||
'swinv2_large_window12to24_192to384_22kft1k': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12to24_192to384_22kto1k_ft.pth',
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@register_model
|
||||
def swinv2_tiny_window16_256(pretrained=False, **kwargs):
|
||||
"""
|
||||
|
@ -37,7 +37,7 @@ import torch.nn.functional as F
|
||||
import torch.utils.checkpoint as checkpoint
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import DropPath, Mlp, to_2tuple, _assert
|
||||
from timm.layers import DropPath, Mlp, ClassifierHead, to_2tuple, _assert
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features_fx import register_notrace_function
|
||||
from ._manipulate import named_apply
|
||||
@ -48,60 +48,6 @@ __all__ = ['SwinTransformerV2Cr'] # model_registry will add each entrypoint fn
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000,
|
||||
'input_size': (3, 224, 224),
|
||||
'pool_size': (7, 7),
|
||||
'crop_pct': 0.9,
|
||||
'interpolation': 'bicubic',
|
||||
'fixed_input_size': True,
|
||||
'mean': IMAGENET_DEFAULT_MEAN,
|
||||
'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'patch_embed.proj',
|
||||
'classifier': 'head',
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
'swinv2_cr_tiny_384': _cfg(
|
||||
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
|
||||
'swinv2_cr_tiny_224': _cfg(
|
||||
url="", input_size=(3, 224, 224), crop_pct=0.9),
|
||||
'swinv2_cr_tiny_ns_224': _cfg(
|
||||
url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_tiny_ns_224-ba8166c6.pth",
|
||||
input_size=(3, 224, 224), crop_pct=0.9),
|
||||
'swinv2_cr_small_384': _cfg(
|
||||
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
|
||||
'swinv2_cr_small_224': _cfg(
|
||||
url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_small_224-0813c165.pth",
|
||||
input_size=(3, 224, 224), crop_pct=0.9),
|
||||
'swinv2_cr_small_ns_224': _cfg(
|
||||
url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_small_ns_224_iv-2ce90f8e.pth",
|
||||
input_size=(3, 224, 224), crop_pct=0.9),
|
||||
'swinv2_cr_base_384': _cfg(
|
||||
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
|
||||
'swinv2_cr_base_224': _cfg(
|
||||
url="", input_size=(3, 224, 224), crop_pct=0.9),
|
||||
'swinv2_cr_base_ns_224': _cfg(
|
||||
url="", input_size=(3, 224, 224), crop_pct=0.9),
|
||||
'swinv2_cr_large_384': _cfg(
|
||||
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
|
||||
'swinv2_cr_large_224': _cfg(
|
||||
url="", input_size=(3, 224, 224), crop_pct=0.9),
|
||||
'swinv2_cr_huge_384': _cfg(
|
||||
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
|
||||
'swinv2_cr_huge_224': _cfg(
|
||||
url="", input_size=(3, 224, 224), crop_pct=0.9),
|
||||
'swinv2_cr_giant_384': _cfg(
|
||||
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
|
||||
'swinv2_cr_giant_224': _cfg(
|
||||
url="", input_size=(3, 224, 224), crop_pct=0.9),
|
||||
}
|
||||
|
||||
|
||||
def bchw_to_bhwc(x: torch.Tensor) -> torch.Tensor:
|
||||
"""Permutes a tensor from the shape (B, C, H, W) to (B, H, W, C). """
|
||||
return x.permute(0, 2, 3, 1)
|
||||
@ -230,22 +176,14 @@ class WindowMultiHeadAttention(nn.Module):
|
||||
relative_position_bias = relative_position_bias.unsqueeze(0)
|
||||
return relative_position_bias
|
||||
|
||||
def _forward_sequential(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
"""
|
||||
# FIXME TODO figure out 'sequential' attention mentioned in paper (should reduce GPU memory)
|
||||
assert False, "not implemented"
|
||||
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
""" Forward pass.
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor of the shape (B * windows, N, C)
|
||||
mask (Optional[torch.Tensor]): Attention mask for the shift case
|
||||
|
||||
def _forward_batch(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""This function performs standard (non-sequential) scaled cosine self-attention.
|
||||
Returns:
|
||||
Output tensor of the shape [B * windows, N, C]
|
||||
"""
|
||||
Bw, L, C = x.shape
|
||||
|
||||
@ -272,22 +210,8 @@ class WindowMultiHeadAttention(nn.Module):
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
""" Forward pass.
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor of the shape (B * windows, N, C)
|
||||
mask (Optional[torch.Tensor]): Attention mask for the shift case
|
||||
|
||||
Returns:
|
||||
Output tensor of the shape [B * windows, N, C]
|
||||
"""
|
||||
if self.sequential_attn:
|
||||
return self._forward_sequential(x, mask)
|
||||
else:
|
||||
return self._forward_batch(x, mask)
|
||||
|
||||
|
||||
class SwinTransformerBlock(nn.Module):
|
||||
class SwinTransformerV2CrBlock(nn.Module):
|
||||
r"""This class implements the Swin transformer block.
|
||||
|
||||
Args:
|
||||
@ -321,7 +245,7 @@ class SwinTransformerBlock(nn.Module):
|
||||
sequential_attn: bool = False,
|
||||
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
||||
) -> None:
|
||||
super(SwinTransformerBlock, self).__init__()
|
||||
super(SwinTransformerV2CrBlock, self).__init__()
|
||||
self.dim: int = dim
|
||||
self.feat_size: Tuple[int, int] = feat_size
|
||||
self.target_shift_size: Tuple[int, int] = to_2tuple(shift_size)
|
||||
@ -410,9 +334,7 @@ class SwinTransformerBlock(nn.Module):
|
||||
self._make_attention_mask()
|
||||
|
||||
def _shifted_window_attn(self, x):
|
||||
H, W = self.feat_size
|
||||
B, L, C = x.shape
|
||||
x = x.view(B, H, W, C)
|
||||
B, H, W, C = x.shape
|
||||
|
||||
# cyclic shift
|
||||
sh, sw = self.shift_size
|
||||
@ -441,7 +363,6 @@ class SwinTransformerBlock(nn.Module):
|
||||
# x = torch.cat([x[:, :, -sw:], x[:, :, :-sw]], dim=2)
|
||||
x = torch.roll(x, shifts=(sh, sw), dims=(1, 2))
|
||||
|
||||
x = x.view(B, L, C)
|
||||
return x
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@ -455,8 +376,12 @@ class SwinTransformerBlock(nn.Module):
|
||||
"""
|
||||
# post-norm branches (op -> norm -> drop)
|
||||
x = x + self.drop_path1(self.norm1(self._shifted_window_attn(x)))
|
||||
|
||||
B, H, W, C = x.shape
|
||||
x = x.reshape(B, -1, C)
|
||||
x = x + self.drop_path2(self.norm2(self.mlp(x)))
|
||||
x = self.norm3(x) # main-branch norm enabled for some blocks / stages (every 6 for Huge/Giant)
|
||||
x = x.reshape(B, H, W, C)
|
||||
return x
|
||||
|
||||
|
||||
@ -479,12 +404,10 @@ class PatchMerging(nn.Module):
|
||||
Returns:
|
||||
output (torch.Tensor): Output tensor of the shape [B, 2 * C, H // 2, W // 2]
|
||||
"""
|
||||
B, C, H, W = x.shape
|
||||
# unfold + BCHW -> BHWC together
|
||||
# ordering, 5, 3, 1 instead of 3, 5, 1 maintains compat with original swin v1 merge
|
||||
x = x.reshape(B, C, H // 2, 2, W // 2, 2).permute(0, 2, 4, 5, 3, 1).flatten(3)
|
||||
B, H, W, C = x.shape
|
||||
x = x.reshape(B, H // 2, 2, W // 2, 2, C).permute(0, 1, 3, 4, 2, 5).flatten(3)
|
||||
x = self.norm(x)
|
||||
x = bhwc_to_bchw(self.reduction(x))
|
||||
x = self.reduction(x)
|
||||
return x
|
||||
|
||||
|
||||
@ -511,7 +434,7 @@ class PatchEmbed(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class SwinTransformerStage(nn.Module):
|
||||
class SwinTransformerV2CrStage(nn.Module):
|
||||
r"""This class implements a stage of the Swin transformer including multiple layers.
|
||||
|
||||
Args:
|
||||
@ -549,12 +472,16 @@ class SwinTransformerStage(nn.Module):
|
||||
extra_norm_stage: bool = False,
|
||||
sequential_attn: bool = False,
|
||||
) -> None:
|
||||
super(SwinTransformerStage, self).__init__()
|
||||
super(SwinTransformerV2CrStage, self).__init__()
|
||||
self.downscale: bool = downscale
|
||||
self.grad_checkpointing: bool = False
|
||||
self.feat_size: Tuple[int, int] = (feat_size[0] // 2, feat_size[1] // 2) if downscale else feat_size
|
||||
|
||||
self.downsample = PatchMerging(embed_dim, norm_layer=norm_layer) if downscale else nn.Identity()
|
||||
if downscale:
|
||||
self.downsample = PatchMerging(embed_dim, norm_layer=norm_layer)
|
||||
embed_dim = embed_dim * 2
|
||||
else:
|
||||
self.downsample = nn.Identity()
|
||||
|
||||
def _extra_norm(index):
|
||||
i = index + 1
|
||||
@ -562,9 +489,8 @@ class SwinTransformerStage(nn.Module):
|
||||
return True
|
||||
return i == depth if extra_norm_stage else False
|
||||
|
||||
embed_dim = embed_dim * 2 if downscale else embed_dim
|
||||
self.blocks = nn.Sequential(*[
|
||||
SwinTransformerBlock(
|
||||
SwinTransformerV2CrBlock(
|
||||
dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
feat_size=self.feat_size,
|
||||
@ -602,18 +528,15 @@ class SwinTransformerStage(nn.Module):
|
||||
Returns:
|
||||
output (torch.Tensor): Output tensor of the shape [B, 2 * C, H // 2, W // 2]
|
||||
"""
|
||||
x = bchw_to_bhwc(x)
|
||||
x = self.downsample(x)
|
||||
B, C, H, W = x.shape
|
||||
L = H * W
|
||||
|
||||
x = bchw_to_bhwc(x).reshape(B, L, C)
|
||||
for block in self.blocks:
|
||||
# Perform checkpointing if utilized
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
x = checkpoint.checkpoint(block, x)
|
||||
else:
|
||||
x = block(x)
|
||||
x = bhwc_to_bchw(x.reshape(B, H, W, -1))
|
||||
x = bhwc_to_bchw(x)
|
||||
return x
|
||||
|
||||
|
||||
@ -676,39 +599,54 @@ class SwinTransformerV2Cr(nn.Module):
|
||||
self.img_size: Tuple[int, int] = img_size
|
||||
self.window_size: int = window_size
|
||||
self.num_features: int = int(embed_dim * 2 ** (len(depths) - 1))
|
||||
self.feature_info = []
|
||||
|
||||
self.patch_embed = PatchEmbed(
|
||||
img_size=img_size, patch_size=patch_size, in_chans=in_chans,
|
||||
embed_dim=embed_dim, norm_layer=norm_layer)
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
embed_dim=embed_dim,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
patch_grid_size: Tuple[int, int] = self.patch_embed.grid_size
|
||||
|
||||
drop_path_rate = torch.linspace(0.0, drop_path_rate, sum(depths)).tolist()
|
||||
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
|
||||
stages = []
|
||||
for index, (depth, num_heads) in enumerate(zip(depths, num_heads)):
|
||||
stage_scale = 2 ** max(index - 1, 0)
|
||||
stages.append(
|
||||
SwinTransformerStage(
|
||||
embed_dim=embed_dim * stage_scale,
|
||||
depth=depth,
|
||||
downscale=index != 0,
|
||||
feat_size=(patch_grid_size[0] // stage_scale, patch_grid_size[1] // stage_scale),
|
||||
num_heads=num_heads,
|
||||
window_size=window_size,
|
||||
mlp_ratio=mlp_ratio,
|
||||
init_values=init_values,
|
||||
drop=drop_rate,
|
||||
drop_attn=attn_drop_rate,
|
||||
drop_path=drop_path_rate[sum(depths[:index]):sum(depths[:index + 1])],
|
||||
extra_norm_period=extra_norm_period,
|
||||
extra_norm_stage=extra_norm_stage or (index + 1) == len(depths), # last stage ends w/ norm
|
||||
sequential_attn=sequential_attn,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
)
|
||||
in_dim = embed_dim
|
||||
in_scale = 1
|
||||
for stage_idx, (depth, num_heads) in enumerate(zip(depths, num_heads)):
|
||||
stages += [SwinTransformerV2CrStage(
|
||||
embed_dim=in_dim,
|
||||
depth=depth,
|
||||
downscale=stage_idx != 0,
|
||||
feat_size=(
|
||||
patch_grid_size[0] // in_scale,
|
||||
patch_grid_size[1] // in_scale
|
||||
),
|
||||
num_heads=num_heads,
|
||||
window_size=window_size,
|
||||
mlp_ratio=mlp_ratio,
|
||||
init_values=init_values,
|
||||
drop=drop_rate,
|
||||
drop_attn=attn_drop_rate,
|
||||
drop_path=dpr[stage_idx],
|
||||
extra_norm_period=extra_norm_period,
|
||||
extra_norm_stage=extra_norm_stage or (stage_idx + 1) == len(depths), # last stage ends w/ norm
|
||||
sequential_attn=sequential_attn,
|
||||
norm_layer=norm_layer,
|
||||
)]
|
||||
if stage_idx != 0:
|
||||
in_dim *= 2
|
||||
in_scale *= 2
|
||||
self.feature_info += [dict(num_chs=in_dim, reduction=4 * in_scale, module=f'stages.{stage_idx}')]
|
||||
self.stages = nn.Sequential(*stages)
|
||||
|
||||
self.global_pool: str = global_pool
|
||||
self.head = nn.Linear(self.num_features, num_classes) if num_classes else nn.Identity()
|
||||
self.head = ClassifierHead(
|
||||
self.num_features,
|
||||
num_classes,
|
||||
pool_type=global_pool,
|
||||
drop_rate=drop_rate,
|
||||
)
|
||||
|
||||
# current weight init skips custom init and uses pytorch layer defaults, seems to work well
|
||||
# FIXME more experiments needed
|
||||
@ -765,7 +703,7 @@ class SwinTransformerV2Cr(nn.Module):
|
||||
Returns:
|
||||
head (nn.Module): Current classification head
|
||||
"""
|
||||
return self.head
|
||||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) -> None:
|
||||
"""Method results the classification head
|
||||
@ -774,10 +712,8 @@ class SwinTransformerV2Cr(nn.Module):
|
||||
num_classes (int): Number of classes to be predicted
|
||||
global_pool (str): Unused
|
||||
"""
|
||||
self.num_classes: int = num_classes
|
||||
if global_pool is not None:
|
||||
self.global_pool = global_pool
|
||||
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
self.num_classes = num_classes
|
||||
self.head.reset(num_classes, global_pool)
|
||||
|
||||
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.patch_embed(x)
|
||||
@ -785,9 +721,7 @@ class SwinTransformerV2Cr(nn.Module):
|
||||
return x
|
||||
|
||||
def forward_head(self, x, pre_logits: bool = False):
|
||||
if self.global_pool == 'avg':
|
||||
x = x.mean(dim=(2, 3))
|
||||
return x if pre_logits else self.head(x)
|
||||
return self.head(x, pre_logits=True) if pre_logits else self.head(x)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.forward_features(x)
|
||||
@ -815,29 +749,87 @@ def init_weights(module: nn.Module, name: str = ''):
|
||||
def checkpoint_filter_fn(state_dict, model):
|
||||
""" convert patch embedding weight from manual patchify + linear proj to conv"""
|
||||
out_dict = {}
|
||||
if 'model' in state_dict:
|
||||
# For deit models
|
||||
state_dict = state_dict['model']
|
||||
state_dict = state_dict.get('model', state_dict)
|
||||
state_dict = state_dict.get('state_dict', state_dict)
|
||||
for k, v in state_dict.items():
|
||||
if 'tau' in k:
|
||||
# convert old tau based checkpoints -> logit_scale (inverse)
|
||||
v = torch.log(1 / v)
|
||||
k = k.replace('tau', 'logit_scale')
|
||||
k = k.replace('head.', 'head.fc.')
|
||||
out_dict[k] = v
|
||||
return out_dict
|
||||
|
||||
|
||||
def _create_swin_transformer_v2_cr(variant, pretrained=False, **kwargs):
|
||||
if kwargs.get('features_only', None):
|
||||
raise RuntimeError('features_only not implemented for Vision Transformer models.')
|
||||
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 1, 1))))
|
||||
out_indices = kwargs.pop('out_indices', default_out_indices)
|
||||
|
||||
model = build_model_with_cfg(
|
||||
SwinTransformerV2Cr, variant, pretrained,
|
||||
pretrained_filter_fn=checkpoint_filter_fn,
|
||||
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
|
||||
**kwargs
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000,
|
||||
'input_size': (3, 224, 224),
|
||||
'pool_size': (7, 7),
|
||||
'crop_pct': 0.9,
|
||||
'interpolation': 'bicubic',
|
||||
'fixed_input_size': True,
|
||||
'mean': IMAGENET_DEFAULT_MEAN,
|
||||
'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'patch_embed.proj',
|
||||
'classifier': 'head.fc',
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
'swinv2_cr_tiny_384': _cfg(
|
||||
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
|
||||
'swinv2_cr_tiny_224': _cfg(
|
||||
url="", input_size=(3, 224, 224), crop_pct=0.9),
|
||||
'swinv2_cr_tiny_ns_224': _cfg(
|
||||
url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_tiny_ns_224-ba8166c6.pth",
|
||||
input_size=(3, 224, 224), crop_pct=0.9),
|
||||
'swinv2_cr_small_384': _cfg(
|
||||
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
|
||||
'swinv2_cr_small_224': _cfg(
|
||||
url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_small_224-0813c165.pth",
|
||||
input_size=(3, 224, 224), crop_pct=0.9),
|
||||
'swinv2_cr_small_ns_224': _cfg(
|
||||
url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_small_ns_224_iv-2ce90f8e.pth",
|
||||
input_size=(3, 224, 224), crop_pct=0.9),
|
||||
'swinv2_cr_small_ns_256': _cfg(
|
||||
url="", input_size=(3, 256, 256), crop_pct=1.0, pool_size=(8, 8)),
|
||||
'swinv2_cr_base_384': _cfg(
|
||||
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
|
||||
'swinv2_cr_base_224': _cfg(
|
||||
url="", input_size=(3, 224, 224), crop_pct=0.9),
|
||||
'swinv2_cr_base_ns_224': _cfg(
|
||||
url="", input_size=(3, 224, 224), crop_pct=0.9),
|
||||
'swinv2_cr_large_384': _cfg(
|
||||
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
|
||||
'swinv2_cr_large_224': _cfg(
|
||||
url="", input_size=(3, 224, 224), crop_pct=0.9),
|
||||
'swinv2_cr_huge_384': _cfg(
|
||||
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
|
||||
'swinv2_cr_huge_224': _cfg(
|
||||
url="", input_size=(3, 224, 224), crop_pct=0.9),
|
||||
'swinv2_cr_giant_384': _cfg(
|
||||
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
|
||||
'swinv2_cr_giant_224': _cfg(
|
||||
url="", input_size=(3, 224, 224), crop_pct=0.9),
|
||||
}
|
||||
|
||||
|
||||
@register_model
|
||||
def swinv2_cr_tiny_384(pretrained=False, **kwargs):
|
||||
"""Swin-T V2 CR @ 384x384, trained ImageNet-1k"""
|
||||
@ -915,6 +907,19 @@ def swinv2_cr_small_ns_224(pretrained=False, **kwargs):
|
||||
return _create_swin_transformer_v2_cr('swinv2_cr_small_ns_224', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def swinv2_cr_small_ns_256(pretrained=False, **kwargs):
|
||||
"""Swin-S V2 CR @ 256x256, trained ImageNet-1k"""
|
||||
model_kwargs = dict(
|
||||
embed_dim=96,
|
||||
depths=(2, 2, 18, 2),
|
||||
num_heads=(3, 6, 12, 24),
|
||||
extra_norm_stage=True,
|
||||
**kwargs
|
||||
)
|
||||
return _create_swin_transformer_v2_cr('swinv2_cr_small_ns_256', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def swinv2_cr_base_384(pretrained=False, **kwargs):
|
||||
"""Swin-B V2 CR @ 384x384, trained ImageNet-1k"""
|
||||
@ -961,8 +966,7 @@ def swinv2_cr_large_384(pretrained=False, **kwargs):
|
||||
num_heads=(6, 12, 24, 48),
|
||||
**kwargs
|
||||
)
|
||||
return _create_swin_transformer_v2_cr('swinv2_cr_large_384', pretrained=pretrained, **model_kwargs
|
||||
)
|
||||
return _create_swin_transformer_v2_cr('swinv2_cr_large_384', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
|
Loading…
x
Reference in New Issue
Block a user