All swin models support spatial output, add output_fmt to v1/v2 and use ClassifierHead.

* update ClassifierHead to allow different input format
* add output format support to patch embed
* fix some flatten issues for a few conv head models
* add Format enum and helpers for tensor format (layout) choices
This commit is contained in:
Ross Wightman 2023-03-15 23:21:51 -07:00
parent c30a160d3e
commit acfd85ad68
20 changed files with 1417 additions and 998 deletions

View File

@ -27,7 +27,8 @@ except ImportError:
import timm
from timm import list_models, create_model, set_scriptable, get_pretrained_cfg_value
from timm.models._features_fx import _leaf_modules, _autowrap_functions
from timm.layers import Format, get_spatial_dim, get_channel_dim
from timm.models import get_notrace_modules, get_notrace_functions
if hasattr(torch._C, '_jit_set_profiling_executor'):
# legacy executor is too slow to compile large models for unit tests
@ -37,10 +38,10 @@ if hasattr(torch._C, '_jit_set_profiling_executor'):
# transformer models don't support many of the spatial / feature based model functionalities
NON_STD_FILTERS = [
'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit*',
'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*',
'eva_*', 'flexivit*'
'poolformer_*', 'volo_*', 'sequencer2d_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*',
'eva_*', 'flexivit*',
]
NUM_NON_STD = len(NON_STD_FILTERS)
@ -52,7 +53,7 @@ if 'GITHUB_ACTIONS' in os.environ:
'*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', '*efficientnetv2_xl*',
'*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', 'vit_huge*', 'vit_gi*', 'swin*huge*',
'swin*giant*', 'convnextv2_huge*', 'maxvit_xlarge*', 'davit_giant', 'davit_huge']
NON_STD_EXCLUDE_FILTERS = ['vit_huge*', 'vit_gi*', 'swin*giant*', 'eva_giant*']
NON_STD_EXCLUDE_FILTERS = ['vit_huge*', 'vit_gi*', 'eva_giant*']
else:
EXCLUDE_FILTERS = []
NON_STD_EXCLUDE_FILTERS = ['vit_gi*']
@ -156,6 +157,10 @@ def test_model_default_cfgs(model_name, batch_size):
pool_size = cfg['pool_size']
input_size = model.default_cfg['input_size']
output_fmt = getattr(model, 'output_fmt', 'NCHW')
spatial_axis = get_spatial_dim(output_fmt)
assert len(spatial_axis) == 2 # TODO add 1D sequence support
feat_axis = get_channel_dim(output_fmt)
if all([x <= MAX_FWD_OUT_SIZE for x in input_size]) and \
not any([fnmatch.fnmatch(model_name, x) for x in EXCLUDE_FILTERS]):
@ -165,13 +170,14 @@ def test_model_default_cfgs(model_name, batch_size):
# test forward_features (always unpooled)
outputs = model.forward_features(input_tensor)
assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2], 'unpooled feature shape != config'
assert outputs.shape[spatial_axis[0]] == pool_size[0], 'unpooled feature shape != config'
assert outputs.shape[spatial_axis[1]] == pool_size[1], 'unpooled feature shape != config'
# test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features
model.reset_classifier(0)
outputs = model.forward(input_tensor)
assert len(outputs.shape) == 2
assert outputs.shape[-1] == model.num_features
assert outputs.shape[1] == model.num_features
# test model forward without pooling and classifier
model.reset_classifier(0, '') # reset classifier and set global pooling to pass-through
@ -179,7 +185,7 @@ def test_model_default_cfgs(model_name, batch_size):
assert len(outputs.shape) == 4
if not isinstance(model, (timm.models.MobileNetV3, timm.models.GhostNet, timm.models.VGG)):
# mobilenetv3/ghostnet/vgg forward_features vs removed pooling differ due to location or lack of GAP
assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2]
assert outputs.shape[spatial_axis[0]] == pool_size[0] and outputs.shape[spatial_axis[1]] == pool_size[1]
if 'pruned' not in model_name: # FIXME better pruned model handling
# test classifier + global pool deletion via __init__
@ -187,7 +193,7 @@ def test_model_default_cfgs(model_name, batch_size):
outputs = model.forward(input_tensor)
assert len(outputs.shape) == 4
if not isinstance(model, (timm.models.MobileNetV3, timm.models.GhostNet, timm.models.VGG)):
assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2]
assert outputs.shape[spatial_axis[0]] == pool_size[0] and outputs.shape[spatial_axis[1]] == pool_size[1]
# check classifier name matches default_cfg
if cfg.get('num_classes', None):
@ -330,176 +336,181 @@ def test_model_forward_features(model_name, batch_size):
model = create_model(model_name, pretrained=False, features_only=True)
model.eval()
expected_channels = model.feature_info.channels()
expected_reduction = model.feature_info.reduction()
assert len(expected_channels) >= 4 # all models here should have at least 4 feature levels by default, some 5 or 6
input_size = _get_input_size(model=model, target=TARGET_FFEAT_SIZE)
if max(input_size) > MAX_FFEAT_SIZE:
pytest.skip("Fixed input size model > limit.")
output_fmt = getattr(model, 'output_fmt', 'NCHW')
feat_axis = get_channel_dim(output_fmt)
spatial_axis = get_spatial_dim(output_fmt)
import math
outputs = model(torch.randn((batch_size, *input_size)))
assert len(expected_channels) == len(outputs)
for e, o in zip(expected_channels, outputs):
assert e == o.shape[1]
spatial_size = input_size[-2:]
for e, r, o in zip(expected_channels, expected_reduction, outputs):
assert e == o.shape[feat_axis]
assert o.shape[spatial_axis[0]] <= math.ceil(spatial_size[0] / r) + 1
assert o.shape[spatial_axis[1]] <= math.ceil(spatial_size[1] / r) + 1
assert o.shape[0] == batch_size
assert not torch.isnan(o).any()
if not _IS_MAC:
# MACOS test runners are really slow, only running tests below this point if not on a Darwin runner...
def _create_fx_model(model, train=False):
# This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode
# So we trace once and look at the graph, and get the indices of the nodes that lead into the original fx output
# node. Then we use those indices to select from train_nodes returned by torchvision get_graph_node_names
tracer_kwargs = dict(
leaf_modules=get_notrace_modules(),
autowrap_functions=get_notrace_functions(),
#enable_cpatching=True,
param_shapes_constant=True
)
train_nodes, eval_nodes = get_graph_node_names(model, tracer_kwargs=tracer_kwargs)
def _create_fx_model(model, train=False):
# This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode
# So we trace once and look at the graph, and get the indices of the nodes that lead into the original fx output
# node. Then we use those indices to select from train_nodes returned by torchvision get_graph_node_names
tracer_kwargs = dict(
leaf_modules=list(_leaf_modules),
autowrap_functions=list(_autowrap_functions),
#enable_cpatching=True,
param_shapes_constant=True
)
train_nodes, eval_nodes = get_graph_node_names(model, tracer_kwargs=tracer_kwargs)
eval_return_nodes = [eval_nodes[-1]]
train_return_nodes = [train_nodes[-1]]
if train:
tracer = NodePathTracer(**tracer_kwargs)
graph = tracer.trace(model)
graph_nodes = list(reversed(graph.nodes))
output_node_names = [n.name for n in graph_nodes[0]._input_nodes.keys()]
graph_node_names = [n.name for n in graph_nodes]
output_node_indices = [-graph_node_names.index(node_name) for node_name in output_node_names]
train_return_nodes = [train_nodes[ix] for ix in output_node_indices]
eval_return_nodes = [eval_nodes[-1]]
train_return_nodes = [train_nodes[-1]]
if train:
tracer = NodePathTracer(**tracer_kwargs)
graph = tracer.trace(model)
graph_nodes = list(reversed(graph.nodes))
output_node_names = [n.name for n in graph_nodes[0]._input_nodes.keys()]
graph_node_names = [n.name for n in graph_nodes]
output_node_indices = [-graph_node_names.index(node_name) for node_name in output_node_names]
train_return_nodes = [train_nodes[ix] for ix in output_node_indices]
fx_model = create_feature_extractor(
model,
train_return_nodes=train_return_nodes,
eval_return_nodes=eval_return_nodes,
tracer_kwargs=tracer_kwargs,
)
return fx_model
fx_model = create_feature_extractor(
model,
train_return_nodes=train_return_nodes,
eval_return_nodes=eval_return_nodes,
tracer_kwargs=tracer_kwargs,
)
return fx_model
EXCLUDE_FX_FILTERS = ['vit_gi*']
# not enough memory to run fx on more models than other tests
if 'GITHUB_ACTIONS' in os.environ:
EXCLUDE_FX_FILTERS += [
'beit_large*',
'mixer_l*',
'*nfnet_f2*',
'*resnext101_32x32d',
'resnetv2_152x2*',
'resmlp_big*',
'resnetrs270',
'swin_large*',
'vgg*',
'vit_large*',
'vit_base_patch8*',
'xcit_large*',
]
EXCLUDE_FX_FILTERS = ['vit_gi*']
# not enough memory to run fx on more models than other tests
if 'GITHUB_ACTIONS' in os.environ:
EXCLUDE_FX_FILTERS += [
'beit_large*',
'mixer_l*',
'*nfnet_f2*',
'*resnext101_32x32d',
'resnetv2_152x2*',
'resmlp_big*',
'resnetrs270',
'swin_large*',
'vgg*',
'vit_large*',
'vit_base_patch8*',
'xcit_large*',
]
@pytest.mark.fxforward
@pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FX_FILTERS))
@pytest.mark.parametrize('batch_size', [1])
def test_model_forward_fx(model_name, batch_size):
"""
Symbolically trace each model and run single forward pass through the resulting GraphModule
Also check that the output of a forward pass through the GraphModule is the same as that from the original Module
"""
if not has_fx_feature_extraction:
pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.")
model = create_model(model_name, pretrained=False)
model.eval()
input_size = _get_input_size(model=model, target=TARGET_FWD_FX_SIZE)
if max(input_size) > MAX_FWD_FX_SIZE:
pytest.skip("Fixed input size model > limit.")
with torch.no_grad():
inputs = torch.randn((batch_size, *input_size))
outputs = model(inputs)
if isinstance(outputs, tuple):
outputs = torch.cat(outputs)
model = _create_fx_model(model)
fx_outputs = tuple(model(inputs).values())
if isinstance(fx_outputs, tuple):
fx_outputs = torch.cat(fx_outputs)
assert torch.all(fx_outputs == outputs)
assert outputs.shape[0] == batch_size
assert not torch.isnan(outputs).any(), 'Output included NaNs'
@pytest.mark.fxbackward
@pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(
exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FX_FILTERS, name_matches_cfg=True))
@pytest.mark.parametrize('batch_size', [2])
def test_model_backward_fx(model_name, batch_size):
"""Symbolically trace each model and run single backward pass through the resulting GraphModule"""
if not has_fx_feature_extraction:
pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.")
input_size = _get_input_size(model_name=model_name, target=TARGET_BWD_FX_SIZE)
if max(input_size) > MAX_BWD_FX_SIZE:
pytest.skip("Fixed input size model > limit.")
model = create_model(model_name, pretrained=False, num_classes=42)
model.train()
num_params = sum([x.numel() for x in model.parameters()])
if 'GITHUB_ACTIONS' in os.environ and num_params > 100e6:
pytest.skip("Skipping FX backward test on model with more than 100M params.")
model = _create_fx_model(model, train=True)
outputs = tuple(model(torch.randn((batch_size, *input_size))).values())
if isinstance(outputs, tuple):
outputs = torch.cat(outputs)
outputs.mean().backward()
for n, x in model.named_parameters():
assert x.grad is not None, f'No gradient for {n}'
num_grad = sum([x.grad.numel() for x in model.parameters() if x.grad is not None])
assert outputs.shape[-1] == 42
assert num_params == num_grad, 'Some parameters are missing gradients'
assert not torch.isnan(outputs).any(), 'Output included NaNs'
if 'GITHUB_ACTIONS' not in os.environ:
# FIXME this test is causing GitHub actions to run out of RAM and abruptly kill the test process
# reason: model is scripted after fx tracing, but beit has torch.jit.is_scripting() control flow
EXCLUDE_FX_JIT_FILTERS = [
'deit_*_distilled_patch16_224',
'levit*',
'pit_*_distilled_224',
] + EXCLUDE_FX_FILTERS
@pytest.mark.fxforward
@pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FX_FILTERS))
@pytest.mark.parametrize(
'model_name', list_models(
exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS + EXCLUDE_FX_JIT_FILTERS, name_matches_cfg=True))
@pytest.mark.parametrize('batch_size', [1])
def test_model_forward_fx(model_name, batch_size):
"""
Symbolically trace each model and run single forward pass through the resulting GraphModule
Also check that the output of a forward pass through the GraphModule is the same as that from the original Module
"""
def test_model_forward_fx_torchscript(model_name, batch_size):
"""Symbolically trace each model, script it, and run single forward pass"""
if not has_fx_feature_extraction:
pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.")
model = create_model(model_name, pretrained=False)
input_size = _get_input_size(model_name=model_name, target=TARGET_JIT_SIZE)
if max(input_size) > MAX_JIT_SIZE:
pytest.skip("Fixed input size model > limit.")
with set_scriptable(True):
model = create_model(model_name, pretrained=False)
model.eval()
input_size = _get_input_size(model=model, target=TARGET_FWD_FX_SIZE)
if max(input_size) > MAX_FWD_FX_SIZE:
pytest.skip("Fixed input size model > limit.")
model = torch.jit.script(_create_fx_model(model))
with torch.no_grad():
inputs = torch.randn((batch_size, *input_size))
outputs = model(inputs)
outputs = tuple(model(torch.randn((batch_size, *input_size))).values())
if isinstance(outputs, tuple):
outputs = torch.cat(outputs)
model = _create_fx_model(model)
fx_outputs = tuple(model(inputs).values())
if isinstance(fx_outputs, tuple):
fx_outputs = torch.cat(fx_outputs)
assert torch.all(fx_outputs == outputs)
assert outputs.shape[0] == batch_size
assert not torch.isnan(outputs).any(), 'Output included NaNs'
@pytest.mark.fxbackward
@pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(
exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FX_FILTERS, name_matches_cfg=True))
@pytest.mark.parametrize('batch_size', [2])
def test_model_backward_fx(model_name, batch_size):
"""Symbolically trace each model and run single backward pass through the resulting GraphModule"""
if not has_fx_feature_extraction:
pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.")
input_size = _get_input_size(model_name=model_name, target=TARGET_BWD_FX_SIZE)
if max(input_size) > MAX_BWD_FX_SIZE:
pytest.skip("Fixed input size model > limit.")
model = create_model(model_name, pretrained=False, num_classes=42)
model.train()
num_params = sum([x.numel() for x in model.parameters()])
if 'GITHUB_ACTIONS' in os.environ and num_params > 100e6:
pytest.skip("Skipping FX backward test on model with more than 100M params.")
model = _create_fx_model(model, train=True)
outputs = tuple(model(torch.randn((batch_size, *input_size))).values())
if isinstance(outputs, tuple):
outputs = torch.cat(outputs)
outputs.mean().backward()
for n, x in model.named_parameters():
assert x.grad is not None, f'No gradient for {n}'
num_grad = sum([x.grad.numel() for x in model.parameters() if x.grad is not None])
assert outputs.shape[-1] == 42
assert num_params == num_grad, 'Some parameters are missing gradients'
assert not torch.isnan(outputs).any(), 'Output included NaNs'
if 'GITHUB_ACTIONS' not in os.environ:
# FIXME this test is causing GitHub actions to run out of RAM and abruptly kill the test process
# reason: model is scripted after fx tracing, but beit has torch.jit.is_scripting() control flow
EXCLUDE_FX_JIT_FILTERS = [
'deit_*_distilled_patch16_224',
'levit*',
'pit_*_distilled_224',
] + EXCLUDE_FX_FILTERS
@pytest.mark.timeout(120)
@pytest.mark.parametrize(
'model_name', list_models(
exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS + EXCLUDE_FX_JIT_FILTERS, name_matches_cfg=True))
@pytest.mark.parametrize('batch_size', [1])
def test_model_forward_fx_torchscript(model_name, batch_size):
"""Symbolically trace each model, script it, and run single forward pass"""
if not has_fx_feature_extraction:
pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.")
input_size = _get_input_size(model_name=model_name, target=TARGET_JIT_SIZE)
if max(input_size) > MAX_JIT_SIZE:
pytest.skip("Fixed input size model > limit.")
with set_scriptable(True):
model = create_model(model_name, pretrained=False)
model.eval()
model = torch.jit.script(_create_fx_model(model))
with torch.no_grad():
outputs = tuple(model(torch.randn((batch_size, *input_size))).values())
if isinstance(outputs, tuple):
outputs = torch.cat(outputs)
assert outputs.shape[0] == batch_size
assert not torch.isnan(outputs).any(), 'Output included NaNs'

View File

@ -20,6 +20,7 @@ from .evo_norm import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2,\
EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a
from .fast_norm import is_fast_norm, set_fast_norm, fast_group_norm, fast_layer_norm
from .filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d
from .format import Format, get_channel_dim, get_spatial_dim, nchw_to, nhwc_to
from .gather_excite import GatherExcite
from .global_context import GlobalContext
from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple

View File

@ -9,31 +9,37 @@ Both a functional and a nn.Module version of the pooling is provided.
Hacked together by / Copyright 2020 Ross Wightman
"""
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from .format import get_spatial_dim, get_channel_dim
_int_tuple_2_t = Union[int, Tuple[int, int]]
def adaptive_pool_feat_mult(pool_type='avg'):
if pool_type == 'catavgmax':
if pool_type.endswith('catavgmax'):
return 2
else:
return 1
def adaptive_avgmax_pool2d(x, output_size=1):
def adaptive_avgmax_pool2d(x, output_size: _int_tuple_2_t = 1):
x_avg = F.adaptive_avg_pool2d(x, output_size)
x_max = F.adaptive_max_pool2d(x, output_size)
return 0.5 * (x_avg + x_max)
def adaptive_catavgmax_pool2d(x, output_size=1):
def adaptive_catavgmax_pool2d(x, output_size: _int_tuple_2_t = 1):
x_avg = F.adaptive_avg_pool2d(x, output_size)
x_max = F.adaptive_max_pool2d(x, output_size)
return torch.cat((x_avg, x_max), 1)
def select_adaptive_pool2d(x, pool_type='avg', output_size=1):
def select_adaptive_pool2d(x, pool_type='avg', output_size: _int_tuple_2_t = 1):
"""Selectable global pooling function with dynamic input kernel size
"""
if pool_type == 'avg':
@ -49,17 +55,56 @@ def select_adaptive_pool2d(x, pool_type='avg', output_size=1):
return x
class FastAdaptiveAvgPool2d(nn.Module):
def __init__(self, flatten=False):
super(FastAdaptiveAvgPool2d, self).__init__()
class FastAdaptiveAvgPool(nn.Module):
def __init__(self, flatten: bool = False, input_fmt: F = 'NCHW'):
super(FastAdaptiveAvgPool, self).__init__()
self.flatten = flatten
self.dim = get_spatial_dim(input_fmt)
def forward(self, x):
return x.mean((2, 3), keepdim=not self.flatten)
return x.mean(self.dim, keepdim=not self.flatten)
class FastAdaptiveMaxPool(nn.Module):
def __init__(self, flatten: bool = False, input_fmt: str = 'NCHW'):
super(FastAdaptiveMaxPool, self).__init__()
self.flatten = flatten
self.dim = get_spatial_dim(input_fmt)
def forward(self, x):
return x.amax(self.dim, keepdim=not self.flatten)
class FastAdaptiveAvgMaxPool(nn.Module):
def __init__(self, flatten: bool = False, input_fmt: str = 'NCHW'):
super(FastAdaptiveAvgMaxPool, self).__init__()
self.flatten = flatten
self.dim = get_spatial_dim(input_fmt)
def forward(self, x):
x_avg = x.mean(self.dim, keepdim=not self.flatten)
x_max = x.amax(self.dim, keepdim=not self.flatten)
return 0.5 * x_avg + 0.5 * x_max
class FastAdaptiveCatAvgMaxPool(nn.Module):
def __init__(self, flatten: bool = False, input_fmt: str = 'NCHW'):
super(FastAdaptiveCatAvgMaxPool, self).__init__()
self.flatten = flatten
self.dim_reduce = get_spatial_dim(input_fmt)
if flatten:
self.dim_cat = 1
else:
self.dim_cat = get_channel_dim(input_fmt)
def forward(self, x):
x_avg = x.mean(self.dim_reduce, keepdim=not self.flatten)
x_max = x.amax(self.dim_reduce, keepdim=not self.flatten)
return torch.cat((x_avg, x_max), self.dim_cat)
class AdaptiveAvgMaxPool2d(nn.Module):
def __init__(self, output_size=1):
def __init__(self, output_size: _int_tuple_2_t = 1):
super(AdaptiveAvgMaxPool2d, self).__init__()
self.output_size = output_size
@ -68,7 +113,7 @@ class AdaptiveAvgMaxPool2d(nn.Module):
class AdaptiveCatAvgMaxPool2d(nn.Module):
def __init__(self, output_size=1):
def __init__(self, output_size: _int_tuple_2_t = 1):
super(AdaptiveCatAvgMaxPool2d, self).__init__()
self.output_size = output_size
@ -79,26 +124,41 @@ class AdaptiveCatAvgMaxPool2d(nn.Module):
class SelectAdaptivePool2d(nn.Module):
"""Selectable global pooling layer with dynamic input kernel size
"""
def __init__(self, output_size=1, pool_type='fast', flatten=False):
def __init__(
self,
output_size: _int_tuple_2_t = 1,
pool_type: str = 'fast',
flatten: bool = False,
input_fmt: str = 'NCHW',
):
super(SelectAdaptivePool2d, self).__init__()
assert input_fmt in ('NCHW', 'NHWC')
self.pool_type = pool_type or '' # convert other falsy values to empty string for consistent TS typing
self.flatten = nn.Flatten(1) if flatten else nn.Identity()
if pool_type == '':
if not pool_type:
self.pool = nn.Identity() # pass through
elif pool_type == 'fast':
assert output_size == 1
self.pool = FastAdaptiveAvgPool2d(flatten)
self.flatten = nn.Flatten(1) if flatten else nn.Identity()
elif pool_type.startswith('fast') or input_fmt != 'NCHW':
assert output_size == 1, 'Fast pooling and non NCHW input formats require output_size == 1.'
if pool_type.endswith('avgmax'):
self.pool = FastAdaptiveAvgMaxPool(flatten, input_fmt=input_fmt)
elif pool_type.endswith('catavgmax'):
self.pool = FastAdaptiveCatAvgMaxPool(flatten, input_fmt=input_fmt)
elif pool_type.endswith('max'):
self.pool = FastAdaptiveMaxPool(flatten, input_fmt=input_fmt)
else:
self.pool = FastAdaptiveAvgPool(flatten, input_fmt=input_fmt)
self.flatten = nn.Identity()
elif pool_type == 'avg':
self.pool = nn.AdaptiveAvgPool2d(output_size)
elif pool_type == 'avgmax':
self.pool = AdaptiveAvgMaxPool2d(output_size)
elif pool_type == 'catavgmax':
self.pool = AdaptiveCatAvgMaxPool2d(output_size)
elif pool_type == 'max':
self.pool = nn.AdaptiveMaxPool2d(output_size)
else:
assert False, 'Invalid pool type: %s' % pool_type
assert input_fmt == 'NCHW'
if pool_type == 'avgmax':
self.pool = AdaptiveAvgMaxPool2d(output_size)
elif pool_type == 'catavgmax':
self.pool = AdaptiveCatAvgMaxPool2d(output_size)
elif pool_type == 'max':
self.pool = nn.AdaptiveMaxPool2d(output_size)
else:
self.pool = nn.AdaptiveAvgPool2d(output_size)
self.flatten = nn.Flatten(1) if flatten else nn.Identity()
def is_identity(self):
return not self.pool_type

View File

@ -15,13 +15,23 @@ from .create_act import get_act_layer
from .create_norm import get_norm_layer
def _create_pool(num_features, num_classes, pool_type='avg', use_conv=False):
def _create_pool(
num_features: int,
num_classes: int,
pool_type: str = 'avg',
use_conv: bool = False,
input_fmt: Optional[str] = None,
):
flatten_in_pool = not use_conv # flatten when we use a Linear layer after pooling
if not pool_type:
assert num_classes == 0 or use_conv,\
'Pooling can only be disabled if classifier is also removed or conv classifier is used'
flatten_in_pool = False # disable flattening if pooling is pass-through (no pooling)
global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=flatten_in_pool)
global_pool = SelectAdaptivePool2d(
pool_type=pool_type,
flatten=flatten_in_pool,
input_fmt=input_fmt,
)
num_pooled_features = num_features * global_pool.feat_mult()
return global_pool, num_pooled_features
@ -36,9 +46,25 @@ def _create_fc(num_features, num_classes, use_conv=False):
return fc
def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False):
global_pool, num_pooled_features = _create_pool(num_features, num_classes, pool_type, use_conv=use_conv)
fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv)
def create_classifier(
num_features: int,
num_classes: int,
pool_type: str = 'avg',
use_conv: bool = False,
input_fmt: str = 'NCHW',
):
global_pool, num_pooled_features = _create_pool(
num_features,
num_classes,
pool_type,
use_conv=use_conv,
input_fmt=input_fmt,
)
fc = _create_fc(
num_pooled_features,
num_classes,
use_conv=use_conv,
)
return global_pool, fc
@ -52,6 +78,7 @@ class ClassifierHead(nn.Module):
pool_type: str = 'avg',
drop_rate: float = 0.,
use_conv: bool = False,
input_fmt: str = 'NCHW',
):
"""
Args:
@ -64,28 +91,43 @@ class ClassifierHead(nn.Module):
self.drop_rate = drop_rate
self.in_features = in_features
self.use_conv = use_conv
self.input_fmt = input_fmt
self.global_pool, num_pooled_features = _create_pool(in_features, num_classes, pool_type, use_conv=use_conv)
self.fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv)
self.global_pool, self.fc = create_classifier(
in_features,
num_classes,
pool_type,
use_conv=use_conv,
input_fmt=input_fmt,
)
self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity()
def reset(self, num_classes, global_pool=None):
if global_pool is not None:
if global_pool != self.global_pool.pool_type:
self.global_pool, _ = _create_pool(self.in_features, num_classes, global_pool, use_conv=self.use_conv)
self.flatten = nn.Flatten(1) if self.use_conv and global_pool else nn.Identity()
num_pooled_features = self.in_features * self.global_pool.feat_mult()
self.fc = _create_fc(num_pooled_features, num_classes, use_conv=self.use_conv)
def reset(self, num_classes, pool_type=None):
if pool_type is not None and pool_type != self.global_pool.pool_type:
self.global_pool, self.fc = create_classifier(
self.in_features,
num_classes,
pool_type=pool_type,
use_conv=self.use_conv,
input_fmt=self.input_fmt,
)
self.flatten = nn.Flatten(1) if self.use_conv and pool_type else nn.Identity()
else:
num_pooled_features = self.in_features * self.global_pool.feat_mult()
self.fc = _create_fc(
num_pooled_features,
num_classes,
use_conv=self.use_conv,
)
def forward(self, x, pre_logits: bool = False):
x = self.global_pool(x)
if self.drop_rate:
x = F.dropout(x, p=float(self.drop_rate), training=self.training)
if pre_logits:
return x.flatten(1)
else:
x = self.fc(x)
return self.flatten(x)
x = self.fc(x)
return self.flatten(x)
class NormMlpClassifierHead(nn.Module):

58
timm/layers/format.py Normal file
View File

@ -0,0 +1,58 @@
from enum import Enum
from typing import Union
import torch
class Format(str, Enum):
NCHW = 'NCHW'
NHWC = 'NHWC'
NCL = 'NCL'
NLC = 'NLC'
FormatT = Union[str, Format]
def get_spatial_dim(fmt: FormatT):
fmt = Format(fmt)
if fmt is Format.NLC:
dim = (1,)
elif fmt is Format.NCL:
dim = (2,)
elif fmt is Format.NHWC:
dim = (1, 2)
else:
dim = (2, 3)
return dim
def get_channel_dim(fmt: FormatT):
fmt = Format(fmt)
if fmt is Format.NHWC:
dim = 3
elif fmt is Format.NLC:
dim = 2
else:
dim = 1
return dim
def nchw_to(x: torch.Tensor, fmt: Format):
if fmt == Format.NHWC:
x = x.permute(0, 2, 3, 1)
elif fmt == Format.NLC:
x = x.flatten(2).transpose(1, 2)
elif fmt == Format.NCL:
x = x.flatten(2)
return x
def nhwc_to(x: torch.Tensor, fmt: Format):
if fmt == Format.NCHW:
x = x.permute(0, 3, 1, 2)
elif fmt == Format.NLC:
x = x.flatten(1, 2)
elif fmt == Format.NCL:
x = x.flatten(1, 2).transpose(1, 2)
return x

View File

@ -9,12 +9,13 @@ Based on code in:
Hacked together by / Copyright 2020 Ross Wightman
"""
import logging
from typing import List
from typing import List, Optional, Callable
import torch
from torch import nn as nn
import torch.nn.functional as F
from .format import Format, nchw_to
from .helpers import to_2tuple
from .trace_utils import _assert
@ -24,15 +25,18 @@ _logger = logging.getLogger(__name__)
class PatchEmbed(nn.Module):
""" 2D Image to Patch Embedding
"""
output_fmt: Format
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
embed_dim=768,
norm_layer=None,
flatten=True,
bias=True,
img_size: int = 224,
patch_size: int = 16,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer: Optional[Callable] = None,
flatten: bool = True,
output_fmt: Optional[str] = None,
bias: bool = True,
):
super().__init__()
img_size = to_2tuple(img_size)
@ -41,7 +45,13 @@ class PatchEmbed(nn.Module):
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
if output_fmt is not None:
self.flatten = False
self.output_fmt = Format(output_fmt)
else:
# flatten spatial dim and transpose to channels last, kept for bwd compat
self.flatten = flatten
self.output_fmt = Format.NCHW
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
@ -52,7 +62,9 @@ class PatchEmbed(nn.Module):
_assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
elif self.output_fmt != Format.NCHW:
x = nchw_to(x, self.output_fmt)
x = self.norm(x)
return x

View File

@ -15,26 +15,48 @@ from .weight_init import trunc_normal_
def gen_relative_position_index(
q_size: Tuple[int, int],
k_size: Tuple[int, int] = None,
class_token: bool = False) -> torch.Tensor:
k_size: Optional[Tuple[int, int]] = None,
class_token: bool = False,
) -> torch.Tensor:
# Adapted with significant modifications from Swin / BeiT codebases
# get pair-wise relative position index for each token inside the window
q_coords = torch.stack(torch.meshgrid([torch.arange(q_size[0]), torch.arange(q_size[1])])).flatten(1) # 2, Wh, Ww
if k_size is None:
k_coords = q_coords
k_size = q_size
coords = torch.stack(
torch.meshgrid([
torch.arange(q_size[0]),
torch.arange(q_size[1])
])
).flatten(1) # 2, Wh, Ww
relative_coords = coords[:, :, None] - coords[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0) # Qh*Qw, Kh*Kw, 2
num_relative_distance = (2 * q_size[0] - 1) * (2 * q_size[1] - 1) + 3
else:
# different q vs k sizes is a WIP
k_coords = torch.stack(torch.meshgrid([torch.arange(k_size[0]), torch.arange(k_size[1])])).flatten(1)
relative_coords = q_coords[:, :, None] - k_coords[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0) # Wh*Ww, Wh*Ww, 2
# FIXME different q vs k sizes is a WIP, need to better offset the two grids?
q_coords = torch.stack(
torch.meshgrid([
torch.arange(q_size[0]),
torch.arange(q_size[1])
])
).flatten(1) # 2, Wh, Ww
k_coords = torch.stack(
torch.meshgrid([
torch.arange(k_size[0]),
torch.arange(k_size[1])
])
).flatten(1)
relative_coords = q_coords[:, :, None] - k_coords[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0) # Qh*Qw, Kh*Kw, 2
# relative_coords[:, :, 0] += max(q_size[0], k_size[0]) - 1 # shift to start from 0
# relative_coords[:, :, 1] += max(q_size[1], k_size[1]) - 1
# relative_coords[:, :, 0] *= k_size[1] + q_size[1] - 1
# relative_position_index = relative_coords.sum(-1) # Qh*Qw, Kh*Kw
num_relative_distance = (q_size[0] + k_size[0] - 1) * (q_size[1] + q_size[1] - 1) + 3
_, relative_position_index = torch.unique(relative_coords.view(-1, 2), return_inverse=True, dim=0)
if class_token:
# handle cls to token & token 2 cls & cls to cls as per beit for rel pos bias
# NOTE not intended or tested with MLP log-coords
max_size = (max(q_size[0], k_size[0]), max(q_size[1], k_size[1]))
num_relative_distance = (2 * max_size[0] - 1) * (2 * max_size[1] - 1) + 3
relative_position_index = F.pad(relative_position_index, [1, 0, 1, 0])
relative_position_index[0, 0:] = num_relative_distance - 3
relative_position_index[0:, 0] = num_relative_distance - 2
@ -59,7 +81,7 @@ class RelPosBias(nn.Module):
self.relative_position_bias_table = nn.Parameter(torch.zeros(num_relative_distance, num_heads))
self.register_buffer(
"relative_position_index",
gen_relative_position_index(self.window_size, class_token=prefix_tokens > 0),
gen_relative_position_index(self.window_size, class_token=prefix_tokens > 0).view(-1),
persistent=False,
)
@ -69,7 +91,7 @@ class RelPosBias(nn.Module):
trunc_normal_(self.relative_position_bias_table, std=.02)
def get_bias(self) -> torch.Tensor:
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)]
relative_position_bias = self.relative_position_bias_table[self.relative_position_index]
# win_h * win_w, win_h * win_w, num_heads
relative_position_bias = relative_position_bias.view(self.bias_shape).permute(2, 0, 1)
return relative_position_bias.unsqueeze(0).contiguous()
@ -148,7 +170,7 @@ class RelPosMlp(nn.Module):
self.register_buffer(
"relative_position_index",
gen_relative_position_index(window_size),
gen_relative_position_index(window_size).view(-1),
persistent=False)
# get relative_coords_table
@ -160,8 +182,7 @@ class RelPosMlp(nn.Module):
def get_bias(self) -> torch.Tensor:
relative_position_bias = self.mlp(self.rel_coords_log)
if self.relative_position_index is not None:
relative_position_bias = relative_position_bias.view(-1, self.num_heads)[
self.relative_position_index.view(-1)] # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.view(-1, self.num_heads)[self.relative_position_index]
relative_position_bias = relative_position_bias.view(self.bias_shape)
relative_position_bias = relative_position_bias.permute(2, 0, 1)
relative_position_bias = self.bias_act(relative_position_bias)

View File

@ -72,7 +72,8 @@ from ._builder import build_model_with_cfg, load_pretrained, load_custom_pretrai
from ._factory import create_model, parse_model_name, safe_model_name
from ._features import FeatureInfo, FeatureHooks, FeatureHookNet, FeatureListNet, FeatureDictNet
from ._features_fx import FeatureGraphNet, GraphExtractNet, create_feature_extractor, \
register_notrace_module, register_notrace_function
register_notrace_module, is_notrace_module, get_notrace_modules, \
register_notrace_function, is_notrace_function, get_notrace_functions
from ._helpers import clean_state_dict, load_state_dict, load_checkpoint, remap_checkpoint, resume_checkpoint
from ._hub import load_model_config_from_hf, load_state_dict_from_hf, push_to_hf_hub
from ._manipulate import model_parameters, named_apply, named_modules, named_modules_with_params, \

View File

@ -392,6 +392,9 @@ def build_model_with_cfg(
# Wrap the model in a feature extraction module if enabled
if features:
feature_cls = FeatureListNet
output_fmt = getattr(model, 'output_fmt', None)
if output_fmt is not None:
feature_cfg.setdefault('output_fmt', output_fmt)
if 'feature_cls' in feature_cfg:
feature_cls = feature_cfg.pop('feature_cls')
if isinstance(feature_cls, str):
@ -403,7 +406,7 @@ def build_model_with_cfg(
else:
assert False, f'Unknown feature class {feature_cls}'
model = feature_cls(model, **feature_cfg)
model.pretrained_cfg = pretrained_cfg_for_features(pretrained_cfg) # add back default_cfg
model.default_cfg = model.pretrained_cfg # alias for backwards compat
model.pretrained_cfg = pretrained_cfg_for_features(pretrained_cfg) # add back pretrained cfg
model.default_cfg = model.pretrained_cfg # alias for rename backwards compat (default_cfg -> pretrained_cfg)
return model

View File

@ -17,6 +17,8 @@ import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from timm.layers import Format
__all__ = ['FeatureInfo', 'FeatureHooks', 'FeatureDictNet', 'FeatureListNet', 'FeatureHookNet']
@ -98,7 +100,7 @@ class FeatureHooks:
self,
hooks: Sequence[str],
named_modules: dict,
out_map: Sequence[Union[int, str]] = None,
return_map: Sequence[Union[int, str]] = None,
default_hook_type: str = 'forward',
):
# setup feature hooks
@ -107,7 +109,7 @@ class FeatureHooks:
for i, h in enumerate(hooks):
hook_name = h['module']
m = modules[hook_name]
hook_id = out_map[i] if out_map else hook_name
hook_id = return_map[i] if return_map else hook_name
hook_fn = partial(self._collect_output_hook, hook_id)
hook_type = h.get('hook_type', default_hook_type)
if hook_type == 'forward_pre':
@ -153,11 +155,11 @@ def _get_feature_info(net, out_indices):
assert False, "Provided feature_info is not valid"
def _get_return_layers(feature_info, out_map):
def _get_return_layers(feature_info, return_map):
module_names = feature_info.module_name()
return_layers = {}
for i, name in enumerate(module_names):
return_layers[name] = out_map[i] if out_map is not None else feature_info.out_indices[i]
return_layers[name] = return_map[i] if return_map is not None else feature_info.out_indices[i]
return return_layers
@ -180,7 +182,8 @@ class FeatureDictNet(nn.ModuleDict):
self,
model: nn.Module,
out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
out_map: Sequence[Union[int, str]] = None,
return_map: Sequence[Union[int, str]] = None,
output_fmt: str = 'NCHW',
feature_concat: bool = False,
flatten_sequential: bool = False,
):
@ -188,18 +191,19 @@ class FeatureDictNet(nn.ModuleDict):
Args:
model: Model from which to extract features.
out_indices: Output indices of the model features to extract.
out_map: Return id mapping for each output index, otherwise str(index) is used.
return_map: Return id mapping for each output index, otherwise str(index) is used.
feature_concat: Concatenate intermediate features that are lists or tuples instead of selecting
first element e.g. `x[0]`
flatten_sequential: Flatten first two-levels of sequential modules in model (re-writes model modules)
"""
super(FeatureDictNet, self).__init__()
self.feature_info = _get_feature_info(model, out_indices)
self.output_fmt = Format(output_fmt)
self.concat = feature_concat
self.grad_checkpointing = False
self.return_layers = {}
return_layers = _get_return_layers(self.feature_info, out_map)
return_layers = _get_return_layers(self.feature_info, return_map)
modules = _module_list(model, flatten_sequential=flatten_sequential)
remaining = set(return_layers.keys())
layers = OrderedDict()
@ -253,6 +257,7 @@ class FeatureListNet(FeatureDictNet):
self,
model: nn.Module,
out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
output_fmt: str = 'NCHW',
feature_concat: bool = False,
flatten_sequential: bool = False,
):
@ -264,9 +269,10 @@ class FeatureListNet(FeatureDictNet):
first element e.g. `x[0]`
flatten_sequential: Flatten first two-levels of sequential modules in model (re-writes model modules)
"""
super(FeatureListNet, self).__init__(
super().__init__(
model,
out_indices=out_indices,
output_fmt=output_fmt,
feature_concat=feature_concat,
flatten_sequential=flatten_sequential,
)
@ -292,8 +298,9 @@ class FeatureHookNet(nn.ModuleDict):
self,
model: nn.Module,
out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
out_map: Sequence[Union[int, str]] = None,
out_as_dict: bool = False,
return_map: Sequence[Union[int, str]] = None,
return_dict: bool = False,
output_fmt: str = 'NCHW',
no_rewrite: bool = False,
flatten_sequential: bool = False,
default_hook_type: str = 'forward',
@ -303,17 +310,18 @@ class FeatureHookNet(nn.ModuleDict):
Args:
model: Model from which to extract features.
out_indices: Output indices of the model features to extract.
out_map: Return id mapping for each output index, otherwise str(index) is used.
out_as_dict: Output features as a dict.
return_map: Return id mapping for each output index, otherwise str(index) is used.
return_dict: Output features as a dict.
no_rewrite: Enforce that model is not re-written if True, ie no modules are removed / changed.
flatten_sequential arg must also be False if this is set True.
flatten_sequential: Re-write modules by flattening first two levels of nn.Sequential containers.
default_hook_type: The default hook type to use if not specified in model.feature_info.
"""
super(FeatureHookNet, self).__init__()
super().__init__()
assert not torch.jit.is_scripting()
self.feature_info = _get_feature_info(model, out_indices)
self.out_as_dict = out_as_dict
self.return_dict = return_dict
self.output_fmt = Format(output_fmt)
self.grad_checkpointing = False
layers = OrderedDict()
@ -340,7 +348,7 @@ class FeatureHookNet(nn.ModuleDict):
break
assert not remaining, f'Return layers ({remaining}) are not present in model'
self.update(layers)
self.hooks = FeatureHooks(hooks, model.named_modules(), out_map=out_map)
self.hooks = FeatureHooks(hooks, model.named_modules(), return_map=return_map)
def set_grad_checkpointing(self, enable: bool = True):
self.grad_checkpointing = enable
@ -356,4 +364,4 @@ class FeatureHookNet(nn.ModuleDict):
else:
x = module(x)
out = self.hooks.get_output(x.device)
return out if self.out_as_dict else list(out.values())
return out if self.return_dict else list(out.values())

View File

@ -19,6 +19,11 @@ from timm.layers import Conv2dSame, ScaledStdConv2dSame, CondConv2d, StdConv2dSa
from timm.layers.non_local_attn import BilinearAttnTransform
from timm.layers.pool2d_same import MaxPool2dSame, AvgPool2dSame
__all__ = ['register_notrace_module', 'is_notrace_module', 'get_notrace_modules',
'register_notrace_function', 'is_notrace_function', 'get_notrace_functions',
'create_feature_extractor', 'FeatureGraphNet', 'GraphExtractNet']
# NOTE: By default, any modules from timm.models.layers that we want to treat as leaf modules go here
# BUT modules from timm.models should use the registration mechanism below
_leaf_modules = {
@ -35,10 +40,6 @@ except ImportError:
pass
__all__ = ['register_notrace_module', 'register_notrace_function', 'create_feature_extractor',
'FeatureGraphNet', 'GraphExtractNet']
def register_notrace_module(module: Type[nn.Module]):
"""
Any module not under timm.models.layers should get this decorator if we don't want to trace through it.
@ -47,6 +48,14 @@ def register_notrace_module(module: Type[nn.Module]):
return module
def is_notrace_module(module: Type[nn.Module]):
return module in _leaf_modules
def get_notrace_modules():
return list(_leaf_modules)
# Functions we want to autowrap (treat them as leaves)
_autowrap_functions = set()
@ -59,6 +68,14 @@ def register_notrace_function(func: Callable):
return func
def is_notrace_function(func: Callable):
return func in _autowrap_functions
def get_notrace_functions():
return list(_autowrap_functions)
def create_feature_extractor(model: nn.Module, return_nodes: Union[Dict[str, str], List[str]]):
assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
return _create_feature_extractor(

View File

@ -406,7 +406,7 @@ class ConvNeXt(nn.Module):
return self.head.fc
def reset_classifier(self, num_classes=0, global_pool=None):
self.head.reset(num_classes, global_pool=global_pool)
self.head.reset(num_classes, global_pool)
def forward_features(self, x):
x = self.stem(x)
@ -415,7 +415,7 @@ class ConvNeXt(nn.Module):
return x
def forward_head(self, x, pre_logits: bool = False):
return self.head(x, pre_logits=pre_logits)
return self.head(x, pre_logits=True) if pre_logits else self.head(x)
def forward(self, x):
x = self.forward_features(x)

View File

@ -359,10 +359,9 @@ class DLA(nn.Module):
if self.drop_rate > 0.:
x = F.dropout(x, p=self.drop_rate, training=self.training)
if pre_logits:
return x.flatten(1)
else:
x = self.fc(x)
return self.flatten(x)
x = self.fc(x)
return self.flatten(x)
def forward(self, x):
x = self.forward_features(x)

View File

@ -298,10 +298,9 @@ class DPN(nn.Module):
if self.drop_rate > 0.:
x = F.dropout(x, p=self.drop_rate, training=self.training)
if pre_logits:
return x.flatten(1)
else:
x = self.classifier(x)
return self.flatten(x)
x = self.classifier(x)
return self.flatten(x)
def forward(self, x):
x = self.forward_features(x)

View File

@ -18,6 +18,7 @@ This impl is/has:
# Written by Jianwei Yang (jianwyan@microsoft.com)
# --------------------------------------------------------
from functools import partial
from typing import Callable, Optional, Tuple
import torch
import torch.nn as nn
@ -35,15 +36,15 @@ __all__ = ['FocalNet']
class FocalModulation(nn.Module):
def __init__(
self,
dim,
dim: int,
focal_window,
focal_level,
focal_factor=2,
bias=True,
use_post_norm=False,
normalize_modulator=False,
proj_drop=0.,
norm_layer=LayerNorm2d,
focal_level: int,
focal_factor: int = 2,
bias: bool = True,
use_post_norm: bool = False,
normalize_modulator: bool = False,
proj_drop: float = 0.,
norm_layer: Callable = LayerNorm2d,
):
super().__init__()
@ -118,36 +119,38 @@ class LayerScale2d(nn.Module):
class FocalNetBlock(nn.Module):
r""" Focal Modulation Network Block.
Args:
dim (int): Number of input channels.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
proj_drop (float, optional): Dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
focal_level (int): Number of focal levels.
focal_window (int): Focal window size at first focal level
layerscale_value (float): Initial layerscale value
use_post_norm (bool): Whether to use layernorm after modulation
""" Focal Modulation Network Block.
"""
def __init__(
self,
dim,
mlp_ratio=4.,
focal_level=1,
focal_window=3,
use_post_norm=False,
use_post_norm_in_modulation=False,
normalize_modulator=False,
layerscale_value=1e-4,
proj_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=LayerNorm2d,
dim: int,
mlp_ratio: float = 4.,
focal_level: int = 1,
focal_window: int = 3,
use_post_norm: bool = False,
use_post_norm_in_modulation: bool = False,
normalize_modulator: bool = False,
layerscale_value: float = 1e-4,
proj_drop: float = 0.,
drop_path: float = 0.,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm2d,
):
"""
Args:
dim: Number of input channels.
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
focal_level: Number of focal levels.
focal_window: Focal window size at first focal level.
use_post_norm: Whether to use layer norm after modulation.
use_post_norm_in_modulation: Whether to use layer norm in modulation.
layerscale_value: Initial layerscale value.
proj_drop: Dropout rate.
drop_path: Stochastic depth rate.
act_layer: Activation layer.
norm_layer: Normalization layer.
"""
super().__init__()
self.dim = dim
self.mlp_ratio = mlp_ratio
@ -197,42 +200,45 @@ class FocalNetBlock(nn.Module):
return x
class BasicLayer(nn.Module):
class FocalNetStage(nn.Module):
""" A basic Focal Transformer layer for one stage.
Args:
dim (int): Number of input channels.
depth (int): Number of blocks.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
drop (float, optional): Dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (bool): Downsample layer at start of the layer. Default: True
focal_level (int): Number of focal levels
focal_window (int): Focal window size at first focal level
layerscale_value (float): Initial layerscale value
use_post_norm (bool): Whether to use layer norm after modulation
"""
def __init__(
self,
dim,
out_dim,
depth,
mlp_ratio=4.,
downsample=True,
focal_level=1,
focal_window=1,
use_overlap_down=False,
use_post_norm=False,
use_post_norm_in_modulation=False,
normalize_modulator=False,
layerscale_value=1e-4,
proj_drop=0.,
drop_path=0.,
norm_layer=LayerNorm2d,
dim: int,
out_dim: int,
depth: int,
mlp_ratio: float = 4.,
downsample: bool = True,
focal_level: int = 1,
focal_window: int = 1,
use_overlap_down: bool = False,
use_post_norm: bool = False,
use_post_norm_in_modulation: bool = False,
normalize_modulator: bool = False,
layerscale_value: float = 1e-4,
proj_drop: float = 0.,
drop_path: float = 0.,
norm_layer: Callable = LayerNorm2d,
):
"""
Args:
dim: Number of input channels.
out_dim: Number of output channels.
depth: Number of blocks.
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
downsample: Downsample layer at start of the layer.
focal_level: Number of focal levels
focal_window: Focal window size at first focal level
use_overlap_down: User overlapped convolution in downsample layer.
use_post_norm: Whether to use layer norm after modulation.
use_post_norm_in_modulation: Whether to use layer norm in modulation.
layerscale_value: Initial layerscale value
proj_drop: Dropout rate for projections.
drop_path: Stochastic depth rate.
norm_layer: Normalization layer.
"""
super().__init__()
self.dim = dim
self.depth = depth
@ -281,22 +287,24 @@ class BasicLayer(nn.Module):
class Downsample(nn.Module):
r"""
Args:
in_chs (int): Number of input image channels
out_chs (int): Number of linear projection output channels
stride (int): Downsample stride. Default: 4.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(
self,
in_chs,
out_chs,
stride=4,
overlap=False,
norm_layer=None,
in_chs: int,
out_chs: int,
stride: int = 4,
overlap: bool = False,
norm_layer: Optional[Callable] = None,
):
"""
Args:
in_chs: Number of input image channels.
out_chs: Number of linear projection output channels.
stride: Downsample stride.
overlap: Use overlapping convolutions if True.
norm_layer: Normalization layer.
"""
super().__init__()
self.stride = stride
padding = 0
@ -317,49 +325,47 @@ class Downsample(nn.Module):
class FocalNet(nn.Module):
r""" Focal Modulation Networks (FocalNets)
Args:
in_chans (int): Number of input image channels. Default: 3
num_classes (int): Number of classes for classification head. Default: 1000
embed_dim (int): Patch embedding dimension. Default: 96
depths (tuple(int)): Depth of each Focal Transformer layer.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
focal_levels (list): How many focal levels at all stages. Note that this excludes the finest-grain level.
Default: [1, 1, 1, 1]
focal_windows (list): The focal window size at all stages. Default: [7, 5, 3, 1]
use_overlap_down (bool): Whether to use convolutional embedding.
use_post_norm (bool): Whether to use layernorm after modulation (it helps stablize training of large models)
layerscale_value (float): Value for layer scale. Default: 1e-4
drop_rate (float): Dropout rate. Default: 0
drop_path_rate (float): Stochastic depth rate. Default: 0.1
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
"""" Focal Modulation Networks (FocalNets)
"""
def __init__(
self,
in_chans=3,
num_classes=1000,
global_pool='avg',
embed_dim=96,
depths=(2, 2, 6, 2),
mlp_ratio=4.,
focal_levels=(2, 2, 2, 2),
focal_windows=(3, 3, 3, 3),
use_overlap_down=False,
use_post_norm=False,
use_post_norm_in_modulation=False,
normalize_modulator=False,
head_hidden_size=None,
head_init_scale=1.0,
layerscale_value=None,
drop_rate=0.,
proj_drop_rate=0.,
drop_path_rate=0.1,
norm_layer=partial(LayerNorm2d, eps=1e-5),
**kwargs,
in_chans: int = 3,
num_classes: int = 1000,
global_pool: str = 'avg',
embed_dim: int = 96,
depths: Tuple[int, ...] = (2, 2, 6, 2),
mlp_ratio: float = 4.,
focal_levels: Tuple[int, ...] = (2, 2, 2, 2),
focal_windows: Tuple[int, ...] = (3, 3, 3, 3),
use_overlap_down: bool = False,
use_post_norm: bool = False,
use_post_norm_in_modulation: bool = False,
normalize_modulator: bool = False,
head_hidden_size: Optional[int] = None,
head_init_scale: float = 1.0,
layerscale_value: Optional[float] = None,
drop_rate: bool = 0.,
proj_drop_rate: bool = 0.,
drop_path_rate: bool = 0.1,
norm_layer: Callable = partial(LayerNorm2d, eps=1e-5),
):
"""
Args:
in_chans: Number of input image channels.
num_classes: Number of classes for classification head.
embed_dim: Patch embedding dimension.
depths: Depth of each Focal Transformer layer.
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
focal_levels: How many focal levels at all stages. Note that this excludes the finest-grain level.
focal_windows: The focal window size at all stages.
use_overlap_down: Whether to use convolutional embedding.
use_post_norm: Whether to use layernorm after modulation (it helps stablize training of large models)
layerscale_value: Value for layer scale.
drop_rate: Dropout rate.
drop_path_rate: Stochastic depth rate.
norm_layer: Normalization layer.
"""
super().__init__()
self.num_layers = len(depths)
@ -382,7 +388,7 @@ class FocalNet(nn.Module):
layers = []
for i_layer in range(self.num_layers):
out_dim = embed_dim[i_layer]
layer = BasicLayer(
layer = FocalNetStage(
dim=in_dim,
out_dim=out_dim,
depth=depths[i_layer],
@ -438,10 +444,10 @@ class FocalNet(nn.Module):
@torch.jit.ignore
def get_classifier(self):
return self.classifier.fc
return self.head.fc
def reset_classifier(self, num_classes, global_pool=None):
self.classifier.reset(num_classes, global_pool=global_pool)
self.head.reset(num_classes, pool_type=global_pool)
def forward_features(self, x):
x = self.stem(x)
@ -475,7 +481,7 @@ def _init_weights(module, name=None, head_init_scale=1.0):
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': .9, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'stem.proj', 'classifier': 'head.fc',
@ -498,19 +504,19 @@ default_cfgs = {
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_base_lrf.pth'),
"focalnet_large_fl3": _cfg(
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_large_lrf_384.pth',
input_size=(3, 384, 384), crop_pct=1.0, num_classes=21842),
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=21842),
"focalnet_large_fl4": _cfg(
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_large_lrf_384_fl4.pth',
input_size=(3, 384, 384), crop_pct=1.0, num_classes=21842),
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=21842),
"focalnet_xlarge_fl3": _cfg(
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_xlarge_lrf_384.pth',
input_size=(3, 384, 384), crop_pct=1.0, num_classes=21842),
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=21842),
"focalnet_xlarge_fl4": _cfg(
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_xlarge_lrf_384_fl4.pth',
input_size=(3, 384, 384), crop_pct=1.0, num_classes=21842),
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=21842),
"focalnet_huge_fl3": _cfg(
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_huge_lrf_224.pth',
num_classes=0),
num_classes=21842),
"focalnet_huge_fl4": _cfg(
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_huge_lrf_224_fl4.pth',
num_classes=0),
@ -533,7 +539,7 @@ def checkpoint_filter_fn(state_dict, model: FocalNet):
k = re.sub(r'norm([0-9])', r'norm\1_post', k)
k = k.replace('ln.', 'norm.')
k = k.replace('head', 'head.fc')
if dest_dict[k].shape != v.shape:
if k in dest_dict and dest_dict[k].numel() == v.numel() and dest_dict[k].shape != v.shape:
v = v.reshape(dest_dict[k].shape)
out_dict[k] = v
return out_dict

View File

@ -29,12 +29,11 @@ import torch.utils.checkpoint as checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, to_2tuple, to_ntuple, Mlp, ClassifierHead, LayerNorm2d, \
get_attn, get_act_layer, get_norm_layer, _assert
get_attn, get_act_layer, get_norm_layer, RelPosBias, _assert
from ._builder import build_model_with_cfg
from ._features_fx import register_notrace_function
from ._manipulate import named_apply
from ._registry import register_model
from .vision_transformer_relpos import RelPosBias # FIXME move to common location
__all__ = ['GlobalContextVit']
@ -222,7 +221,7 @@ class WindowAttentionGlobal(nn.Module):
q, k, v = qkv.unbind(0)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
attn = q @ k.transpose(-2, -1).contiguous() # NOTE contiguous() fixes an odd jit bug in PyTorch 2.0
attn = self.rel_pos(attn)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)

View File

@ -145,13 +145,12 @@ class MobileNetV3(nn.Module):
x = self.global_pool(x)
x = self.conv_head(x)
x = self.act2(x)
x = self.flatten(x)
if pre_logits:
return x.flatten(1)
else:
x = self.flatten(x)
if self.drop_rate > 0.:
x = F.dropout(x, p=self.drop_rate, training=self.training)
return self.classifier(x)
return x
if self.drop_rate > 0.:
x = F.dropout(x, p=self.drop_rate, training=self.training)
return self.classifier(x)
def forward(self, x):
x = self.forward_features(x)

View File

@ -17,86 +17,24 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W
# --------------------------------------------------------
import logging
import math
from typing import Optional
from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, to_ntuple, trunc_normal_, _assert
from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, to_ntuple, trunc_normal_, _assert, ClassifierHead
from ._builder import build_model_with_cfg
from ._features_fx import register_notrace_function
from ._manipulate import checkpoint_seq, named_apply
from ._registry import register_model
from .vision_transformer import checkpoint_filter_fn, get_init_weights_vit
from .vision_transformer import get_init_weights_vit
__all__ = ['SwinTransformer'] # model_registry will add each entrypoint fn to this
_logger = logging.getLogger(__name__)
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'patch_embed.proj', 'classifier': 'head',
**kwargs
}
default_cfgs = {
'swin_base_patch4_window12_384': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22kto1k.pth',
input_size=(3, 384, 384), crop_pct=1.0),
'swin_base_patch4_window7_224': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22kto1k.pth',
),
'swin_large_patch4_window12_384': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22kto1k.pth',
input_size=(3, 384, 384), crop_pct=1.0),
'swin_large_patch4_window7_224': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22kto1k.pth',
),
'swin_small_patch4_window7_224': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth',
),
'swin_tiny_patch4_window7_224': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth',
),
'swin_base_patch4_window12_384_in22k': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth',
input_size=(3, 384, 384), crop_pct=1.0, num_classes=21841),
'swin_base_patch4_window7_224_in22k': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth',
num_classes=21841),
'swin_large_patch4_window12_384_in22k': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth',
input_size=(3, 384, 384), crop_pct=1.0, num_classes=21841),
'swin_large_patch4_window7_224_in22k': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth',
num_classes=21841),
'swin_s3_tiny_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/s3_t-1d53f6a8.pth'
),
'swin_s3_small_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/s3_s-3bb4c69d.pth'
),
'swin_s3_base_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/s3_b-a1e95db4.pth'
)
}
_int_or_tuple_2_t = Union[int, Tuple[int, int]]
def window_partition(x, window_size: int):
@ -132,7 +70,7 @@ def window_reverse(windows, window_size: int, H: int, W: int):
return x
def get_relative_position_index(win_h, win_w):
def get_relative_position_index(win_h: int, win_w: int):
# get pair-wise relative position index for each token inside the window
coords = torch.stack(torch.meshgrid([torch.arange(win_h), torch.arange(win_w)])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
@ -145,21 +83,30 @@ def get_relative_position_index(win_h, win_w):
class WindowAttention(nn.Module):
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
head_dim (int): Number of channels per head (dim // num_heads if not set)
window_size (tuple[int]): The height and width of the window.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports shifted and non-shifted windows.
"""
def __init__(self, dim, num_heads, head_dim=None, window_size=7, qkv_bias=True, attn_drop=0., proj_drop=0.):
def __init__(
self,
dim: int,
num_heads: int,
head_dim: Optional[int] = None,
window_size: _int_or_tuple_2_t = 7,
qkv_bias: bool = True,
attn_drop: float = 0.,
proj_drop: float = 0.,
):
"""
Args:
dim: Number of input channels.
num_heads: Number of attention heads.
head_dim: Number of channels per head (dim // num_heads if not set)
window_size: The height and width of the window.
qkv_bias: If True, add a learnable bias to query, key, value.
attn_drop: Dropout ratio of attention weight.
proj_drop: Dropout ratio of output.
"""
super().__init__()
self.dim = dim
self.window_size = to_2tuple(window_size) # Wh, Ww
@ -198,7 +145,7 @@ class WindowAttention(nn.Module):
"""
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
q, k, v = qkv.unbind(0)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
@ -206,7 +153,7 @@ class WindowAttention(nn.Module):
if mask is not None:
num_win = mask.shape[0]
attn = attn.view(B_ // num_win, num_win, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, num_win, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
@ -221,28 +168,41 @@ class WindowAttention(nn.Module):
class SwinTransformerBlock(nn.Module):
r""" Swin Transformer Block.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion.
window_size (int): Window size.
num_heads (int): Number of attention heads.
head_dim (int): Enforce the number of channels per head
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
""" Swin Transformer Block.
"""
def __init__(
self, dim, input_resolution, num_heads=4, head_dim=None, window_size=7, shift_size=0,
mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
self,
dim: int,
input_resolution: _int_or_tuple_2_t,
num_heads: int = 4,
head_dim: Optional[int] = None,
window_size: _int_or_tuple_2_t = 7,
shift_size: int = 0,
mlp_ratio: float = 4.,
qkv_bias: bool = True,
drop: float = 0.,
attn_drop: float = 0.,
drop_path: float = 0.,
act_layer: Callable = nn.GELU,
norm_layer: Callable = nn.LayerNorm,
):
"""
Args:
dim: Number of input channels.
input_resolution: Input resolution.
window_size: Window size.
num_heads: Number of attention heads.
head_dim: Enforce the number of channels per head
shift_size: Shift size for SW-MSA.
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
qkv_bias: If True, add a learnable bias to query, key, value.
drop: Dropout rate.
attn_drop: Attention dropout rate.
drop_path: Stochastic depth rate.
act_layer: Activation layer.
norm_layer: Normalization layer.
"""
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
@ -257,12 +217,23 @@ class SwinTransformerBlock(nn.Module):
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim, num_heads=num_heads, head_dim=head_dim, window_size=to_2tuple(self.window_size),
qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
dim,
num_heads=num_heads,
head_dim=head_dim,
window_size=to_2tuple(self.window_size),
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=drop,
)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
self.mlp = Mlp(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
drop=drop,
)
if self.shift_size > 0:
# calculate attention mask for SW-MSA
@ -285,17 +256,15 @@ class SwinTransformerBlock(nn.Module):
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
else:
attn_mask = None
self.register_buffer("attn_mask", attn_mask)
def forward(self, x):
H, W = self.input_resolution
B, L, C = x.shape
_assert(L == H * W, "input feature has wrong size")
B, H, W, C = x.shape
_assert(H == self.input_resolution[0], "input feature has wrong size")
_assert(W == self.input_resolution[1], "input feature has wrong size")
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# cyclic shift
if self.shift_size > 0:
@ -319,176 +288,227 @@ class SwinTransformerBlock(nn.Module):
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
x = x.view(B, H * W, C)
# FFN
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
x = x.reshape(B, -1, C)
x = x + self.drop_path(self.mlp(self.norm2(x)))
x = x.reshape(B, H, W, C)
return x
class PatchMerging(nn.Module):
r""" Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
""" Patch Merging Layer.
"""
def __init__(self, input_resolution, dim, out_dim=None, norm_layer=nn.LayerNorm):
def __init__(self, dim: int, out_dim: Optional[int] = None, norm_layer: Callable = nn.LayerNorm):
"""
Args:
dim: Number of input channels.
out_dim: Number of output channels (or 2 * dim if None)
norm_layer: Normalization layer.
"""
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.out_dim = out_dim or 2 * dim
self.norm = norm_layer(4 * dim)
self.reduction = nn.Linear(4 * dim, self.out_dim, bias=False)
def forward(self, x):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
B, L, C = x.shape
_assert(L == H * W, "input feature has wrong size")
_assert(H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even.")
x = x.view(B, H, W, C)
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
B, H, W, C = x.shape
_assert(H % 2 == 0, f"x height ({H}) is not even.")
_assert(W % 2 == 0, f"x width ({W}) is not even.")
x = x.reshape(B, H // 2, 2, W // 2, 2, C).permute(0, 1, 3, 4, 2, 5).flatten(3)
x = self.norm(x)
x = self.reduction(x)
return x
class BasicLayer(nn.Module):
class SwinTransformerStage(nn.Module):
""" A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
head_dim (int): Channels per head (dim // num_heads if not set)
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
"""
def __init__(
self, dim, out_dim, input_resolution, depth, num_heads=4, head_dim=None,
window_size=7, mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, downsample=None):
self,
dim: int,
out_dim: int,
input_resolution: Tuple[int, int],
depth: int,
downsample: bool = True,
num_heads: int = 4,
head_dim: Optional[int] = None,
window_size: _int_or_tuple_2_t = 7,
mlp_ratio: float = 4.,
qkv_bias: bool = True,
drop: float = 0.,
attn_drop: float = 0.,
drop_path: Union[List[float], float] = 0.,
norm_layer: Callable = nn.LayerNorm,
output_nchw: bool = False,
):
"""
Args:
dim: Number of input channels.
input_resolution: Input resolution.
depth: Number of blocks.
num_heads: Number of attention heads.
head_dim: Channels per head (dim // num_heads if not set)
window_size: Local window size.
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
qkv_bias: If True, add a learnable bias to query, key, value.
drop: Dropout rate.
attn_drop: Attention dropout rate.
drop_path: Stochastic depth rate.
norm_layer: Normalization layer.
downsample: Downsample layer at the end of the layer.
"""
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.output_resolution = tuple(i // 2 for i in input_resolution) if downsample else input_resolution
self.depth = depth
self.use_nchw = output_nchw
self.grad_checkpointing = False
# patch merging layer
if downsample:
self.downsample = PatchMerging(
dim=dim,
out_dim=out_dim,
norm_layer=norm_layer,
)
else:
assert dim == out_dim
self.downsample = nn.Identity()
# build blocks
self.blocks = nn.Sequential(*[
SwinTransformerBlock(
dim=dim, input_resolution=input_resolution, num_heads=num_heads, head_dim=head_dim,
window_size=window_size, shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop, attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer)
dim=out_dim,
input_resolution=self.output_resolution,
num_heads=num_heads,
head_dim=head_dim,
window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer,
)
for i in range(depth)])
# patch merging layer
if downsample is not None:
self.downsample = downsample(input_resolution, dim=dim, out_dim=out_dim, norm_layer=norm_layer)
else:
self.downsample = None
def forward(self, x):
if self.use_nchw:
x = x.permute(0, 2, 3, 1) # NCHW -> NHWC
x = self.downsample(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x)
else:
x = self.blocks(x)
if self.downsample is not None:
x = self.downsample(x)
if self.use_nchw:
x = x.permute(0, 3, 1, 2) # NHWC -> NCHW
return x
class SwinTransformer(nn.Module):
r""" Swin Transformer
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
https://arxiv.org/pdf/2103.14030
""" Swin Transformer
Args:
img_size (int | tuple(int)): Input image size. Default 224
patch_size (int | tuple(int)): Patch size. Default: 4
in_chans (int): Number of input image channels. Default: 3
num_classes (int): Number of classes for classification head. Default: 1000
embed_dim (int): Patch embedding dimension. Default: 96
depths (tuple(int)): Depth of each Swin Transformer layer.
num_heads (tuple(int)): Number of attention heads in different layers.
head_dim (int, tuple(int)):
window_size (int): Window size. Default: 7
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
drop_rate (float): Dropout rate. Default: 0
attn_drop_rate (float): Attention dropout rate. Default: 0
drop_path_rate (float): Stochastic depth rate. Default: 0.1
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
patch_norm (bool): If True, add normalization after patch embedding. Default: True
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
https://arxiv.org/pdf/2103.14030
"""
def __init__(
self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, global_pool='avg',
embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), head_dim=None,
window_size=7, mlp_ratio=4., qkv_bias=True,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm, ape=False, patch_norm=True, weight_init='', **kwargs):
self,
img_size: _int_or_tuple_2_t = 224,
patch_size: int = 4,
in_chans: int = 3,
num_classes: int = 1000,
global_pool: str = 'avg',
embed_dim: int = 96,
depths: Tuple[int, ...] = (2, 2, 6, 2),
num_heads: Tuple[int, ...] = (3, 6, 12, 24),
head_dim: Optional[int] = None,
window_size: _int_or_tuple_2_t = 7,
mlp_ratio: float = 4.,
qkv_bias: bool = True,
drop_rate: float = 0.,
attn_drop_rate: float = 0.,
drop_path_rate: float = 0.1,
norm_layer: Union[str, Callable] = nn.LayerNorm,
weight_init: str = '',
output_fmt: str = 'NHWC',
**kwargs,
):
"""
Args:
img_size: Input image size.
patch_size: Patch size.
in_chans: Number of input image channels.
num_classes: Number of classes for classification head.
embed_dim: Patch embedding dimension.
depths: Depth of each Swin Transformer layer.
num_heads: Number of attention heads in different layers.
head_dim: Dimension of self-attention heads.
window_size: Window size.
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
qkv_bias: If True, add a learnable bias to query, key, value.
drop_rate: Dropout rate.
attn_drop_rate (float): Attention dropout rate.
drop_path_rate (float): Stochastic depth rate.
norm_layer (nn.Module): Normalization layer.
"""
super().__init__()
assert global_pool in ('', 'avg')
assert output_fmt in ('NCHW', 'NHWC')
self.num_classes = num_classes
self.global_pool = global_pool
self.output_fmt = output_fmt
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
self.output_nchw = self.output_fmt == 'NCHW' # bool flag for fwd
self.feature_info = []
if not isinstance(embed_dim, (tuple, list)):
embed_dim = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
# split image into non-overlapping patches
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
norm_layer=norm_layer if patch_norm else None)
num_patches = self.patch_embed.num_patches
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim[0],
norm_layer=norm_layer,
output_fmt='NHWC',
)
self.patch_grid = self.patch_embed.grid_size
# absolute position embedding
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) if ape else None
self.pos_drop = nn.Dropout(p=drop_rate)
# build layers
if not isinstance(embed_dim, (tuple, list)):
embed_dim = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
embed_out_dim = embed_dim[1:] + [None]
head_dim = to_ntuple(self.num_layers)(head_dim)
window_size = to_ntuple(self.num_layers)(window_size)
mlp_ratio = to_ntuple(self.num_layers)(mlp_ratio)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
layers = []
in_dim = embed_dim[0]
scale = 1
for i in range(self.num_layers):
layers += [BasicLayer(
dim=embed_dim[i],
out_dim=embed_out_dim[i],
input_resolution=(self.patch_grid[0] // (2 ** i), self.patch_grid[1] // (2 ** i)),
out_dim = embed_dim[i]
layers += [SwinTransformerStage(
dim=in_dim,
out_dim=out_dim,
input_resolution=(
self.patch_grid[0] // scale,
self.patch_grid[1] // scale
),
depth=depths[i],
downsample=i > 0,
num_heads=num_heads[i],
head_dim=head_dim[i],
window_size=window_size[i],
@ -496,29 +516,36 @@ class SwinTransformer(nn.Module):
qkv_bias=qkv_bias,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i]):sum(depths[:i + 1])],
drop_path=dpr[i],
norm_layer=norm_layer,
downsample=PatchMerging if (i < self.num_layers - 1) else None
output_nchw=self.output_nchw,
)]
in_dim = out_dim
if i > 0:
scale *= 2
self.feature_info += [dict(num_chs=out_dim, reduction=4 * scale, module=f'layers.{i}')]
self.layers = nn.Sequential(*layers)
self.norm = norm_layer(self.num_features)
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
self.head = ClassifierHead(
self.num_features,
num_classes,
pool_type=global_pool,
drop_rate=drop_rate,
input_fmt=self.output_fmt,
)
if weight_init != 'skip':
self.init_weights(weight_init)
@torch.jit.ignore
def init_weights(self, mode=''):
assert mode in ('jax', 'jax_nlhb', 'moco', '')
if self.absolute_pos_embed is not None:
trunc_normal_(self.absolute_pos_embed, std=.02)
head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
named_apply(get_init_weights_vit(mode, head_bias=head_bias), self)
@torch.jit.ignore
def no_weight_decay(self):
nwd = {'absolute_pos_embed'}
nwd = set()
for n, _ in self.named_parameters():
if 'relative_position_bias_table' in n:
nwd.add(n)
@ -527,7 +554,7 @@ class SwinTransformer(nn.Module):
@torch.jit.ignore
def group_matcher(self, coarse=False):
return dict(
stem=r'^absolute_pos_embed|patch_embed', # stem and embed
stem=r'^patch_embed', # stem and embed
blocks=r'^layers\.(\d+)' if coarse else [
(r'^layers\.(\d+).downsample', (0,)),
(r'^layers\.(\d+)\.\w+\.(\d+)', None),
@ -542,28 +569,26 @@ class SwinTransformer(nn.Module):
@torch.jit.ignore
def get_classifier(self):
return self.head
return self.head.fc
def reset_classifier(self, num_classes, global_pool=None):
self.num_classes = num_classes
if global_pool is not None:
assert global_pool in ('', 'avg')
self.global_pool = global_pool
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
self.head.reset(num_classes, pool_type=global_pool)
def forward_features(self, x):
x = self.patch_embed(x)
if self.absolute_pos_embed is not None:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
if self.output_nchw:
# patch embed always outputs NHWC, stage layers expect NCHW input
x = x.permute(0, 3, 1, 2)
x = self.layers(x)
x = self.norm(x) # B L C
if self.output_nchw:
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
else:
x = self.norm(x)
return x
def forward_head(self, x, pre_logits: bool = False):
if self.global_pool == 'avg':
x = x.mean(dim=1)
return x if pre_logits else self.head(x)
return self.head(x, pre_logits=True) if pre_logits else self.head(x)
def forward(self, x):
x = self.forward_features(x)
@ -571,15 +596,102 @@ class SwinTransformer(nn.Module):
return x
def checkpoint_filter_fn(
state_dict,
model,
adapt_layer_scale=False,
interpolation='bicubic',
antialias=True,
):
""" convert patch embedding weight from manual patchify + linear proj to conv"""
import re
out_dict = {}
state_dict = state_dict.get('model', state_dict)
state_dict = state_dict.get('state_dict', state_dict)
for k, v in state_dict.items():
k = re.sub(r'layers.(\d+).downsample', lambda x: f'layers.{int(x.group(1)) + 1}.downsample', k)
k = k.replace('head.', 'head.fc.')
out_dict[k] = v
return out_dict
def _create_swin_transformer(variant, pretrained=False, **kwargs):
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1))))
out_indices = kwargs.pop('out_indices', default_out_indices)
model = build_model_with_cfg(
SwinTransformer, variant, pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
**kwargs)
return model
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'patch_embed.proj', 'classifier': 'head.fc',
**kwargs
}
default_cfgs = {
'swin_base_patch4_window12_384': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22kto1k.pth',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
'swin_base_patch4_window7_224': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22kto1k.pth',
),
'swin_large_patch4_window12_384': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22kto1k.pth',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
'swin_large_patch4_window7_224': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22kto1k.pth',
),
'swin_small_patch4_window7_224': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth',
),
'swin_tiny_patch4_window7_224': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth',
),
'swin_base_patch4_window12_384_in22k': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=21841),
'swin_base_patch4_window7_224_in22k': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth',
num_classes=21841),
'swin_large_patch4_window12_384_in22k': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=21841),
'swin_large_patch4_window7_224_in22k': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth',
num_classes=21841),
'swin_s3_tiny_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/s3_t-1d53f6a8.pth'
),
'swin_s3_small_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/s3_s-3bb4c69d.pth'
),
'swin_s3_base_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/s3_b-a1e95db4.pth'
)
}
@register_model
def swin_base_patch4_window12_384(pretrained=False, **kwargs):
""" Swin-B @ 384x384, pretrained ImageNet-22k, fine tune 1k

View File

@ -13,7 +13,7 @@ Modifications and additions for timm hacked together by / Copyright 2022, Ross W
# Written by Ze Liu
# --------------------------------------------------------
import math
from typing import Tuple, Optional
from typing import Callable, Optional, Tuple, Union
import torch
import torch.nn as nn
@ -21,76 +21,14 @@ import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, _assert
from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, _assert, ClassifierHead
from ._builder import build_model_with_cfg
from ._features_fx import register_notrace_function
from ._registry import register_model
__all__ = ['SwinTransformerV2'] # model_registry will add each entrypoint fn to this
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'patch_embed.proj', 'classifier': 'head',
**kwargs
}
default_cfgs = {
'swinv2_tiny_window8_256': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window8_256.pth',
input_size=(3, 256, 256)
),
'swinv2_tiny_window16_256': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window16_256.pth',
input_size=(3, 256, 256)
),
'swinv2_small_window8_256': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_small_patch4_window8_256.pth',
input_size=(3, 256, 256)
),
'swinv2_small_window16_256': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_small_patch4_window16_256.pth',
input_size=(3, 256, 256)
),
'swinv2_base_window8_256': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window8_256.pth',
input_size=(3, 256, 256)
),
'swinv2_base_window16_256': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window16_256.pth',
input_size=(3, 256, 256)
),
'swinv2_base_window12_192_22k': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12_192_22k.pth',
num_classes=21841, input_size=(3, 192, 192)
),
'swinv2_base_window12to16_192to256_22kft1k': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12to16_192to256_22kto1k_ft.pth',
input_size=(3, 256, 256)
),
'swinv2_base_window12to24_192to384_22kft1k': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12to24_192to384_22kto1k_ft.pth',
input_size=(3, 384, 384), crop_pct=1.0,
),
'swinv2_large_window12_192_22k': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12_192_22k.pth',
num_classes=21841, input_size=(3, 192, 192)
),
'swinv2_large_window12to16_192to256_22kft1k': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12to16_192to256_22kto1k_ft.pth',
input_size=(3, 256, 256)
),
'swinv2_large_window12to24_192to384_22kft1k': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12to24_192to384_22kto1k_ft.pth',
input_size=(3, 384, 384), crop_pct=1.0,
),
}
_int_or_tuple_2_t = Union[int, Tuple[int, int]]
def window_partition(x, window_size: Tuple[int, int]):
@ -141,9 +79,15 @@ class WindowAttention(nn.Module):
"""
def __init__(
self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
pretrained_window_size=[0, 0]):
self,
dim,
window_size,
num_heads,
qkv_bias=True,
attn_drop=0.,
proj_drop=0.,
pretrained_window_size=[0, 0],
):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
@ -231,8 +175,8 @@ class WindowAttention(nn.Module):
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
num_win = mask.shape[0]
attn = attn.view(-1, num_win, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
@ -246,29 +190,42 @@ class WindowAttention(nn.Module):
return x
class SwinTransformerBlock(nn.Module):
r""" Swin Transformer Block.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
pretrained_window_size (int): Window size in pretraining.
class SwinTransformerV2Block(nn.Module):
""" Swin Transformer Block.
"""
def __init__(
self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm, pretrained_window_size=0):
self,
dim,
input_resolution,
num_heads,
window_size=7,
shift_size=0,
mlp_ratio=4.,
qkv_bias=True,
drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
pretrained_window_size=0,
):
"""
Args:
dim: Number of input channels.
input_resolution: Input resolution.
num_heads: Number of attention heads.
window_size: Window size.
shift_size: Shift size for SW-MSA.
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
qkv_bias: If True, add a learnable bias to query, key, value.
drop: Dropout rate.
attn_drop: Attention dropout rate.
drop_path: Stochastic depth rate.
act_layer: Activation layer.
norm_layer: Normalization layer.
pretrained_window_size: Window size in pretraining.
"""
super().__init__()
self.dim = dim
self.input_resolution = to_2tuple(input_resolution)
@ -280,13 +237,23 @@ class SwinTransformerBlock(nn.Module):
self.mlp_ratio = mlp_ratio
self.attn = WindowAttention(
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
pretrained_window_size=to_2tuple(pretrained_window_size))
dim,
window_size=to_2tuple(self.window_size),
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=drop,
pretrained_window_size=to_2tuple(pretrained_window_size),
)
self.norm1 = norm_layer(dim)
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
self.mlp = Mlp(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
drop=drop,
)
self.norm2 = norm_layer(dim)
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
@ -322,10 +289,7 @@ class SwinTransformerBlock(nn.Module):
return tuple(window_size), tuple(shift_size)
def _attn(self, x):
H, W = self.input_resolution
B, L, C = x.shape
_assert(L == H * W, "input feature has wrong size")
x = x.view(B, H, W, C)
B, H, W, C = x.shape
# cyclic shift
has_shift = any(self.shift_size)
@ -350,113 +314,130 @@ class SwinTransformerBlock(nn.Module):
x = torch.roll(shifted_x, shifts=self.shift_size, dims=(1, 2))
else:
x = shifted_x
x = x.view(B, H * W, C)
return x
def forward(self, x):
B, H, W, C = x.shape
x = x + self.drop_path1(self.norm1(self._attn(x)))
x = x.reshape(B, -1, C)
x = x + self.drop_path2(self.norm2(self.mlp(x)))
x = x.reshape(B, H, W, C)
return x
class PatchMerging(nn.Module):
r""" Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
""" Patch Merging Layer.
"""
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
def __init__(self, dim, out_dim=None, norm_layer=nn.LayerNorm):
"""
Args:
dim (int): Number of input channels.
out_dim (int): Number of output channels (or 2 * dim if None)
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(2 * dim)
self.out_dim = out_dim or 2 * dim
self.reduction = nn.Linear(4 * dim, self.out_dim, bias=False)
self.norm = norm_layer(self.out_dim)
def forward(self, x):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
B, L, C = x.shape
_assert(L == H * W, "input feature has wrong size")
_assert(H % 2 == 0, f"x size ({H}*{W}) are not even.")
_assert(W % 2 == 0, f"x size ({H}*{W}) are not even.")
x = x.view(B, H, W, C)
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
B, H, W, C = x.shape
_assert(H % 2 == 0, f"x height ({H}) is not even.")
_assert(W % 2 == 0, f"x width ({W}) is not even.")
x = x.reshape(B, H // 2, 2, W // 2, 2, C).permute(0, 1, 3, 4, 2, 5).flatten(3)
x = self.reduction(x)
x = self.norm(x)
return x
class BasicLayer(nn.Module):
""" A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
pretrained_window_size (int): Local window size in pre-training.
class SwinTransformerV2Stage(nn.Module):
""" A Swin Transformer V2 Stage.
"""
def __init__(
self, dim, input_resolution, depth, num_heads, window_size,
mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
norm_layer=nn.LayerNorm, downsample=None, pretrained_window_size=0):
self,
dim,
out_dim,
input_resolution,
depth,
num_heads,
window_size,
downsample=False,
mlp_ratio=4.,
qkv_bias=True,
drop=0.,
attn_drop=0.,
drop_path=0.,
norm_layer=nn.LayerNorm,
pretrained_window_size=0,
output_nchw=False,
):
"""
Args:
dim: Number of input channels.
input_resolution: Input resolution.
depth: Number of blocks.
num_heads: Number of attention heads.
window_size: Local window size.
downsample: Use downsample layer at start of the block.
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
qkv_bias: If True, add a learnable bias to query, key, value.
drop: Dropout rate
attn_drop: Attention dropout rate.
drop_path: Stochastic depth rate.
norm_layer: Normalization layer.
pretrained_window_size: Local window size in pretraining.
output_nchw: Output tensors on NCHW format instead of NHWC.
"""
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.output_resolution = tuple(i // 2 for i in input_resolution) if downsample else input_resolution
self.depth = depth
self.output_nchw = output_nchw
self.grad_checkpointing = False
# patch merging / downsample layer
if downsample:
self.downsample = PatchMerging(dim=dim, out_dim=out_dim, norm_layer=norm_layer)
else:
assert dim == out_dim
self.downsample = nn.Identity()
# build blocks
self.blocks = nn.ModuleList([
SwinTransformerBlock(
dim=dim, input_resolution=input_resolution,
num_heads=num_heads, window_size=window_size,
SwinTransformerV2Block(
dim=out_dim,
input_resolution=self.output_resolution,
num_heads=num_heads,
window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop=drop, attn_drop=attn_drop,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer,
pretrained_window_size=pretrained_window_size)
pretrained_window_size=pretrained_window_size,
)
for i in range(depth)])
# patch merging layer
if downsample is not None:
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
else:
self.downsample = nn.Identity()
def forward(self, x):
if self.output_nchw:
x = x.permute(0, 2, 3, 1) # NCHW -> NHWC
x = self.downsample(x)
for blk in self.blocks:
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint.checkpoint(blk, x)
else:
x = blk(x)
x = self.downsample(x)
if self.output_nchw:
x = x.permute(0, 3, 1, 2) # NHWC -> NCHW
return x
def _init_respostnorm(self):
@ -468,88 +449,117 @@ class BasicLayer(nn.Module):
class SwinTransformerV2(nn.Module):
r""" Swin Transformer V2
A PyTorch impl of : `Swin Transformer V2: Scaling Up Capacity and Resolution`
- https://arxiv.org/abs/2111.09883
Args:
img_size (int | tuple(int)): Input image size. Default 224
patch_size (int | tuple(int)): Patch size. Default: 4
in_chans (int): Number of input image channels. Default: 3
num_classes (int): Number of classes for classification head. Default: 1000
embed_dim (int): Patch embedding dimension. Default: 96
depths (tuple(int)): Depth of each Swin Transformer layer.
num_heads (tuple(int)): Number of attention heads in different layers.
window_size (int): Window size. Default: 7
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
drop_rate (float): Dropout rate. Default: 0
attn_drop_rate (float): Attention dropout rate. Default: 0
drop_path_rate (float): Stochastic depth rate. Default: 0.1
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
patch_norm (bool): If True, add normalization after patch embedding. Default: True
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
pretrained_window_sizes (tuple(int)): Pretrained window sizes of each layer.
""" Swin Transformer V2
A PyTorch impl of : `Swin Transformer V2: Scaling Up Capacity and Resolution`
- https://arxiv.org/abs/2111.09883
"""
def __init__(
self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, global_pool='avg',
embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24),
window_size=7, mlp_ratio=4., qkv_bias=True,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
pretrained_window_sizes=(0, 0, 0, 0), **kwargs):
self,
img_size: _int_or_tuple_2_t = 224,
patch_size: int = 4,
in_chans: int = 3,
num_classes: int = 1000,
global_pool: str = 'avg',
embed_dim: int = 96,
depths: Tuple[int, ...] = (2, 2, 6, 2),
num_heads: Tuple[int, ...] = (3, 6, 12, 24),
window_size: _int_or_tuple_2_t = 7,
mlp_ratio: float = 4.,
qkv_bias: bool = True,
drop_rate: float = 0.,
attn_drop_rate: float = 0.,
drop_path_rate: float = 0.1,
norm_layer: Callable = nn.LayerNorm,
pretrained_window_sizes: Tuple[int, ...] = (0, 0, 0, 0),
output_fmt: str = 'NHWC',
**kwargs,
):
"""
Args:
img_size: Input image size.
patch_size: Patch size.
in_chans: Number of input image channels.
num_classes: Number of classes for classification head.
embed_dim: Patch embedding dimension.
depths: Depth of each Swin Transformer stage (layer).
num_heads: Number of attention heads in different layers.
window_size: Window size.
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
qkv_bias: If True, add a learnable bias to query, key, value.
drop_rate: Dropout rate.
attn_drop_rate: Attention dropout rate.
drop_path_rate: Stochastic depth rate.
norm_layer: Normalization layer.
patch_norm: If True, add normalization after patch embedding.
pretrained_window_sizes: Pretrained window sizes of each layer.
output_fmt: Output tensor format if not None, otherwise output 'NHWC' by default.
"""
super().__init__()
self.num_classes = num_classes
assert global_pool in ('', 'avg')
assert output_fmt in ('NCHW', 'NHWC')
self.global_pool = global_pool
self.output_fmt = output_fmt
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.patch_norm = patch_norm
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
self.output_nchw = self.output_fmt == 'NCHW'
self.feature_info = []
if not isinstance(embed_dim, (tuple, list)):
embed_dim = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
# split image into non-overlapping patches
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
num_patches = self.patch_embed.num_patches
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim[0],
norm_layer=norm_layer,
output_fmt='NHWC',
)
# absolute position embedding
if ape:
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
trunc_normal_(self.absolute_pos_embed, std=.02)
else:
self.absolute_pos_embed = None
self.pos_drop = nn.Dropout(p=drop_rate)
# stochastic depth
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
# build layers
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = BasicLayer(
dim=int(embed_dim * 2 ** i_layer),
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
layers = []
in_dim = embed_dim[0]
scale = 1
for i in range(self.num_layers):
out_dim = embed_dim[i]
layers += [SwinTransformerV2Stage(
dim=in_dim,
out_dim=out_dim,
input_resolution=(
self.patch_embed.grid_size[0] // (2 ** i_layer),
self.patch_embed.grid_size[1] // (2 ** i_layer)),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
self.patch_embed.grid_size[0] // scale,
self.patch_embed.grid_size[1] // scale),
depth=depths[i],
downsample=i > 0,
num_heads=num_heads[i],
window_size=window_size,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
drop_path=dpr[i],
norm_layer=norm_layer,
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
pretrained_window_size=pretrained_window_sizes[i_layer]
)
self.layers.append(layer)
pretrained_window_size=pretrained_window_sizes[i],
output_nchw=self.output_nchw,
)]
in_dim = out_dim
if i > 0:
scale *= 2
self.feature_info += [dict(num_chs=out_dim, reduction=4 * scale, module=f'layers.{i}')]
self.layers = nn.Sequential(*layers)
self.norm = norm_layer(self.num_features)
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
self.head = ClassifierHead(
self.num_features,
num_classes,
pool_type=global_pool,
drop_rate=drop_rate,
input_fmt=self.output_fmt,
)
self.apply(self._init_weights)
for bly in self.layers:
@ -563,7 +573,7 @@ class SwinTransformerV2(nn.Module):
@torch.jit.ignore
def no_weight_decay(self):
nod = {'absolute_pos_embed'}
nod = set()
for n, m in self.named_modules():
if any([kw in n for kw in ("cpb_mlp", "logit_scale", 'relative_position_bias_table')]):
nod.add(n)
@ -587,31 +597,26 @@ class SwinTransformerV2(nn.Module):
@torch.jit.ignore
def get_classifier(self):
return self.head
return self.head.fc
def reset_classifier(self, num_classes, global_pool=None):
self.num_classes = num_classes
if global_pool is not None:
assert global_pool in ('', 'avg')
self.global_pool = global_pool
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
self.head.reset(num_classes, global_pool)
def forward_features(self, x):
x = self.patch_embed(x)
if self.absolute_pos_embed is not None:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
for layer in self.layers:
x = layer(x)
x = self.norm(x) # B L C
if self.output_nchw:
# patch embed always outputs NHWC, stage layers expect NCHW input if output_nchw = True
x = x.permute(0, 3, 1, 2)
x = self.layers(x)
if self.output_nchw:
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
else:
x = self.norm(x)
return x
def forward_head(self, x, pre_logits: bool = False):
if self.global_pool == 'avg':
x = x.mean(dim=1)
return x if pre_logits else self.head(x)
return self.head(x, pre_logits=True) if pre_logits else self.head(x)
def forward(self, x):
x = self.forward_features(x)
@ -620,25 +625,87 @@ class SwinTransformerV2(nn.Module):
def checkpoint_filter_fn(state_dict, model):
import re
out_dict = {}
if 'model' in state_dict:
# For deit models
state_dict = state_dict['model']
state_dict = state_dict.get('model', state_dict)
state_dict = state_dict.get('state_dict', state_dict)
for k, v in state_dict.items():
if any([n in k for n in ('relative_position_index', 'relative_coords_table')]):
continue # skip buffers that should not be persistent
k = re.sub(r'layers.(\d+).downsample', lambda x: f'layers.{int(x.group(1)) + 1}.downsample', k)
k = k.replace('head.', 'head.fc.')
out_dict[k] = v
return out_dict
def _create_swin_transformer_v2(variant, pretrained=False, **kwargs):
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 1, 1))))
out_indices = kwargs.pop('out_indices', default_out_indices)
model = build_model_with_cfg(
SwinTransformerV2, variant, pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
**kwargs)
return model
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 256, 256), 'pool_size': (8, 8),
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'patch_embed.proj', 'classifier': 'head.fc',
**kwargs
}
default_cfgs = {
'swinv2_tiny_window8_256': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window8_256.pth',
),
'swinv2_tiny_window16_256': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window16_256.pth',
),
'swinv2_small_window8_256': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_small_patch4_window8_256.pth',
),
'swinv2_small_window16_256': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_small_patch4_window16_256.pth',
),
'swinv2_base_window8_256': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window8_256.pth',
),
'swinv2_base_window16_256': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window16_256.pth',
),
'swinv2_base_window12_192_22k': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12_192_22k.pth',
num_classes=21841, input_size=(3, 192, 192), pool_size=(6, 6)
),
'swinv2_base_window12to16_192to256_22kft1k': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12to16_192to256_22kto1k_ft.pth',
),
'swinv2_base_window12to24_192to384_22kft1k': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12to24_192to384_22kto1k_ft.pth',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0,
),
'swinv2_large_window12_192_22k': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12_192_22k.pth',
num_classes=21841, input_size=(3, 192, 192), pool_size=(6, 6)
),
'swinv2_large_window12to16_192to256_22kft1k': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12to16_192to256_22kto1k_ft.pth',
),
'swinv2_large_window12to24_192to384_22kft1k': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12to24_192to384_22kto1k_ft.pth',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0,
),
}
@register_model
def swinv2_tiny_window16_256(pretrained=False, **kwargs):
"""

View File

@ -37,7 +37,7 @@ import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, Mlp, to_2tuple, _assert
from timm.layers import DropPath, Mlp, ClassifierHead, to_2tuple, _assert
from ._builder import build_model_with_cfg
from ._features_fx import register_notrace_function
from ._manipulate import named_apply
@ -48,60 +48,6 @@ __all__ = ['SwinTransformerV2Cr'] # model_registry will add each entrypoint fn
_logger = logging.getLogger(__name__)
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000,
'input_size': (3, 224, 224),
'pool_size': (7, 7),
'crop_pct': 0.9,
'interpolation': 'bicubic',
'fixed_input_size': True,
'mean': IMAGENET_DEFAULT_MEAN,
'std': IMAGENET_DEFAULT_STD,
'first_conv': 'patch_embed.proj',
'classifier': 'head',
**kwargs,
}
default_cfgs = {
'swinv2_cr_tiny_384': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
'swinv2_cr_tiny_224': _cfg(
url="", input_size=(3, 224, 224), crop_pct=0.9),
'swinv2_cr_tiny_ns_224': _cfg(
url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_tiny_ns_224-ba8166c6.pth",
input_size=(3, 224, 224), crop_pct=0.9),
'swinv2_cr_small_384': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
'swinv2_cr_small_224': _cfg(
url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_small_224-0813c165.pth",
input_size=(3, 224, 224), crop_pct=0.9),
'swinv2_cr_small_ns_224': _cfg(
url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_small_ns_224_iv-2ce90f8e.pth",
input_size=(3, 224, 224), crop_pct=0.9),
'swinv2_cr_base_384': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
'swinv2_cr_base_224': _cfg(
url="", input_size=(3, 224, 224), crop_pct=0.9),
'swinv2_cr_base_ns_224': _cfg(
url="", input_size=(3, 224, 224), crop_pct=0.9),
'swinv2_cr_large_384': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
'swinv2_cr_large_224': _cfg(
url="", input_size=(3, 224, 224), crop_pct=0.9),
'swinv2_cr_huge_384': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
'swinv2_cr_huge_224': _cfg(
url="", input_size=(3, 224, 224), crop_pct=0.9),
'swinv2_cr_giant_384': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
'swinv2_cr_giant_224': _cfg(
url="", input_size=(3, 224, 224), crop_pct=0.9),
}
def bchw_to_bhwc(x: torch.Tensor) -> torch.Tensor:
"""Permutes a tensor from the shape (B, C, H, W) to (B, H, W, C). """
return x.permute(0, 2, 3, 1)
@ -230,22 +176,14 @@ class WindowMultiHeadAttention(nn.Module):
relative_position_bias = relative_position_bias.unsqueeze(0)
return relative_position_bias
def _forward_sequential(
self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
"""
# FIXME TODO figure out 'sequential' attention mentioned in paper (should reduce GPU memory)
assert False, "not implemented"
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
""" Forward pass.
Args:
x (torch.Tensor): Input tensor of the shape (B * windows, N, C)
mask (Optional[torch.Tensor]): Attention mask for the shift case
def _forward_batch(
self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""This function performs standard (non-sequential) scaled cosine self-attention.
Returns:
Output tensor of the shape [B * windows, N, C]
"""
Bw, L, C = x.shape
@ -272,22 +210,8 @@ class WindowMultiHeadAttention(nn.Module):
x = self.proj_drop(x)
return x
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
""" Forward pass.
Args:
x (torch.Tensor): Input tensor of the shape (B * windows, N, C)
mask (Optional[torch.Tensor]): Attention mask for the shift case
Returns:
Output tensor of the shape [B * windows, N, C]
"""
if self.sequential_attn:
return self._forward_sequential(x, mask)
else:
return self._forward_batch(x, mask)
class SwinTransformerBlock(nn.Module):
class SwinTransformerV2CrBlock(nn.Module):
r"""This class implements the Swin transformer block.
Args:
@ -321,7 +245,7 @@ class SwinTransformerBlock(nn.Module):
sequential_attn: bool = False,
norm_layer: Type[nn.Module] = nn.LayerNorm,
) -> None:
super(SwinTransformerBlock, self).__init__()
super(SwinTransformerV2CrBlock, self).__init__()
self.dim: int = dim
self.feat_size: Tuple[int, int] = feat_size
self.target_shift_size: Tuple[int, int] = to_2tuple(shift_size)
@ -410,9 +334,7 @@ class SwinTransformerBlock(nn.Module):
self._make_attention_mask()
def _shifted_window_attn(self, x):
H, W = self.feat_size
B, L, C = x.shape
x = x.view(B, H, W, C)
B, H, W, C = x.shape
# cyclic shift
sh, sw = self.shift_size
@ -441,7 +363,6 @@ class SwinTransformerBlock(nn.Module):
# x = torch.cat([x[:, :, -sw:], x[:, :, :-sw]], dim=2)
x = torch.roll(x, shifts=(sh, sw), dims=(1, 2))
x = x.view(B, L, C)
return x
def forward(self, x: torch.Tensor) -> torch.Tensor:
@ -455,8 +376,12 @@ class SwinTransformerBlock(nn.Module):
"""
# post-norm branches (op -> norm -> drop)
x = x + self.drop_path1(self.norm1(self._shifted_window_attn(x)))
B, H, W, C = x.shape
x = x.reshape(B, -1, C)
x = x + self.drop_path2(self.norm2(self.mlp(x)))
x = self.norm3(x) # main-branch norm enabled for some blocks / stages (every 6 for Huge/Giant)
x = x.reshape(B, H, W, C)
return x
@ -479,12 +404,10 @@ class PatchMerging(nn.Module):
Returns:
output (torch.Tensor): Output tensor of the shape [B, 2 * C, H // 2, W // 2]
"""
B, C, H, W = x.shape
# unfold + BCHW -> BHWC together
# ordering, 5, 3, 1 instead of 3, 5, 1 maintains compat with original swin v1 merge
x = x.reshape(B, C, H // 2, 2, W // 2, 2).permute(0, 2, 4, 5, 3, 1).flatten(3)
B, H, W, C = x.shape
x = x.reshape(B, H // 2, 2, W // 2, 2, C).permute(0, 1, 3, 4, 2, 5).flatten(3)
x = self.norm(x)
x = bhwc_to_bchw(self.reduction(x))
x = self.reduction(x)
return x
@ -511,7 +434,7 @@ class PatchEmbed(nn.Module):
return x
class SwinTransformerStage(nn.Module):
class SwinTransformerV2CrStage(nn.Module):
r"""This class implements a stage of the Swin transformer including multiple layers.
Args:
@ -549,12 +472,16 @@ class SwinTransformerStage(nn.Module):
extra_norm_stage: bool = False,
sequential_attn: bool = False,
) -> None:
super(SwinTransformerStage, self).__init__()
super(SwinTransformerV2CrStage, self).__init__()
self.downscale: bool = downscale
self.grad_checkpointing: bool = False
self.feat_size: Tuple[int, int] = (feat_size[0] // 2, feat_size[1] // 2) if downscale else feat_size
self.downsample = PatchMerging(embed_dim, norm_layer=norm_layer) if downscale else nn.Identity()
if downscale:
self.downsample = PatchMerging(embed_dim, norm_layer=norm_layer)
embed_dim = embed_dim * 2
else:
self.downsample = nn.Identity()
def _extra_norm(index):
i = index + 1
@ -562,9 +489,8 @@ class SwinTransformerStage(nn.Module):
return True
return i == depth if extra_norm_stage else False
embed_dim = embed_dim * 2 if downscale else embed_dim
self.blocks = nn.Sequential(*[
SwinTransformerBlock(
SwinTransformerV2CrBlock(
dim=embed_dim,
num_heads=num_heads,
feat_size=self.feat_size,
@ -602,18 +528,15 @@ class SwinTransformerStage(nn.Module):
Returns:
output (torch.Tensor): Output tensor of the shape [B, 2 * C, H // 2, W // 2]
"""
x = bchw_to_bhwc(x)
x = self.downsample(x)
B, C, H, W = x.shape
L = H * W
x = bchw_to_bhwc(x).reshape(B, L, C)
for block in self.blocks:
# Perform checkpointing if utilized
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint.checkpoint(block, x)
else:
x = block(x)
x = bhwc_to_bchw(x.reshape(B, H, W, -1))
x = bhwc_to_bchw(x)
return x
@ -676,39 +599,54 @@ class SwinTransformerV2Cr(nn.Module):
self.img_size: Tuple[int, int] = img_size
self.window_size: int = window_size
self.num_features: int = int(embed_dim * 2 ** (len(depths) - 1))
self.feature_info = []
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans,
embed_dim=embed_dim, norm_layer=norm_layer)
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
norm_layer=norm_layer,
)
patch_grid_size: Tuple[int, int] = self.patch_embed.grid_size
drop_path_rate = torch.linspace(0.0, drop_path_rate, sum(depths)).tolist()
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
stages = []
for index, (depth, num_heads) in enumerate(zip(depths, num_heads)):
stage_scale = 2 ** max(index - 1, 0)
stages.append(
SwinTransformerStage(
embed_dim=embed_dim * stage_scale,
depth=depth,
downscale=index != 0,
feat_size=(patch_grid_size[0] // stage_scale, patch_grid_size[1] // stage_scale),
num_heads=num_heads,
window_size=window_size,
mlp_ratio=mlp_ratio,
init_values=init_values,
drop=drop_rate,
drop_attn=attn_drop_rate,
drop_path=drop_path_rate[sum(depths[:index]):sum(depths[:index + 1])],
extra_norm_period=extra_norm_period,
extra_norm_stage=extra_norm_stage or (index + 1) == len(depths), # last stage ends w/ norm
sequential_attn=sequential_attn,
norm_layer=norm_layer,
)
)
in_dim = embed_dim
in_scale = 1
for stage_idx, (depth, num_heads) in enumerate(zip(depths, num_heads)):
stages += [SwinTransformerV2CrStage(
embed_dim=in_dim,
depth=depth,
downscale=stage_idx != 0,
feat_size=(
patch_grid_size[0] // in_scale,
patch_grid_size[1] // in_scale
),
num_heads=num_heads,
window_size=window_size,
mlp_ratio=mlp_ratio,
init_values=init_values,
drop=drop_rate,
drop_attn=attn_drop_rate,
drop_path=dpr[stage_idx],
extra_norm_period=extra_norm_period,
extra_norm_stage=extra_norm_stage or (stage_idx + 1) == len(depths), # last stage ends w/ norm
sequential_attn=sequential_attn,
norm_layer=norm_layer,
)]
if stage_idx != 0:
in_dim *= 2
in_scale *= 2
self.feature_info += [dict(num_chs=in_dim, reduction=4 * in_scale, module=f'stages.{stage_idx}')]
self.stages = nn.Sequential(*stages)
self.global_pool: str = global_pool
self.head = nn.Linear(self.num_features, num_classes) if num_classes else nn.Identity()
self.head = ClassifierHead(
self.num_features,
num_classes,
pool_type=global_pool,
drop_rate=drop_rate,
)
# current weight init skips custom init and uses pytorch layer defaults, seems to work well
# FIXME more experiments needed
@ -765,7 +703,7 @@ class SwinTransformerV2Cr(nn.Module):
Returns:
head (nn.Module): Current classification head
"""
return self.head
return self.head.fc
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) -> None:
"""Method results the classification head
@ -774,10 +712,8 @@ class SwinTransformerV2Cr(nn.Module):
num_classes (int): Number of classes to be predicted
global_pool (str): Unused
"""
self.num_classes: int = num_classes
if global_pool is not None:
self.global_pool = global_pool
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
self.num_classes = num_classes
self.head.reset(num_classes, global_pool)
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
x = self.patch_embed(x)
@ -785,9 +721,7 @@ class SwinTransformerV2Cr(nn.Module):
return x
def forward_head(self, x, pre_logits: bool = False):
if self.global_pool == 'avg':
x = x.mean(dim=(2, 3))
return x if pre_logits else self.head(x)
return self.head(x, pre_logits=True) if pre_logits else self.head(x)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.forward_features(x)
@ -815,29 +749,87 @@ def init_weights(module: nn.Module, name: str = ''):
def checkpoint_filter_fn(state_dict, model):
""" convert patch embedding weight from manual patchify + linear proj to conv"""
out_dict = {}
if 'model' in state_dict:
# For deit models
state_dict = state_dict['model']
state_dict = state_dict.get('model', state_dict)
state_dict = state_dict.get('state_dict', state_dict)
for k, v in state_dict.items():
if 'tau' in k:
# convert old tau based checkpoints -> logit_scale (inverse)
v = torch.log(1 / v)
k = k.replace('tau', 'logit_scale')
k = k.replace('head.', 'head.fc.')
out_dict[k] = v
return out_dict
def _create_swin_transformer_v2_cr(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 1, 1))))
out_indices = kwargs.pop('out_indices', default_out_indices)
model = build_model_with_cfg(
SwinTransformerV2Cr, variant, pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
**kwargs
)
return model
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000,
'input_size': (3, 224, 224),
'pool_size': (7, 7),
'crop_pct': 0.9,
'interpolation': 'bicubic',
'fixed_input_size': True,
'mean': IMAGENET_DEFAULT_MEAN,
'std': IMAGENET_DEFAULT_STD,
'first_conv': 'patch_embed.proj',
'classifier': 'head.fc',
**kwargs,
}
default_cfgs = {
'swinv2_cr_tiny_384': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
'swinv2_cr_tiny_224': _cfg(
url="", input_size=(3, 224, 224), crop_pct=0.9),
'swinv2_cr_tiny_ns_224': _cfg(
url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_tiny_ns_224-ba8166c6.pth",
input_size=(3, 224, 224), crop_pct=0.9),
'swinv2_cr_small_384': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
'swinv2_cr_small_224': _cfg(
url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_small_224-0813c165.pth",
input_size=(3, 224, 224), crop_pct=0.9),
'swinv2_cr_small_ns_224': _cfg(
url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_small_ns_224_iv-2ce90f8e.pth",
input_size=(3, 224, 224), crop_pct=0.9),
'swinv2_cr_small_ns_256': _cfg(
url="", input_size=(3, 256, 256), crop_pct=1.0, pool_size=(8, 8)),
'swinv2_cr_base_384': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
'swinv2_cr_base_224': _cfg(
url="", input_size=(3, 224, 224), crop_pct=0.9),
'swinv2_cr_base_ns_224': _cfg(
url="", input_size=(3, 224, 224), crop_pct=0.9),
'swinv2_cr_large_384': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
'swinv2_cr_large_224': _cfg(
url="", input_size=(3, 224, 224), crop_pct=0.9),
'swinv2_cr_huge_384': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
'swinv2_cr_huge_224': _cfg(
url="", input_size=(3, 224, 224), crop_pct=0.9),
'swinv2_cr_giant_384': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
'swinv2_cr_giant_224': _cfg(
url="", input_size=(3, 224, 224), crop_pct=0.9),
}
@register_model
def swinv2_cr_tiny_384(pretrained=False, **kwargs):
"""Swin-T V2 CR @ 384x384, trained ImageNet-1k"""
@ -915,6 +907,19 @@ def swinv2_cr_small_ns_224(pretrained=False, **kwargs):
return _create_swin_transformer_v2_cr('swinv2_cr_small_ns_224', pretrained=pretrained, **model_kwargs)
@register_model
def swinv2_cr_small_ns_256(pretrained=False, **kwargs):
"""Swin-S V2 CR @ 256x256, trained ImageNet-1k"""
model_kwargs = dict(
embed_dim=96,
depths=(2, 2, 18, 2),
num_heads=(3, 6, 12, 24),
extra_norm_stage=True,
**kwargs
)
return _create_swin_transformer_v2_cr('swinv2_cr_small_ns_256', pretrained=pretrained, **model_kwargs)
@register_model
def swinv2_cr_base_384(pretrained=False, **kwargs):
"""Swin-B V2 CR @ 384x384, trained ImageNet-1k"""
@ -961,8 +966,7 @@ def swinv2_cr_large_384(pretrained=False, **kwargs):
num_heads=(6, 12, 24, 48),
**kwargs
)
return _create_swin_transformer_v2_cr('swinv2_cr_large_384', pretrained=pretrained, **model_kwargs
)
return _create_swin_transformer_v2_cr('swinv2_cr_large_384', pretrained=pretrained, **model_kwargs)
@register_model