diff --git a/tests/test_models.py b/tests/test_models.py index 3d8f8515..9fa745ee 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -27,7 +27,8 @@ except ImportError: import timm from timm import list_models, create_model, set_scriptable, get_pretrained_cfg_value -from timm.models._features_fx import _leaf_modules, _autowrap_functions +from timm.layers import Format, get_spatial_dim, get_channel_dim +from timm.models import get_notrace_modules, get_notrace_functions if hasattr(torch._C, '_jit_set_profiling_executor'): # legacy executor is too slow to compile large models for unit tests @@ -37,10 +38,10 @@ if hasattr(torch._C, '_jit_set_profiling_executor'): # transformer models don't support many of the spatial / feature based model functionalities NON_STD_FILTERS = [ - 'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*', + 'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*', 'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit*', - 'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*', - 'eva_*', 'flexivit*' + 'poolformer_*', 'volo_*', 'sequencer2d_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*', + 'eva_*', 'flexivit*', ] NUM_NON_STD = len(NON_STD_FILTERS) @@ -52,7 +53,7 @@ if 'GITHUB_ACTIONS' in os.environ: '*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', '*efficientnetv2_xl*', '*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', 'vit_huge*', 'vit_gi*', 'swin*huge*', 'swin*giant*', 'convnextv2_huge*', 'maxvit_xlarge*', 'davit_giant', 'davit_huge'] - NON_STD_EXCLUDE_FILTERS = ['vit_huge*', 'vit_gi*', 'swin*giant*', 'eva_giant*'] + NON_STD_EXCLUDE_FILTERS = ['vit_huge*', 'vit_gi*', 'eva_giant*'] else: EXCLUDE_FILTERS = [] NON_STD_EXCLUDE_FILTERS = ['vit_gi*'] @@ -156,6 +157,10 @@ def test_model_default_cfgs(model_name, batch_size): pool_size = cfg['pool_size'] input_size = model.default_cfg['input_size'] + output_fmt = getattr(model, 'output_fmt', 'NCHW') + spatial_axis = get_spatial_dim(output_fmt) + assert len(spatial_axis) == 2 # TODO add 1D sequence support + feat_axis = get_channel_dim(output_fmt) if all([x <= MAX_FWD_OUT_SIZE for x in input_size]) and \ not any([fnmatch.fnmatch(model_name, x) for x in EXCLUDE_FILTERS]): @@ -165,13 +170,14 @@ def test_model_default_cfgs(model_name, batch_size): # test forward_features (always unpooled) outputs = model.forward_features(input_tensor) - assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2], 'unpooled feature shape != config' + assert outputs.shape[spatial_axis[0]] == pool_size[0], 'unpooled feature shape != config' + assert outputs.shape[spatial_axis[1]] == pool_size[1], 'unpooled feature shape != config' # test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features model.reset_classifier(0) outputs = model.forward(input_tensor) assert len(outputs.shape) == 2 - assert outputs.shape[-1] == model.num_features + assert outputs.shape[1] == model.num_features # test model forward without pooling and classifier model.reset_classifier(0, '') # reset classifier and set global pooling to pass-through @@ -179,7 +185,7 @@ def test_model_default_cfgs(model_name, batch_size): assert len(outputs.shape) == 4 if not isinstance(model, (timm.models.MobileNetV3, timm.models.GhostNet, timm.models.VGG)): # mobilenetv3/ghostnet/vgg forward_features vs removed pooling differ due to location or lack of GAP - assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2] + assert outputs.shape[spatial_axis[0]] == pool_size[0] and outputs.shape[spatial_axis[1]] == pool_size[1] if 'pruned' not in model_name: # FIXME better pruned model handling # test classifier + global pool deletion via __init__ @@ -187,7 +193,7 @@ def test_model_default_cfgs(model_name, batch_size): outputs = model.forward(input_tensor) assert len(outputs.shape) == 4 if not isinstance(model, (timm.models.MobileNetV3, timm.models.GhostNet, timm.models.VGG)): - assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2] + assert outputs.shape[spatial_axis[0]] == pool_size[0] and outputs.shape[spatial_axis[1]] == pool_size[1] # check classifier name matches default_cfg if cfg.get('num_classes', None): @@ -330,176 +336,181 @@ def test_model_forward_features(model_name, batch_size): model = create_model(model_name, pretrained=False, features_only=True) model.eval() expected_channels = model.feature_info.channels() + expected_reduction = model.feature_info.reduction() assert len(expected_channels) >= 4 # all models here should have at least 4 feature levels by default, some 5 or 6 input_size = _get_input_size(model=model, target=TARGET_FFEAT_SIZE) if max(input_size) > MAX_FFEAT_SIZE: pytest.skip("Fixed input size model > limit.") + output_fmt = getattr(model, 'output_fmt', 'NCHW') + feat_axis = get_channel_dim(output_fmt) + spatial_axis = get_spatial_dim(output_fmt) + import math outputs = model(torch.randn((batch_size, *input_size))) assert len(expected_channels) == len(outputs) - for e, o in zip(expected_channels, outputs): - assert e == o.shape[1] + spatial_size = input_size[-2:] + for e, r, o in zip(expected_channels, expected_reduction, outputs): + assert e == o.shape[feat_axis] + assert o.shape[spatial_axis[0]] <= math.ceil(spatial_size[0] / r) + 1 + assert o.shape[spatial_axis[1]] <= math.ceil(spatial_size[1] / r) + 1 assert o.shape[0] == batch_size assert not torch.isnan(o).any() -if not _IS_MAC: - # MACOS test runners are really slow, only running tests below this point if not on a Darwin runner... +def _create_fx_model(model, train=False): + # This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode + # So we trace once and look at the graph, and get the indices of the nodes that lead into the original fx output + # node. Then we use those indices to select from train_nodes returned by torchvision get_graph_node_names + tracer_kwargs = dict( + leaf_modules=get_notrace_modules(), + autowrap_functions=get_notrace_functions(), + #enable_cpatching=True, + param_shapes_constant=True + ) + train_nodes, eval_nodes = get_graph_node_names(model, tracer_kwargs=tracer_kwargs) - def _create_fx_model(model, train=False): - # This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode - # So we trace once and look at the graph, and get the indices of the nodes that lead into the original fx output - # node. Then we use those indices to select from train_nodes returned by torchvision get_graph_node_names - tracer_kwargs = dict( - leaf_modules=list(_leaf_modules), - autowrap_functions=list(_autowrap_functions), - #enable_cpatching=True, - param_shapes_constant=True - ) - train_nodes, eval_nodes = get_graph_node_names(model, tracer_kwargs=tracer_kwargs) + eval_return_nodes = [eval_nodes[-1]] + train_return_nodes = [train_nodes[-1]] + if train: + tracer = NodePathTracer(**tracer_kwargs) + graph = tracer.trace(model) + graph_nodes = list(reversed(graph.nodes)) + output_node_names = [n.name for n in graph_nodes[0]._input_nodes.keys()] + graph_node_names = [n.name for n in graph_nodes] + output_node_indices = [-graph_node_names.index(node_name) for node_name in output_node_names] + train_return_nodes = [train_nodes[ix] for ix in output_node_indices] - eval_return_nodes = [eval_nodes[-1]] - train_return_nodes = [train_nodes[-1]] - if train: - tracer = NodePathTracer(**tracer_kwargs) - graph = tracer.trace(model) - graph_nodes = list(reversed(graph.nodes)) - output_node_names = [n.name for n in graph_nodes[0]._input_nodes.keys()] - graph_node_names = [n.name for n in graph_nodes] - output_node_indices = [-graph_node_names.index(node_name) for node_name in output_node_names] - train_return_nodes = [train_nodes[ix] for ix in output_node_indices] - - fx_model = create_feature_extractor( - model, - train_return_nodes=train_return_nodes, - eval_return_nodes=eval_return_nodes, - tracer_kwargs=tracer_kwargs, - ) - return fx_model + fx_model = create_feature_extractor( + model, + train_return_nodes=train_return_nodes, + eval_return_nodes=eval_return_nodes, + tracer_kwargs=tracer_kwargs, + ) + return fx_model - EXCLUDE_FX_FILTERS = ['vit_gi*'] - # not enough memory to run fx on more models than other tests - if 'GITHUB_ACTIONS' in os.environ: - EXCLUDE_FX_FILTERS += [ - 'beit_large*', - 'mixer_l*', - '*nfnet_f2*', - '*resnext101_32x32d', - 'resnetv2_152x2*', - 'resmlp_big*', - 'resnetrs270', - 'swin_large*', - 'vgg*', - 'vit_large*', - 'vit_base_patch8*', - 'xcit_large*', - ] +EXCLUDE_FX_FILTERS = ['vit_gi*'] +# not enough memory to run fx on more models than other tests +if 'GITHUB_ACTIONS' in os.environ: + EXCLUDE_FX_FILTERS += [ + 'beit_large*', + 'mixer_l*', + '*nfnet_f2*', + '*resnext101_32x32d', + 'resnetv2_152x2*', + 'resmlp_big*', + 'resnetrs270', + 'swin_large*', + 'vgg*', + 'vit_large*', + 'vit_base_patch8*', + 'xcit_large*', + ] + + +@pytest.mark.fxforward +@pytest.mark.timeout(120) +@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FX_FILTERS)) +@pytest.mark.parametrize('batch_size', [1]) +def test_model_forward_fx(model_name, batch_size): + """ + Symbolically trace each model and run single forward pass through the resulting GraphModule + Also check that the output of a forward pass through the GraphModule is the same as that from the original Module + """ + if not has_fx_feature_extraction: + pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.") + + model = create_model(model_name, pretrained=False) + model.eval() + + input_size = _get_input_size(model=model, target=TARGET_FWD_FX_SIZE) + if max(input_size) > MAX_FWD_FX_SIZE: + pytest.skip("Fixed input size model > limit.") + with torch.no_grad(): + inputs = torch.randn((batch_size, *input_size)) + outputs = model(inputs) + if isinstance(outputs, tuple): + outputs = torch.cat(outputs) + + model = _create_fx_model(model) + fx_outputs = tuple(model(inputs).values()) + if isinstance(fx_outputs, tuple): + fx_outputs = torch.cat(fx_outputs) + + assert torch.all(fx_outputs == outputs) + assert outputs.shape[0] == batch_size + assert not torch.isnan(outputs).any(), 'Output included NaNs' + + +@pytest.mark.fxbackward +@pytest.mark.timeout(120) +@pytest.mark.parametrize('model_name', list_models( + exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FX_FILTERS, name_matches_cfg=True)) +@pytest.mark.parametrize('batch_size', [2]) +def test_model_backward_fx(model_name, batch_size): + """Symbolically trace each model and run single backward pass through the resulting GraphModule""" + if not has_fx_feature_extraction: + pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.") + + input_size = _get_input_size(model_name=model_name, target=TARGET_BWD_FX_SIZE) + if max(input_size) > MAX_BWD_FX_SIZE: + pytest.skip("Fixed input size model > limit.") + + model = create_model(model_name, pretrained=False, num_classes=42) + model.train() + num_params = sum([x.numel() for x in model.parameters()]) + if 'GITHUB_ACTIONS' in os.environ and num_params > 100e6: + pytest.skip("Skipping FX backward test on model with more than 100M params.") + + model = _create_fx_model(model, train=True) + outputs = tuple(model(torch.randn((batch_size, *input_size))).values()) + if isinstance(outputs, tuple): + outputs = torch.cat(outputs) + outputs.mean().backward() + for n, x in model.named_parameters(): + assert x.grad is not None, f'No gradient for {n}' + num_grad = sum([x.grad.numel() for x in model.parameters() if x.grad is not None]) + + assert outputs.shape[-1] == 42 + assert num_params == num_grad, 'Some parameters are missing gradients' + assert not torch.isnan(outputs).any(), 'Output included NaNs' + + +if 'GITHUB_ACTIONS' not in os.environ: + # FIXME this test is causing GitHub actions to run out of RAM and abruptly kill the test process + + # reason: model is scripted after fx tracing, but beit has torch.jit.is_scripting() control flow + EXCLUDE_FX_JIT_FILTERS = [ + 'deit_*_distilled_patch16_224', + 'levit*', + 'pit_*_distilled_224', + ] + EXCLUDE_FX_FILTERS - @pytest.mark.fxforward @pytest.mark.timeout(120) - @pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FX_FILTERS)) + @pytest.mark.parametrize( + 'model_name', list_models( + exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS + EXCLUDE_FX_JIT_FILTERS, name_matches_cfg=True)) @pytest.mark.parametrize('batch_size', [1]) - def test_model_forward_fx(model_name, batch_size): - """ - Symbolically trace each model and run single forward pass through the resulting GraphModule - Also check that the output of a forward pass through the GraphModule is the same as that from the original Module - """ + def test_model_forward_fx_torchscript(model_name, batch_size): + """Symbolically trace each model, script it, and run single forward pass""" if not has_fx_feature_extraction: pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.") - model = create_model(model_name, pretrained=False) + input_size = _get_input_size(model_name=model_name, target=TARGET_JIT_SIZE) + if max(input_size) > MAX_JIT_SIZE: + pytest.skip("Fixed input size model > limit.") + + with set_scriptable(True): + model = create_model(model_name, pretrained=False) model.eval() - input_size = _get_input_size(model=model, target=TARGET_FWD_FX_SIZE) - if max(input_size) > MAX_FWD_FX_SIZE: - pytest.skip("Fixed input size model > limit.") + model = torch.jit.script(_create_fx_model(model)) with torch.no_grad(): - inputs = torch.randn((batch_size, *input_size)) - outputs = model(inputs) + outputs = tuple(model(torch.randn((batch_size, *input_size))).values()) if isinstance(outputs, tuple): outputs = torch.cat(outputs) - model = _create_fx_model(model) - fx_outputs = tuple(model(inputs).values()) - if isinstance(fx_outputs, tuple): - fx_outputs = torch.cat(fx_outputs) - - assert torch.all(fx_outputs == outputs) assert outputs.shape[0] == batch_size assert not torch.isnan(outputs).any(), 'Output included NaNs' - - - @pytest.mark.fxbackward - @pytest.mark.timeout(120) - @pytest.mark.parametrize('model_name', list_models( - exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FX_FILTERS, name_matches_cfg=True)) - @pytest.mark.parametrize('batch_size', [2]) - def test_model_backward_fx(model_name, batch_size): - """Symbolically trace each model and run single backward pass through the resulting GraphModule""" - if not has_fx_feature_extraction: - pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.") - - input_size = _get_input_size(model_name=model_name, target=TARGET_BWD_FX_SIZE) - if max(input_size) > MAX_BWD_FX_SIZE: - pytest.skip("Fixed input size model > limit.") - - model = create_model(model_name, pretrained=False, num_classes=42) - model.train() - num_params = sum([x.numel() for x in model.parameters()]) - if 'GITHUB_ACTIONS' in os.environ and num_params > 100e6: - pytest.skip("Skipping FX backward test on model with more than 100M params.") - - model = _create_fx_model(model, train=True) - outputs = tuple(model(torch.randn((batch_size, *input_size))).values()) - if isinstance(outputs, tuple): - outputs = torch.cat(outputs) - outputs.mean().backward() - for n, x in model.named_parameters(): - assert x.grad is not None, f'No gradient for {n}' - num_grad = sum([x.grad.numel() for x in model.parameters() if x.grad is not None]) - - assert outputs.shape[-1] == 42 - assert num_params == num_grad, 'Some parameters are missing gradients' - assert not torch.isnan(outputs).any(), 'Output included NaNs' - - - if 'GITHUB_ACTIONS' not in os.environ: - # FIXME this test is causing GitHub actions to run out of RAM and abruptly kill the test process - - # reason: model is scripted after fx tracing, but beit has torch.jit.is_scripting() control flow - EXCLUDE_FX_JIT_FILTERS = [ - 'deit_*_distilled_patch16_224', - 'levit*', - 'pit_*_distilled_224', - ] + EXCLUDE_FX_FILTERS - - - @pytest.mark.timeout(120) - @pytest.mark.parametrize( - 'model_name', list_models( - exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS + EXCLUDE_FX_JIT_FILTERS, name_matches_cfg=True)) - @pytest.mark.parametrize('batch_size', [1]) - def test_model_forward_fx_torchscript(model_name, batch_size): - """Symbolically trace each model, script it, and run single forward pass""" - if not has_fx_feature_extraction: - pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.") - - input_size = _get_input_size(model_name=model_name, target=TARGET_JIT_SIZE) - if max(input_size) > MAX_JIT_SIZE: - pytest.skip("Fixed input size model > limit.") - - with set_scriptable(True): - model = create_model(model_name, pretrained=False) - model.eval() - - model = torch.jit.script(_create_fx_model(model)) - with torch.no_grad(): - outputs = tuple(model(torch.randn((batch_size, *input_size))).values()) - if isinstance(outputs, tuple): - outputs = torch.cat(outputs) - - assert outputs.shape[0] == batch_size - assert not torch.isnan(outputs).any(), 'Output included NaNs' diff --git a/timm/layers/__init__.py b/timm/layers/__init__.py index 47b02892..d4eab660 100644 --- a/timm/layers/__init__.py +++ b/timm/layers/__init__.py @@ -20,6 +20,7 @@ from .evo_norm import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2,\ EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a from .fast_norm import is_fast_norm, set_fast_norm, fast_group_norm, fast_layer_norm from .filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d +from .format import Format, get_channel_dim, get_spatial_dim, nchw_to, nhwc_to from .gather_excite import GatherExcite from .global_context import GlobalContext from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple diff --git a/timm/layers/adaptive_avgmax_pool.py b/timm/layers/adaptive_avgmax_pool.py index ebc6ada8..e18ea998 100644 --- a/timm/layers/adaptive_avgmax_pool.py +++ b/timm/layers/adaptive_avgmax_pool.py @@ -9,31 +9,37 @@ Both a functional and a nn.Module version of the pooling is provided. Hacked together by / Copyright 2020 Ross Wightman """ +from typing import Optional, Tuple, Union + import torch import torch.nn as nn import torch.nn.functional as F +from .format import get_spatial_dim, get_channel_dim + +_int_tuple_2_t = Union[int, Tuple[int, int]] + def adaptive_pool_feat_mult(pool_type='avg'): - if pool_type == 'catavgmax': + if pool_type.endswith('catavgmax'): return 2 else: return 1 -def adaptive_avgmax_pool2d(x, output_size=1): +def adaptive_avgmax_pool2d(x, output_size: _int_tuple_2_t = 1): x_avg = F.adaptive_avg_pool2d(x, output_size) x_max = F.adaptive_max_pool2d(x, output_size) return 0.5 * (x_avg + x_max) -def adaptive_catavgmax_pool2d(x, output_size=1): +def adaptive_catavgmax_pool2d(x, output_size: _int_tuple_2_t = 1): x_avg = F.adaptive_avg_pool2d(x, output_size) x_max = F.adaptive_max_pool2d(x, output_size) return torch.cat((x_avg, x_max), 1) -def select_adaptive_pool2d(x, pool_type='avg', output_size=1): +def select_adaptive_pool2d(x, pool_type='avg', output_size: _int_tuple_2_t = 1): """Selectable global pooling function with dynamic input kernel size """ if pool_type == 'avg': @@ -49,17 +55,56 @@ def select_adaptive_pool2d(x, pool_type='avg', output_size=1): return x -class FastAdaptiveAvgPool2d(nn.Module): - def __init__(self, flatten=False): - super(FastAdaptiveAvgPool2d, self).__init__() +class FastAdaptiveAvgPool(nn.Module): + def __init__(self, flatten: bool = False, input_fmt: F = 'NCHW'): + super(FastAdaptiveAvgPool, self).__init__() self.flatten = flatten + self.dim = get_spatial_dim(input_fmt) def forward(self, x): - return x.mean((2, 3), keepdim=not self.flatten) + return x.mean(self.dim, keepdim=not self.flatten) + + +class FastAdaptiveMaxPool(nn.Module): + def __init__(self, flatten: bool = False, input_fmt: str = 'NCHW'): + super(FastAdaptiveMaxPool, self).__init__() + self.flatten = flatten + self.dim = get_spatial_dim(input_fmt) + + def forward(self, x): + return x.amax(self.dim, keepdim=not self.flatten) + + +class FastAdaptiveAvgMaxPool(nn.Module): + def __init__(self, flatten: bool = False, input_fmt: str = 'NCHW'): + super(FastAdaptiveAvgMaxPool, self).__init__() + self.flatten = flatten + self.dim = get_spatial_dim(input_fmt) + + def forward(self, x): + x_avg = x.mean(self.dim, keepdim=not self.flatten) + x_max = x.amax(self.dim, keepdim=not self.flatten) + return 0.5 * x_avg + 0.5 * x_max + + +class FastAdaptiveCatAvgMaxPool(nn.Module): + def __init__(self, flatten: bool = False, input_fmt: str = 'NCHW'): + super(FastAdaptiveCatAvgMaxPool, self).__init__() + self.flatten = flatten + self.dim_reduce = get_spatial_dim(input_fmt) + if flatten: + self.dim_cat = 1 + else: + self.dim_cat = get_channel_dim(input_fmt) + + def forward(self, x): + x_avg = x.mean(self.dim_reduce, keepdim=not self.flatten) + x_max = x.amax(self.dim_reduce, keepdim=not self.flatten) + return torch.cat((x_avg, x_max), self.dim_cat) class AdaptiveAvgMaxPool2d(nn.Module): - def __init__(self, output_size=1): + def __init__(self, output_size: _int_tuple_2_t = 1): super(AdaptiveAvgMaxPool2d, self).__init__() self.output_size = output_size @@ -68,7 +113,7 @@ class AdaptiveAvgMaxPool2d(nn.Module): class AdaptiveCatAvgMaxPool2d(nn.Module): - def __init__(self, output_size=1): + def __init__(self, output_size: _int_tuple_2_t = 1): super(AdaptiveCatAvgMaxPool2d, self).__init__() self.output_size = output_size @@ -79,26 +124,41 @@ class AdaptiveCatAvgMaxPool2d(nn.Module): class SelectAdaptivePool2d(nn.Module): """Selectable global pooling layer with dynamic input kernel size """ - def __init__(self, output_size=1, pool_type='fast', flatten=False): + def __init__( + self, + output_size: _int_tuple_2_t = 1, + pool_type: str = 'fast', + flatten: bool = False, + input_fmt: str = 'NCHW', + ): super(SelectAdaptivePool2d, self).__init__() + assert input_fmt in ('NCHW', 'NHWC') self.pool_type = pool_type or '' # convert other falsy values to empty string for consistent TS typing - self.flatten = nn.Flatten(1) if flatten else nn.Identity() - if pool_type == '': + if not pool_type: self.pool = nn.Identity() # pass through - elif pool_type == 'fast': - assert output_size == 1 - self.pool = FastAdaptiveAvgPool2d(flatten) + self.flatten = nn.Flatten(1) if flatten else nn.Identity() + elif pool_type.startswith('fast') or input_fmt != 'NCHW': + assert output_size == 1, 'Fast pooling and non NCHW input formats require output_size == 1.' + if pool_type.endswith('avgmax'): + self.pool = FastAdaptiveAvgMaxPool(flatten, input_fmt=input_fmt) + elif pool_type.endswith('catavgmax'): + self.pool = FastAdaptiveCatAvgMaxPool(flatten, input_fmt=input_fmt) + elif pool_type.endswith('max'): + self.pool = FastAdaptiveMaxPool(flatten, input_fmt=input_fmt) + else: + self.pool = FastAdaptiveAvgPool(flatten, input_fmt=input_fmt) self.flatten = nn.Identity() - elif pool_type == 'avg': - self.pool = nn.AdaptiveAvgPool2d(output_size) - elif pool_type == 'avgmax': - self.pool = AdaptiveAvgMaxPool2d(output_size) - elif pool_type == 'catavgmax': - self.pool = AdaptiveCatAvgMaxPool2d(output_size) - elif pool_type == 'max': - self.pool = nn.AdaptiveMaxPool2d(output_size) else: - assert False, 'Invalid pool type: %s' % pool_type + assert input_fmt == 'NCHW' + if pool_type == 'avgmax': + self.pool = AdaptiveAvgMaxPool2d(output_size) + elif pool_type == 'catavgmax': + self.pool = AdaptiveCatAvgMaxPool2d(output_size) + elif pool_type == 'max': + self.pool = nn.AdaptiveMaxPool2d(output_size) + else: + self.pool = nn.AdaptiveAvgPool2d(output_size) + self.flatten = nn.Flatten(1) if flatten else nn.Identity() def is_identity(self): return not self.pool_type diff --git a/timm/layers/classifier.py b/timm/layers/classifier.py index d93d0ec7..a95a4dfe 100644 --- a/timm/layers/classifier.py +++ b/timm/layers/classifier.py @@ -15,13 +15,23 @@ from .create_act import get_act_layer from .create_norm import get_norm_layer -def _create_pool(num_features, num_classes, pool_type='avg', use_conv=False): +def _create_pool( + num_features: int, + num_classes: int, + pool_type: str = 'avg', + use_conv: bool = False, + input_fmt: Optional[str] = None, +): flatten_in_pool = not use_conv # flatten when we use a Linear layer after pooling if not pool_type: assert num_classes == 0 or use_conv,\ 'Pooling can only be disabled if classifier is also removed or conv classifier is used' flatten_in_pool = False # disable flattening if pooling is pass-through (no pooling) - global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=flatten_in_pool) + global_pool = SelectAdaptivePool2d( + pool_type=pool_type, + flatten=flatten_in_pool, + input_fmt=input_fmt, + ) num_pooled_features = num_features * global_pool.feat_mult() return global_pool, num_pooled_features @@ -36,9 +46,25 @@ def _create_fc(num_features, num_classes, use_conv=False): return fc -def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False): - global_pool, num_pooled_features = _create_pool(num_features, num_classes, pool_type, use_conv=use_conv) - fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv) +def create_classifier( + num_features: int, + num_classes: int, + pool_type: str = 'avg', + use_conv: bool = False, + input_fmt: str = 'NCHW', +): + global_pool, num_pooled_features = _create_pool( + num_features, + num_classes, + pool_type, + use_conv=use_conv, + input_fmt=input_fmt, + ) + fc = _create_fc( + num_pooled_features, + num_classes, + use_conv=use_conv, + ) return global_pool, fc @@ -52,6 +78,7 @@ class ClassifierHead(nn.Module): pool_type: str = 'avg', drop_rate: float = 0., use_conv: bool = False, + input_fmt: str = 'NCHW', ): """ Args: @@ -64,28 +91,43 @@ class ClassifierHead(nn.Module): self.drop_rate = drop_rate self.in_features = in_features self.use_conv = use_conv + self.input_fmt = input_fmt - self.global_pool, num_pooled_features = _create_pool(in_features, num_classes, pool_type, use_conv=use_conv) - self.fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv) + self.global_pool, self.fc = create_classifier( + in_features, + num_classes, + pool_type, + use_conv=use_conv, + input_fmt=input_fmt, + ) self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity() - def reset(self, num_classes, global_pool=None): - if global_pool is not None: - if global_pool != self.global_pool.pool_type: - self.global_pool, _ = _create_pool(self.in_features, num_classes, global_pool, use_conv=self.use_conv) - self.flatten = nn.Flatten(1) if self.use_conv and global_pool else nn.Identity() - num_pooled_features = self.in_features * self.global_pool.feat_mult() - self.fc = _create_fc(num_pooled_features, num_classes, use_conv=self.use_conv) + def reset(self, num_classes, pool_type=None): + if pool_type is not None and pool_type != self.global_pool.pool_type: + self.global_pool, self.fc = create_classifier( + self.in_features, + num_classes, + pool_type=pool_type, + use_conv=self.use_conv, + input_fmt=self.input_fmt, + ) + self.flatten = nn.Flatten(1) if self.use_conv and pool_type else nn.Identity() + else: + num_pooled_features = self.in_features * self.global_pool.feat_mult() + self.fc = _create_fc( + num_pooled_features, + num_classes, + use_conv=self.use_conv, + ) def forward(self, x, pre_logits: bool = False): x = self.global_pool(x) if self.drop_rate: x = F.dropout(x, p=float(self.drop_rate), training=self.training) if pre_logits: - return x.flatten(1) - else: - x = self.fc(x) return self.flatten(x) + x = self.fc(x) + return self.flatten(x) class NormMlpClassifierHead(nn.Module): diff --git a/timm/layers/format.py b/timm/layers/format.py new file mode 100644 index 00000000..7eadc1af --- /dev/null +++ b/timm/layers/format.py @@ -0,0 +1,58 @@ +from enum import Enum +from typing import Union + +import torch + + +class Format(str, Enum): + NCHW = 'NCHW' + NHWC = 'NHWC' + NCL = 'NCL' + NLC = 'NLC' + + +FormatT = Union[str, Format] + + +def get_spatial_dim(fmt: FormatT): + fmt = Format(fmt) + if fmt is Format.NLC: + dim = (1,) + elif fmt is Format.NCL: + dim = (2,) + elif fmt is Format.NHWC: + dim = (1, 2) + else: + dim = (2, 3) + return dim + + +def get_channel_dim(fmt: FormatT): + fmt = Format(fmt) + if fmt is Format.NHWC: + dim = 3 + elif fmt is Format.NLC: + dim = 2 + else: + dim = 1 + return dim + + +def nchw_to(x: torch.Tensor, fmt: Format): + if fmt == Format.NHWC: + x = x.permute(0, 2, 3, 1) + elif fmt == Format.NLC: + x = x.flatten(2).transpose(1, 2) + elif fmt == Format.NCL: + x = x.flatten(2) + return x + + +def nhwc_to(x: torch.Tensor, fmt: Format): + if fmt == Format.NCHW: + x = x.permute(0, 3, 1, 2) + elif fmt == Format.NLC: + x = x.flatten(1, 2) + elif fmt == Format.NCL: + x = x.flatten(1, 2).transpose(1, 2) + return x diff --git a/timm/layers/patch_embed.py b/timm/layers/patch_embed.py index 764519f2..05768674 100644 --- a/timm/layers/patch_embed.py +++ b/timm/layers/patch_embed.py @@ -9,12 +9,13 @@ Based on code in: Hacked together by / Copyright 2020 Ross Wightman """ import logging -from typing import List +from typing import List, Optional, Callable import torch from torch import nn as nn import torch.nn.functional as F +from .format import Format, nchw_to from .helpers import to_2tuple from .trace_utils import _assert @@ -24,15 +25,18 @@ _logger = logging.getLogger(__name__) class PatchEmbed(nn.Module): """ 2D Image to Patch Embedding """ + output_fmt: Format + def __init__( self, - img_size=224, - patch_size=16, - in_chans=3, - embed_dim=768, - norm_layer=None, - flatten=True, - bias=True, + img_size: int = 224, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten: bool = True, + output_fmt: Optional[str] = None, + bias: bool = True, ): super().__init__() img_size = to_2tuple(img_size) @@ -41,7 +45,13 @@ class PatchEmbed(nn.Module): self.patch_size = patch_size self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) self.num_patches = self.grid_size[0] * self.grid_size[1] - self.flatten = flatten + if output_fmt is not None: + self.flatten = False + self.output_fmt = Format(output_fmt) + else: + # flatten spatial dim and transpose to channels last, kept for bwd compat + self.flatten = flatten + self.output_fmt = Format.NCHW self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() @@ -52,7 +62,9 @@ class PatchEmbed(nn.Module): _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).") x = self.proj(x) if self.flatten: - x = x.flatten(2).transpose(1, 2) # BCHW -> BNC + x = x.flatten(2).transpose(1, 2) # NCHW -> NLC + elif self.output_fmt != Format.NCHW: + x = nchw_to(x, self.output_fmt) x = self.norm(x) return x diff --git a/timm/layers/pos_embed_rel.py b/timm/layers/pos_embed_rel.py index 7b843dc5..5cb3d0f4 100644 --- a/timm/layers/pos_embed_rel.py +++ b/timm/layers/pos_embed_rel.py @@ -15,26 +15,48 @@ from .weight_init import trunc_normal_ def gen_relative_position_index( q_size: Tuple[int, int], - k_size: Tuple[int, int] = None, - class_token: bool = False) -> torch.Tensor: + k_size: Optional[Tuple[int, int]] = None, + class_token: bool = False, +) -> torch.Tensor: # Adapted with significant modifications from Swin / BeiT codebases # get pair-wise relative position index for each token inside the window - q_coords = torch.stack(torch.meshgrid([torch.arange(q_size[0]), torch.arange(q_size[1])])).flatten(1) # 2, Wh, Ww if k_size is None: - k_coords = q_coords - k_size = q_size + coords = torch.stack( + torch.meshgrid([ + torch.arange(q_size[0]), + torch.arange(q_size[1]) + ]) + ).flatten(1) # 2, Wh, Ww + relative_coords = coords[:, :, None] - coords[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0) # Qh*Qw, Kh*Kw, 2 + num_relative_distance = (2 * q_size[0] - 1) * (2 * q_size[1] - 1) + 3 else: - # different q vs k sizes is a WIP - k_coords = torch.stack(torch.meshgrid([torch.arange(k_size[0]), torch.arange(k_size[1])])).flatten(1) - relative_coords = q_coords[:, :, None] - k_coords[:, None, :] # 2, Wh*Ww, Wh*Ww - relative_coords = relative_coords.permute(1, 2, 0) # Wh*Ww, Wh*Ww, 2 + # FIXME different q vs k sizes is a WIP, need to better offset the two grids? + q_coords = torch.stack( + torch.meshgrid([ + torch.arange(q_size[0]), + torch.arange(q_size[1]) + ]) + ).flatten(1) # 2, Wh, Ww + k_coords = torch.stack( + torch.meshgrid([ + torch.arange(k_size[0]), + torch.arange(k_size[1]) + ]) + ).flatten(1) + relative_coords = q_coords[:, :, None] - k_coords[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0) # Qh*Qw, Kh*Kw, 2 + # relative_coords[:, :, 0] += max(q_size[0], k_size[0]) - 1 # shift to start from 0 + # relative_coords[:, :, 1] += max(q_size[1], k_size[1]) - 1 + # relative_coords[:, :, 0] *= k_size[1] + q_size[1] - 1 + # relative_position_index = relative_coords.sum(-1) # Qh*Qw, Kh*Kw + num_relative_distance = (q_size[0] + k_size[0] - 1) * (q_size[1] + q_size[1] - 1) + 3 + _, relative_position_index = torch.unique(relative_coords.view(-1, 2), return_inverse=True, dim=0) if class_token: # handle cls to token & token 2 cls & cls to cls as per beit for rel pos bias # NOTE not intended or tested with MLP log-coords - max_size = (max(q_size[0], k_size[0]), max(q_size[1], k_size[1])) - num_relative_distance = (2 * max_size[0] - 1) * (2 * max_size[1] - 1) + 3 relative_position_index = F.pad(relative_position_index, [1, 0, 1, 0]) relative_position_index[0, 0:] = num_relative_distance - 3 relative_position_index[0:, 0] = num_relative_distance - 2 @@ -59,7 +81,7 @@ class RelPosBias(nn.Module): self.relative_position_bias_table = nn.Parameter(torch.zeros(num_relative_distance, num_heads)) self.register_buffer( "relative_position_index", - gen_relative_position_index(self.window_size, class_token=prefix_tokens > 0), + gen_relative_position_index(self.window_size, class_token=prefix_tokens > 0).view(-1), persistent=False, ) @@ -69,7 +91,7 @@ class RelPosBias(nn.Module): trunc_normal_(self.relative_position_bias_table, std=.02) def get_bias(self) -> torch.Tensor: - relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)] + relative_position_bias = self.relative_position_bias_table[self.relative_position_index] # win_h * win_w, win_h * win_w, num_heads relative_position_bias = relative_position_bias.view(self.bias_shape).permute(2, 0, 1) return relative_position_bias.unsqueeze(0).contiguous() @@ -148,7 +170,7 @@ class RelPosMlp(nn.Module): self.register_buffer( "relative_position_index", - gen_relative_position_index(window_size), + gen_relative_position_index(window_size).view(-1), persistent=False) # get relative_coords_table @@ -160,8 +182,7 @@ class RelPosMlp(nn.Module): def get_bias(self) -> torch.Tensor: relative_position_bias = self.mlp(self.rel_coords_log) if self.relative_position_index is not None: - relative_position_bias = relative_position_bias.view(-1, self.num_heads)[ - self.relative_position_index.view(-1)] # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.view(-1, self.num_heads)[self.relative_position_index] relative_position_bias = relative_position_bias.view(self.bias_shape) relative_position_bias = relative_position_bias.permute(2, 0, 1) relative_position_bias = self.bias_act(relative_position_bias) diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 05cbbc81..7c2ab72b 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -72,7 +72,8 @@ from ._builder import build_model_with_cfg, load_pretrained, load_custom_pretrai from ._factory import create_model, parse_model_name, safe_model_name from ._features import FeatureInfo, FeatureHooks, FeatureHookNet, FeatureListNet, FeatureDictNet from ._features_fx import FeatureGraphNet, GraphExtractNet, create_feature_extractor, \ - register_notrace_module, register_notrace_function + register_notrace_module, is_notrace_module, get_notrace_modules, \ + register_notrace_function, is_notrace_function, get_notrace_functions from ._helpers import clean_state_dict, load_state_dict, load_checkpoint, remap_checkpoint, resume_checkpoint from ._hub import load_model_config_from_hf, load_state_dict_from_hf, push_to_hf_hub from ._manipulate import model_parameters, named_apply, named_modules, named_modules_with_params, \ diff --git a/timm/models/_builder.py b/timm/models/_builder.py index 32a35304..f6a3f541 100644 --- a/timm/models/_builder.py +++ b/timm/models/_builder.py @@ -392,6 +392,9 @@ def build_model_with_cfg( # Wrap the model in a feature extraction module if enabled if features: feature_cls = FeatureListNet + output_fmt = getattr(model, 'output_fmt', None) + if output_fmt is not None: + feature_cfg.setdefault('output_fmt', output_fmt) if 'feature_cls' in feature_cfg: feature_cls = feature_cfg.pop('feature_cls') if isinstance(feature_cls, str): @@ -403,7 +406,7 @@ def build_model_with_cfg( else: assert False, f'Unknown feature class {feature_cls}' model = feature_cls(model, **feature_cfg) - model.pretrained_cfg = pretrained_cfg_for_features(pretrained_cfg) # add back default_cfg - model.default_cfg = model.pretrained_cfg # alias for backwards compat + model.pretrained_cfg = pretrained_cfg_for_features(pretrained_cfg) # add back pretrained cfg + model.default_cfg = model.pretrained_cfg # alias for rename backwards compat (default_cfg -> pretrained_cfg) return model diff --git a/timm/models/_features.py b/timm/models/_features.py index 8e0b8984..04d48b14 100644 --- a/timm/models/_features.py +++ b/timm/models/_features.py @@ -17,6 +17,8 @@ import torch import torch.nn as nn from torch.utils.checkpoint import checkpoint +from timm.layers import Format + __all__ = ['FeatureInfo', 'FeatureHooks', 'FeatureDictNet', 'FeatureListNet', 'FeatureHookNet'] @@ -98,7 +100,7 @@ class FeatureHooks: self, hooks: Sequence[str], named_modules: dict, - out_map: Sequence[Union[int, str]] = None, + return_map: Sequence[Union[int, str]] = None, default_hook_type: str = 'forward', ): # setup feature hooks @@ -107,7 +109,7 @@ class FeatureHooks: for i, h in enumerate(hooks): hook_name = h['module'] m = modules[hook_name] - hook_id = out_map[i] if out_map else hook_name + hook_id = return_map[i] if return_map else hook_name hook_fn = partial(self._collect_output_hook, hook_id) hook_type = h.get('hook_type', default_hook_type) if hook_type == 'forward_pre': @@ -153,11 +155,11 @@ def _get_feature_info(net, out_indices): assert False, "Provided feature_info is not valid" -def _get_return_layers(feature_info, out_map): +def _get_return_layers(feature_info, return_map): module_names = feature_info.module_name() return_layers = {} for i, name in enumerate(module_names): - return_layers[name] = out_map[i] if out_map is not None else feature_info.out_indices[i] + return_layers[name] = return_map[i] if return_map is not None else feature_info.out_indices[i] return return_layers @@ -180,7 +182,8 @@ class FeatureDictNet(nn.ModuleDict): self, model: nn.Module, out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4), - out_map: Sequence[Union[int, str]] = None, + return_map: Sequence[Union[int, str]] = None, + output_fmt: str = 'NCHW', feature_concat: bool = False, flatten_sequential: bool = False, ): @@ -188,18 +191,19 @@ class FeatureDictNet(nn.ModuleDict): Args: model: Model from which to extract features. out_indices: Output indices of the model features to extract. - out_map: Return id mapping for each output index, otherwise str(index) is used. + return_map: Return id mapping for each output index, otherwise str(index) is used. feature_concat: Concatenate intermediate features that are lists or tuples instead of selecting first element e.g. `x[0]` flatten_sequential: Flatten first two-levels of sequential modules in model (re-writes model modules) """ super(FeatureDictNet, self).__init__() self.feature_info = _get_feature_info(model, out_indices) + self.output_fmt = Format(output_fmt) self.concat = feature_concat self.grad_checkpointing = False self.return_layers = {} - return_layers = _get_return_layers(self.feature_info, out_map) + return_layers = _get_return_layers(self.feature_info, return_map) modules = _module_list(model, flatten_sequential=flatten_sequential) remaining = set(return_layers.keys()) layers = OrderedDict() @@ -253,6 +257,7 @@ class FeatureListNet(FeatureDictNet): self, model: nn.Module, out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4), + output_fmt: str = 'NCHW', feature_concat: bool = False, flatten_sequential: bool = False, ): @@ -264,9 +269,10 @@ class FeatureListNet(FeatureDictNet): first element e.g. `x[0]` flatten_sequential: Flatten first two-levels of sequential modules in model (re-writes model modules) """ - super(FeatureListNet, self).__init__( + super().__init__( model, out_indices=out_indices, + output_fmt=output_fmt, feature_concat=feature_concat, flatten_sequential=flatten_sequential, ) @@ -292,8 +298,9 @@ class FeatureHookNet(nn.ModuleDict): self, model: nn.Module, out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4), - out_map: Sequence[Union[int, str]] = None, - out_as_dict: bool = False, + return_map: Sequence[Union[int, str]] = None, + return_dict: bool = False, + output_fmt: str = 'NCHW', no_rewrite: bool = False, flatten_sequential: bool = False, default_hook_type: str = 'forward', @@ -303,17 +310,18 @@ class FeatureHookNet(nn.ModuleDict): Args: model: Model from which to extract features. out_indices: Output indices of the model features to extract. - out_map: Return id mapping for each output index, otherwise str(index) is used. - out_as_dict: Output features as a dict. + return_map: Return id mapping for each output index, otherwise str(index) is used. + return_dict: Output features as a dict. no_rewrite: Enforce that model is not re-written if True, ie no modules are removed / changed. flatten_sequential arg must also be False if this is set True. flatten_sequential: Re-write modules by flattening first two levels of nn.Sequential containers. default_hook_type: The default hook type to use if not specified in model.feature_info. """ - super(FeatureHookNet, self).__init__() + super().__init__() assert not torch.jit.is_scripting() self.feature_info = _get_feature_info(model, out_indices) - self.out_as_dict = out_as_dict + self.return_dict = return_dict + self.output_fmt = Format(output_fmt) self.grad_checkpointing = False layers = OrderedDict() @@ -340,7 +348,7 @@ class FeatureHookNet(nn.ModuleDict): break assert not remaining, f'Return layers ({remaining}) are not present in model' self.update(layers) - self.hooks = FeatureHooks(hooks, model.named_modules(), out_map=out_map) + self.hooks = FeatureHooks(hooks, model.named_modules(), return_map=return_map) def set_grad_checkpointing(self, enable: bool = True): self.grad_checkpointing = enable @@ -356,4 +364,4 @@ class FeatureHookNet(nn.ModuleDict): else: x = module(x) out = self.hooks.get_output(x.device) - return out if self.out_as_dict else list(out.values()) + return out if self.return_dict else list(out.values()) diff --git a/timm/models/_features_fx.py b/timm/models/_features_fx.py index 10670a1d..c894d66c 100644 --- a/timm/models/_features_fx.py +++ b/timm/models/_features_fx.py @@ -19,6 +19,11 @@ from timm.layers import Conv2dSame, ScaledStdConv2dSame, CondConv2d, StdConv2dSa from timm.layers.non_local_attn import BilinearAttnTransform from timm.layers.pool2d_same import MaxPool2dSame, AvgPool2dSame +__all__ = ['register_notrace_module', 'is_notrace_module', 'get_notrace_modules', + 'register_notrace_function', 'is_notrace_function', 'get_notrace_functions', + 'create_feature_extractor', 'FeatureGraphNet', 'GraphExtractNet'] + + # NOTE: By default, any modules from timm.models.layers that we want to treat as leaf modules go here # BUT modules from timm.models should use the registration mechanism below _leaf_modules = { @@ -35,10 +40,6 @@ except ImportError: pass -__all__ = ['register_notrace_module', 'register_notrace_function', 'create_feature_extractor', - 'FeatureGraphNet', 'GraphExtractNet'] - - def register_notrace_module(module: Type[nn.Module]): """ Any module not under timm.models.layers should get this decorator if we don't want to trace through it. @@ -47,6 +48,14 @@ def register_notrace_module(module: Type[nn.Module]): return module +def is_notrace_module(module: Type[nn.Module]): + return module in _leaf_modules + + +def get_notrace_modules(): + return list(_leaf_modules) + + # Functions we want to autowrap (treat them as leaves) _autowrap_functions = set() @@ -59,6 +68,14 @@ def register_notrace_function(func: Callable): return func +def is_notrace_function(func: Callable): + return func in _autowrap_functions + + +def get_notrace_functions(): + return list(_autowrap_functions) + + def create_feature_extractor(model: nn.Module, return_nodes: Union[Dict[str, str], List[str]]): assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction' return _create_feature_extractor( diff --git a/timm/models/convnext.py b/timm/models/convnext.py index b2e8fce7..bad5a44b 100644 --- a/timm/models/convnext.py +++ b/timm/models/convnext.py @@ -406,7 +406,7 @@ class ConvNeXt(nn.Module): return self.head.fc def reset_classifier(self, num_classes=0, global_pool=None): - self.head.reset(num_classes, global_pool=global_pool) + self.head.reset(num_classes, global_pool) def forward_features(self, x): x = self.stem(x) @@ -415,7 +415,7 @@ class ConvNeXt(nn.Module): return x def forward_head(self, x, pre_logits: bool = False): - return self.head(x, pre_logits=pre_logits) + return self.head(x, pre_logits=True) if pre_logits else self.head(x) def forward(self, x): x = self.forward_features(x) diff --git a/timm/models/dla.py b/timm/models/dla.py index 204fcb4b..5231225e 100644 --- a/timm/models/dla.py +++ b/timm/models/dla.py @@ -359,10 +359,9 @@ class DLA(nn.Module): if self.drop_rate > 0.: x = F.dropout(x, p=self.drop_rate, training=self.training) if pre_logits: - return x.flatten(1) - else: - x = self.fc(x) return self.flatten(x) + x = self.fc(x) + return self.flatten(x) def forward(self, x): x = self.forward_features(x) diff --git a/timm/models/dpn.py b/timm/models/dpn.py index 29a7a7e8..4ee42778 100644 --- a/timm/models/dpn.py +++ b/timm/models/dpn.py @@ -298,10 +298,9 @@ class DPN(nn.Module): if self.drop_rate > 0.: x = F.dropout(x, p=self.drop_rate, training=self.training) if pre_logits: - return x.flatten(1) - else: - x = self.classifier(x) return self.flatten(x) + x = self.classifier(x) + return self.flatten(x) def forward(self, x): x = self.forward_features(x) diff --git a/timm/models/focalnet.py b/timm/models/focalnet.py index 8178cfc3..5b201db2 100644 --- a/timm/models/focalnet.py +++ b/timm/models/focalnet.py @@ -18,6 +18,7 @@ This impl is/has: # Written by Jianwei Yang (jianwyan@microsoft.com) # -------------------------------------------------------- from functools import partial +from typing import Callable, Optional, Tuple import torch import torch.nn as nn @@ -35,15 +36,15 @@ __all__ = ['FocalNet'] class FocalModulation(nn.Module): def __init__( self, - dim, + dim: int, focal_window, - focal_level, - focal_factor=2, - bias=True, - use_post_norm=False, - normalize_modulator=False, - proj_drop=0., - norm_layer=LayerNorm2d, + focal_level: int, + focal_factor: int = 2, + bias: bool = True, + use_post_norm: bool = False, + normalize_modulator: bool = False, + proj_drop: float = 0., + norm_layer: Callable = LayerNorm2d, ): super().__init__() @@ -118,36 +119,38 @@ class LayerScale2d(nn.Module): class FocalNetBlock(nn.Module): - r""" Focal Modulation Network Block. - - Args: - dim (int): Number of input channels. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - proj_drop (float, optional): Dropout rate. Default: 0.0 - drop_path (float, optional): Stochastic depth rate. Default: 0.0 - act_layer (nn.Module, optional): Activation layer. Default: nn.GELU - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - focal_level (int): Number of focal levels. - focal_window (int): Focal window size at first focal level - layerscale_value (float): Initial layerscale value - use_post_norm (bool): Whether to use layernorm after modulation + """ Focal Modulation Network Block. """ def __init__( self, - dim, - mlp_ratio=4., - focal_level=1, - focal_window=3, - use_post_norm=False, - use_post_norm_in_modulation=False, - normalize_modulator=False, - layerscale_value=1e-4, - proj_drop=0., - drop_path=0., - act_layer=nn.GELU, - norm_layer=LayerNorm2d, + dim: int, + mlp_ratio: float = 4., + focal_level: int = 1, + focal_window: int = 3, + use_post_norm: bool = False, + use_post_norm_in_modulation: bool = False, + normalize_modulator: bool = False, + layerscale_value: float = 1e-4, + proj_drop: float = 0., + drop_path: float = 0., + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm2d, ): + """ + Args: + dim: Number of input channels. + mlp_ratio: Ratio of mlp hidden dim to embedding dim. + focal_level: Number of focal levels. + focal_window: Focal window size at first focal level. + use_post_norm: Whether to use layer norm after modulation. + use_post_norm_in_modulation: Whether to use layer norm in modulation. + layerscale_value: Initial layerscale value. + proj_drop: Dropout rate. + drop_path: Stochastic depth rate. + act_layer: Activation layer. + norm_layer: Normalization layer. + """ super().__init__() self.dim = dim self.mlp_ratio = mlp_ratio @@ -197,42 +200,45 @@ class FocalNetBlock(nn.Module): return x -class BasicLayer(nn.Module): +class FocalNetStage(nn.Module): """ A basic Focal Transformer layer for one stage. - - Args: - dim (int): Number of input channels. - depth (int): Number of blocks. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - drop (float, optional): Dropout rate. Default: 0.0 - drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - downsample (bool): Downsample layer at start of the layer. Default: True - focal_level (int): Number of focal levels - focal_window (int): Focal window size at first focal level - layerscale_value (float): Initial layerscale value - use_post_norm (bool): Whether to use layer norm after modulation """ def __init__( self, - dim, - out_dim, - depth, - mlp_ratio=4., - downsample=True, - focal_level=1, - focal_window=1, - use_overlap_down=False, - use_post_norm=False, - use_post_norm_in_modulation=False, - normalize_modulator=False, - layerscale_value=1e-4, - proj_drop=0., - drop_path=0., - norm_layer=LayerNorm2d, + dim: int, + out_dim: int, + depth: int, + mlp_ratio: float = 4., + downsample: bool = True, + focal_level: int = 1, + focal_window: int = 1, + use_overlap_down: bool = False, + use_post_norm: bool = False, + use_post_norm_in_modulation: bool = False, + normalize_modulator: bool = False, + layerscale_value: float = 1e-4, + proj_drop: float = 0., + drop_path: float = 0., + norm_layer: Callable = LayerNorm2d, ): - + """ + Args: + dim: Number of input channels. + out_dim: Number of output channels. + depth: Number of blocks. + mlp_ratio: Ratio of mlp hidden dim to embedding dim. + downsample: Downsample layer at start of the layer. + focal_level: Number of focal levels + focal_window: Focal window size at first focal level + use_overlap_down: User overlapped convolution in downsample layer. + use_post_norm: Whether to use layer norm after modulation. + use_post_norm_in_modulation: Whether to use layer norm in modulation. + layerscale_value: Initial layerscale value + proj_drop: Dropout rate for projections. + drop_path: Stochastic depth rate. + norm_layer: Normalization layer. + """ super().__init__() self.dim = dim self.depth = depth @@ -281,22 +287,24 @@ class BasicLayer(nn.Module): class Downsample(nn.Module): - r""" - Args: - in_chs (int): Number of input image channels - out_chs (int): Number of linear projection output channels - stride (int): Downsample stride. Default: 4. - norm_layer (nn.Module, optional): Normalization layer. Default: None - """ def __init__( self, - in_chs, - out_chs, - stride=4, - overlap=False, - norm_layer=None, + in_chs: int, + out_chs: int, + stride: int = 4, + overlap: bool = False, + norm_layer: Optional[Callable] = None, ): + """ + + Args: + in_chs: Number of input image channels. + out_chs: Number of linear projection output channels. + stride: Downsample stride. + overlap: Use overlapping convolutions if True. + norm_layer: Normalization layer. + """ super().__init__() self.stride = stride padding = 0 @@ -317,49 +325,47 @@ class Downsample(nn.Module): class FocalNet(nn.Module): - r""" Focal Modulation Networks (FocalNets) - - Args: - in_chans (int): Number of input image channels. Default: 3 - num_classes (int): Number of classes for classification head. Default: 1000 - embed_dim (int): Patch embedding dimension. Default: 96 - depths (tuple(int)): Depth of each Focal Transformer layer. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 - focal_levels (list): How many focal levels at all stages. Note that this excludes the finest-grain level. - Default: [1, 1, 1, 1] - focal_windows (list): The focal window size at all stages. Default: [7, 5, 3, 1] - use_overlap_down (bool): Whether to use convolutional embedding. - use_post_norm (bool): Whether to use layernorm after modulation (it helps stablize training of large models) - layerscale_value (float): Value for layer scale. Default: 1e-4 - drop_rate (float): Dropout rate. Default: 0 - drop_path_rate (float): Stochastic depth rate. Default: 0.1 - norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. - + """" Focal Modulation Networks (FocalNets) """ def __init__( self, - in_chans=3, - num_classes=1000, - global_pool='avg', - embed_dim=96, - depths=(2, 2, 6, 2), - mlp_ratio=4., - focal_levels=(2, 2, 2, 2), - focal_windows=(3, 3, 3, 3), - use_overlap_down=False, - use_post_norm=False, - use_post_norm_in_modulation=False, - normalize_modulator=False, - head_hidden_size=None, - head_init_scale=1.0, - layerscale_value=None, - drop_rate=0., - proj_drop_rate=0., - drop_path_rate=0.1, - norm_layer=partial(LayerNorm2d, eps=1e-5), - **kwargs, + in_chans: int = 3, + num_classes: int = 1000, + global_pool: str = 'avg', + embed_dim: int = 96, + depths: Tuple[int, ...] = (2, 2, 6, 2), + mlp_ratio: float = 4., + focal_levels: Tuple[int, ...] = (2, 2, 2, 2), + focal_windows: Tuple[int, ...] = (3, 3, 3, 3), + use_overlap_down: bool = False, + use_post_norm: bool = False, + use_post_norm_in_modulation: bool = False, + normalize_modulator: bool = False, + head_hidden_size: Optional[int] = None, + head_init_scale: float = 1.0, + layerscale_value: Optional[float] = None, + drop_rate: bool = 0., + proj_drop_rate: bool = 0., + drop_path_rate: bool = 0.1, + norm_layer: Callable = partial(LayerNorm2d, eps=1e-5), ): + """ + Args: + in_chans: Number of input image channels. + num_classes: Number of classes for classification head. + embed_dim: Patch embedding dimension. + depths: Depth of each Focal Transformer layer. + mlp_ratio: Ratio of mlp hidden dim to embedding dim. + focal_levels: How many focal levels at all stages. Note that this excludes the finest-grain level. + focal_windows: The focal window size at all stages. + use_overlap_down: Whether to use convolutional embedding. + use_post_norm: Whether to use layernorm after modulation (it helps stablize training of large models) + layerscale_value: Value for layer scale. + drop_rate: Dropout rate. + drop_path_rate: Stochastic depth rate. + norm_layer: Normalization layer. + """ super().__init__() self.num_layers = len(depths) @@ -382,7 +388,7 @@ class FocalNet(nn.Module): layers = [] for i_layer in range(self.num_layers): out_dim = embed_dim[i_layer] - layer = BasicLayer( + layer = FocalNetStage( dim=in_dim, out_dim=out_dim, depth=depths[i_layer], @@ -438,10 +444,10 @@ class FocalNet(nn.Module): @torch.jit.ignore def get_classifier(self): - return self.classifier.fc + return self.head.fc def reset_classifier(self, num_classes, global_pool=None): - self.classifier.reset(num_classes, global_pool=global_pool) + self.head.reset(num_classes, pool_type=global_pool) def forward_features(self, x): x = self.stem(x) @@ -475,7 +481,7 @@ def _init_weights(module, name=None, head_init_scale=1.0): def _cfg(url='', **kwargs): return { 'url': url, - 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 'crop_pct': .9, 'interpolation': 'bicubic', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'first_conv': 'stem.proj', 'classifier': 'head.fc', @@ -498,19 +504,19 @@ default_cfgs = { url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_base_lrf.pth'), "focalnet_large_fl3": _cfg( url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_large_lrf_384.pth', - input_size=(3, 384, 384), crop_pct=1.0, num_classes=21842), + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=21842), "focalnet_large_fl4": _cfg( url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_large_lrf_384_fl4.pth', - input_size=(3, 384, 384), crop_pct=1.0, num_classes=21842), + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=21842), "focalnet_xlarge_fl3": _cfg( url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_xlarge_lrf_384.pth', - input_size=(3, 384, 384), crop_pct=1.0, num_classes=21842), + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=21842), "focalnet_xlarge_fl4": _cfg( url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_xlarge_lrf_384_fl4.pth', - input_size=(3, 384, 384), crop_pct=1.0, num_classes=21842), + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=21842), "focalnet_huge_fl3": _cfg( url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_huge_lrf_224.pth', - num_classes=0), + num_classes=21842), "focalnet_huge_fl4": _cfg( url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_huge_lrf_224_fl4.pth', num_classes=0), @@ -533,7 +539,7 @@ def checkpoint_filter_fn(state_dict, model: FocalNet): k = re.sub(r'norm([0-9])', r'norm\1_post', k) k = k.replace('ln.', 'norm.') k = k.replace('head', 'head.fc') - if dest_dict[k].shape != v.shape: + if k in dest_dict and dest_dict[k].numel() == v.numel() and dest_dict[k].shape != v.shape: v = v.reshape(dest_dict[k].shape) out_dict[k] = v return out_dict diff --git a/timm/models/gcvit.py b/timm/models/gcvit.py index 2423a954..32d2e703 100644 --- a/timm/models/gcvit.py +++ b/timm/models/gcvit.py @@ -29,12 +29,11 @@ import torch.utils.checkpoint as checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import DropPath, to_2tuple, to_ntuple, Mlp, ClassifierHead, LayerNorm2d, \ - get_attn, get_act_layer, get_norm_layer, _assert + get_attn, get_act_layer, get_norm_layer, RelPosBias, _assert from ._builder import build_model_with_cfg from ._features_fx import register_notrace_function from ._manipulate import named_apply from ._registry import register_model -from .vision_transformer_relpos import RelPosBias # FIXME move to common location __all__ = ['GlobalContextVit'] @@ -222,7 +221,7 @@ class WindowAttentionGlobal(nn.Module): q, k, v = qkv.unbind(0) q = q * self.scale - attn = (q @ k.transpose(-2, -1)) + attn = q @ k.transpose(-2, -1).contiguous() # NOTE contiguous() fixes an odd jit bug in PyTorch 2.0 attn = self.rel_pos(attn) attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index 5943781f..d50a1ae0 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -145,13 +145,12 @@ class MobileNetV3(nn.Module): x = self.global_pool(x) x = self.conv_head(x) x = self.act2(x) + x = self.flatten(x) if pre_logits: - return x.flatten(1) - else: - x = self.flatten(x) - if self.drop_rate > 0.: - x = F.dropout(x, p=self.drop_rate, training=self.training) - return self.classifier(x) + return x + if self.drop_rate > 0.: + x = F.dropout(x, p=self.drop_rate, training=self.training) + return self.classifier(x) def forward(self, x): x = self.forward_features(x) diff --git a/timm/models/swin_transformer.py b/timm/models/swin_transformer.py index bbc97036..495a4c94 100644 --- a/timm/models/swin_transformer.py +++ b/timm/models/swin_transformer.py @@ -17,86 +17,24 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W # -------------------------------------------------------- import logging import math -from typing import Optional +from typing import Callable, List, Optional, Tuple, Union import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, to_ntuple, trunc_normal_, _assert +from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, to_ntuple, trunc_normal_, _assert, ClassifierHead from ._builder import build_model_with_cfg from ._features_fx import register_notrace_function from ._manipulate import checkpoint_seq, named_apply from ._registry import register_model -from .vision_transformer import checkpoint_filter_fn, get_init_weights_vit +from .vision_transformer import get_init_weights_vit __all__ = ['SwinTransformer'] # model_registry will add each entrypoint fn to this _logger = logging.getLogger(__name__) - -def _cfg(url='', **kwargs): - return { - 'url': url, - 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, - 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, - 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'patch_embed.proj', 'classifier': 'head', - **kwargs - } - - -default_cfgs = { - 'swin_base_patch4_window12_384': _cfg( - url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22kto1k.pth', - input_size=(3, 384, 384), crop_pct=1.0), - - 'swin_base_patch4_window7_224': _cfg( - url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22kto1k.pth', - ), - - 'swin_large_patch4_window12_384': _cfg( - url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22kto1k.pth', - input_size=(3, 384, 384), crop_pct=1.0), - - 'swin_large_patch4_window7_224': _cfg( - url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22kto1k.pth', - ), - - 'swin_small_patch4_window7_224': _cfg( - url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth', - ), - - 'swin_tiny_patch4_window7_224': _cfg( - url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth', - ), - - 'swin_base_patch4_window12_384_in22k': _cfg( - url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth', - input_size=(3, 384, 384), crop_pct=1.0, num_classes=21841), - - 'swin_base_patch4_window7_224_in22k': _cfg( - url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth', - num_classes=21841), - - 'swin_large_patch4_window12_384_in22k': _cfg( - url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth', - input_size=(3, 384, 384), crop_pct=1.0, num_classes=21841), - - 'swin_large_patch4_window7_224_in22k': _cfg( - url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth', - num_classes=21841), - - 'swin_s3_tiny_224': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/s3_t-1d53f6a8.pth' - ), - 'swin_s3_small_224': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/s3_s-3bb4c69d.pth' - ), - 'swin_s3_base_224': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/s3_b-a1e95db4.pth' - ) -} +_int_or_tuple_2_t = Union[int, Tuple[int, int]] def window_partition(x, window_size: int): @@ -132,7 +70,7 @@ def window_reverse(windows, window_size: int, H: int, W: int): return x -def get_relative_position_index(win_h, win_w): +def get_relative_position_index(win_h: int, win_w: int): # get pair-wise relative position index for each token inside the window coords = torch.stack(torch.meshgrid([torch.arange(win_h), torch.arange(win_w)])) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww @@ -145,21 +83,30 @@ def get_relative_position_index(win_h, win_w): class WindowAttention(nn.Module): - r""" Window based multi-head self attention (W-MSA) module with relative position bias. - It supports both of shifted and non-shifted window. - - Args: - dim (int): Number of input channels. - num_heads (int): Number of attention heads. - head_dim (int): Number of channels per head (dim // num_heads if not set) - window_size (tuple[int]): The height and width of the window. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 - proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ Window based multi-head self attention (W-MSA) module with relative position bias. + It supports shifted and non-shifted windows. """ - def __init__(self, dim, num_heads, head_dim=None, window_size=7, qkv_bias=True, attn_drop=0., proj_drop=0.): - + def __init__( + self, + dim: int, + num_heads: int, + head_dim: Optional[int] = None, + window_size: _int_or_tuple_2_t = 7, + qkv_bias: bool = True, + attn_drop: float = 0., + proj_drop: float = 0., + ): + """ + Args: + dim: Number of input channels. + num_heads: Number of attention heads. + head_dim: Number of channels per head (dim // num_heads if not set) + window_size: The height and width of the window. + qkv_bias: If True, add a learnable bias to query, key, value. + attn_drop: Dropout ratio of attention weight. + proj_drop: Dropout ratio of output. + """ super().__init__() self.dim = dim self.window_size = to_2tuple(window_size) # Wh, Ww @@ -198,7 +145,7 @@ class WindowAttention(nn.Module): """ B_, N, C = x.shape qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) - q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) + q, k, v = qkv.unbind(0) q = q * self.scale attn = (q @ k.transpose(-2, -1)) @@ -206,7 +153,7 @@ class WindowAttention(nn.Module): if mask is not None: num_win = mask.shape[0] - attn = attn.view(B_ // num_win, num_win, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, num_win, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) attn = attn.view(-1, self.num_heads, N, N) attn = self.softmax(attn) else: @@ -221,28 +168,41 @@ class WindowAttention(nn.Module): class SwinTransformerBlock(nn.Module): - r""" Swin Transformer Block. - - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resulotion. - window_size (int): Window size. - num_heads (int): Number of attention heads. - head_dim (int): Enforce the number of channels per head - shift_size (int): Shift size for SW-MSA. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float, optional): Stochastic depth rate. Default: 0.0 - act_layer (nn.Module, optional): Activation layer. Default: nn.GELU - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ Swin Transformer Block. """ def __init__( - self, dim, input_resolution, num_heads=4, head_dim=None, window_size=7, shift_size=0, - mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0., - act_layer=nn.GELU, norm_layer=nn.LayerNorm): + self, + dim: int, + input_resolution: _int_or_tuple_2_t, + num_heads: int = 4, + head_dim: Optional[int] = None, + window_size: _int_or_tuple_2_t = 7, + shift_size: int = 0, + mlp_ratio: float = 4., + qkv_bias: bool = True, + drop: float = 0., + attn_drop: float = 0., + drop_path: float = 0., + act_layer: Callable = nn.GELU, + norm_layer: Callable = nn.LayerNorm, + ): + """ + Args: + dim: Number of input channels. + input_resolution: Input resolution. + window_size: Window size. + num_heads: Number of attention heads. + head_dim: Enforce the number of channels per head + shift_size: Shift size for SW-MSA. + mlp_ratio: Ratio of mlp hidden dim to embedding dim. + qkv_bias: If True, add a learnable bias to query, key, value. + drop: Dropout rate. + attn_drop: Attention dropout rate. + drop_path: Stochastic depth rate. + act_layer: Activation layer. + norm_layer: Normalization layer. + """ super().__init__() self.dim = dim self.input_resolution = input_resolution @@ -257,12 +217,23 @@ class SwinTransformerBlock(nn.Module): self.norm1 = norm_layer(dim) self.attn = WindowAttention( - dim, num_heads=num_heads, head_dim=head_dim, window_size=to_2tuple(self.window_size), - qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) + dim, + num_heads=num_heads, + head_dim=head_dim, + window_size=to_2tuple(self.window_size), + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) - self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) + self.mlp = Mlp( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + drop=drop, + ) if self.shift_size > 0: # calculate attention mask for SW-MSA @@ -285,17 +256,15 @@ class SwinTransformerBlock(nn.Module): attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) else: attn_mask = None - self.register_buffer("attn_mask", attn_mask) def forward(self, x): - H, W = self.input_resolution - B, L, C = x.shape - _assert(L == H * W, "input feature has wrong size") + B, H, W, C = x.shape + _assert(H == self.input_resolution[0], "input feature has wrong size") + _assert(W == self.input_resolution[1], "input feature has wrong size") shortcut = x x = self.norm1(x) - x = x.view(B, H, W, C) # cyclic shift if self.shift_size > 0: @@ -319,176 +288,227 @@ class SwinTransformerBlock(nn.Module): x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) else: x = shifted_x - x = x.view(B, H * W, C) # FFN x = shortcut + self.drop_path(x) - x = x + self.drop_path(self.mlp(self.norm2(x))) + x = x.reshape(B, -1, C) + x = x + self.drop_path(self.mlp(self.norm2(x))) + x = x.reshape(B, H, W, C) return x class PatchMerging(nn.Module): - r""" Patch Merging Layer. - - Args: - input_resolution (tuple[int]): Resolution of input feature. - dim (int): Number of input channels. - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ Patch Merging Layer. """ - def __init__(self, input_resolution, dim, out_dim=None, norm_layer=nn.LayerNorm): + def __init__(self, dim: int, out_dim: Optional[int] = None, norm_layer: Callable = nn.LayerNorm): + """ + Args: + dim: Number of input channels. + out_dim: Number of output channels (or 2 * dim if None) + norm_layer: Normalization layer. + """ super().__init__() - self.input_resolution = input_resolution self.dim = dim self.out_dim = out_dim or 2 * dim self.norm = norm_layer(4 * dim) self.reduction = nn.Linear(4 * dim, self.out_dim, bias=False) def forward(self, x): - """ - x: B, H*W, C - """ - H, W = self.input_resolution - B, L, C = x.shape - _assert(L == H * W, "input feature has wrong size") - _assert(H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even.") - - x = x.view(B, H, W, C) - - x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C - x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C - x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C - x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C - x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C - x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C - + B, H, W, C = x.shape + _assert(H % 2 == 0, f"x height ({H}) is not even.") + _assert(W % 2 == 0, f"x width ({W}) is not even.") + x = x.reshape(B, H // 2, 2, W // 2, 2, C).permute(0, 1, 3, 4, 2, 5).flatten(3) x = self.norm(x) x = self.reduction(x) - return x -class BasicLayer(nn.Module): +class SwinTransformerStage(nn.Module): """ A basic Swin Transformer layer for one stage. - - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resolution. - depth (int): Number of blocks. - num_heads (int): Number of attention heads. - head_dim (int): Channels per head (dim // num_heads if not set) - window_size (int): Local window size. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None """ def __init__( - self, dim, out_dim, input_resolution, depth, num_heads=4, head_dim=None, - window_size=7, mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., - drop_path=0., norm_layer=nn.LayerNorm, downsample=None): - + self, + dim: int, + out_dim: int, + input_resolution: Tuple[int, int], + depth: int, + downsample: bool = True, + num_heads: int = 4, + head_dim: Optional[int] = None, + window_size: _int_or_tuple_2_t = 7, + mlp_ratio: float = 4., + qkv_bias: bool = True, + drop: float = 0., + attn_drop: float = 0., + drop_path: Union[List[float], float] = 0., + norm_layer: Callable = nn.LayerNorm, + output_nchw: bool = False, + ): + """ + Args: + dim: Number of input channels. + input_resolution: Input resolution. + depth: Number of blocks. + num_heads: Number of attention heads. + head_dim: Channels per head (dim // num_heads if not set) + window_size: Local window size. + mlp_ratio: Ratio of mlp hidden dim to embedding dim. + qkv_bias: If True, add a learnable bias to query, key, value. + drop: Dropout rate. + attn_drop: Attention dropout rate. + drop_path: Stochastic depth rate. + norm_layer: Normalization layer. + downsample: Downsample layer at the end of the layer. + """ super().__init__() self.dim = dim self.input_resolution = input_resolution + self.output_resolution = tuple(i // 2 for i in input_resolution) if downsample else input_resolution self.depth = depth + self.use_nchw = output_nchw self.grad_checkpointing = False + # patch merging layer + if downsample: + self.downsample = PatchMerging( + dim=dim, + out_dim=out_dim, + norm_layer=norm_layer, + ) + else: + assert dim == out_dim + self.downsample = nn.Identity() + # build blocks self.blocks = nn.Sequential(*[ SwinTransformerBlock( - dim=dim, input_resolution=input_resolution, num_heads=num_heads, head_dim=head_dim, - window_size=window_size, shift_size=0 if (i % 2 == 0) else window_size // 2, - mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop, attn_drop=attn_drop, - drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer) + dim=out_dim, + input_resolution=self.output_resolution, + num_heads=num_heads, + head_dim=head_dim, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + ) for i in range(depth)]) - # patch merging layer - if downsample is not None: - self.downsample = downsample(input_resolution, dim=dim, out_dim=out_dim, norm_layer=norm_layer) - else: - self.downsample = None - def forward(self, x): + if self.use_nchw: + x = x.permute(0, 2, 3, 1) # NCHW -> NHWC + + x = self.downsample(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): x = checkpoint_seq(self.blocks, x) else: x = self.blocks(x) - if self.downsample is not None: - x = self.downsample(x) + + if self.use_nchw: + x = x.permute(0, 3, 1, 2) # NHWC -> NCHW return x class SwinTransformer(nn.Module): - r""" Swin Transformer - A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - - https://arxiv.org/pdf/2103.14030 + """ Swin Transformer - Args: - img_size (int | tuple(int)): Input image size. Default 224 - patch_size (int | tuple(int)): Patch size. Default: 4 - in_chans (int): Number of input image channels. Default: 3 - num_classes (int): Number of classes for classification head. Default: 1000 - embed_dim (int): Patch embedding dimension. Default: 96 - depths (tuple(int)): Depth of each Swin Transformer layer. - num_heads (tuple(int)): Number of attention heads in different layers. - head_dim (int, tuple(int)): - window_size (int): Window size. Default: 7 - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 - qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True - drop_rate (float): Dropout rate. Default: 0 - attn_drop_rate (float): Attention dropout rate. Default: 0 - drop_path_rate (float): Stochastic depth rate. Default: 0.1 - norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. - ape (bool): If True, add absolute position embedding to the patch embedding. Default: False - patch_norm (bool): If True, add normalization after patch embedding. Default: True + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 """ def __init__( - self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, global_pool='avg', - embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), head_dim=None, - window_size=7, mlp_ratio=4., qkv_bias=True, - drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, - norm_layer=nn.LayerNorm, ape=False, patch_norm=True, weight_init='', **kwargs): + self, + img_size: _int_or_tuple_2_t = 224, + patch_size: int = 4, + in_chans: int = 3, + num_classes: int = 1000, + global_pool: str = 'avg', + embed_dim: int = 96, + depths: Tuple[int, ...] = (2, 2, 6, 2), + num_heads: Tuple[int, ...] = (3, 6, 12, 24), + head_dim: Optional[int] = None, + window_size: _int_or_tuple_2_t = 7, + mlp_ratio: float = 4., + qkv_bias: bool = True, + drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0.1, + norm_layer: Union[str, Callable] = nn.LayerNorm, + weight_init: str = '', + output_fmt: str = 'NHWC', + **kwargs, + ): + """ + Args: + img_size: Input image size. + patch_size: Patch size. + in_chans: Number of input image channels. + num_classes: Number of classes for classification head. + embed_dim: Patch embedding dimension. + depths: Depth of each Swin Transformer layer. + num_heads: Number of attention heads in different layers. + head_dim: Dimension of self-attention heads. + window_size: Window size. + mlp_ratio: Ratio of mlp hidden dim to embedding dim. + qkv_bias: If True, add a learnable bias to query, key, value. + drop_rate: Dropout rate. + attn_drop_rate (float): Attention dropout rate. + drop_path_rate (float): Stochastic depth rate. + norm_layer (nn.Module): Normalization layer. + """ super().__init__() assert global_pool in ('', 'avg') + assert output_fmt in ('NCHW', 'NHWC') self.num_classes = num_classes self.global_pool = global_pool + self.output_fmt = output_fmt + self.num_layers = len(depths) self.embed_dim = embed_dim self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) + self.output_nchw = self.output_fmt == 'NCHW' # bool flag for fwd + self.feature_info = [] + + if not isinstance(embed_dim, (tuple, list)): + embed_dim = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] # split image into non-overlapping patches self.patch_embed = PatchEmbed( - img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, - norm_layer=norm_layer if patch_norm else None) - num_patches = self.patch_embed.num_patches + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim[0], + norm_layer=norm_layer, + output_fmt='NHWC', + ) self.patch_grid = self.patch_embed.grid_size - # absolute position embedding - self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) if ape else None - self.pos_drop = nn.Dropout(p=drop_rate) - # build layers - if not isinstance(embed_dim, (tuple, list)): - embed_dim = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] - embed_out_dim = embed_dim[1:] + [None] head_dim = to_ntuple(self.num_layers)(head_dim) window_size = to_ntuple(self.num_layers)(window_size) mlp_ratio = to_ntuple(self.num_layers)(mlp_ratio) - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] layers = [] + in_dim = embed_dim[0] + scale = 1 for i in range(self.num_layers): - layers += [BasicLayer( - dim=embed_dim[i], - out_dim=embed_out_dim[i], - input_resolution=(self.patch_grid[0] // (2 ** i), self.patch_grid[1] // (2 ** i)), + out_dim = embed_dim[i] + layers += [SwinTransformerStage( + dim=in_dim, + out_dim=out_dim, + input_resolution=( + self.patch_grid[0] // scale, + self.patch_grid[1] // scale + ), depth=depths[i], + downsample=i > 0, num_heads=num_heads[i], head_dim=head_dim[i], window_size=window_size[i], @@ -496,29 +516,36 @@ class SwinTransformer(nn.Module): qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, - drop_path=dpr[sum(depths[:i]):sum(depths[:i + 1])], + drop_path=dpr[i], norm_layer=norm_layer, - downsample=PatchMerging if (i < self.num_layers - 1) else None + output_nchw=self.output_nchw, )] + in_dim = out_dim + if i > 0: + scale *= 2 + self.feature_info += [dict(num_chs=out_dim, reduction=4 * scale, module=f'layers.{i}')] self.layers = nn.Sequential(*layers) self.norm = norm_layer(self.num_features) - self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() - + self.head = ClassifierHead( + self.num_features, + num_classes, + pool_type=global_pool, + drop_rate=drop_rate, + input_fmt=self.output_fmt, + ) if weight_init != 'skip': self.init_weights(weight_init) @torch.jit.ignore def init_weights(self, mode=''): assert mode in ('jax', 'jax_nlhb', 'moco', '') - if self.absolute_pos_embed is not None: - trunc_normal_(self.absolute_pos_embed, std=.02) head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. named_apply(get_init_weights_vit(mode, head_bias=head_bias), self) @torch.jit.ignore def no_weight_decay(self): - nwd = {'absolute_pos_embed'} + nwd = set() for n, _ in self.named_parameters(): if 'relative_position_bias_table' in n: nwd.add(n) @@ -527,7 +554,7 @@ class SwinTransformer(nn.Module): @torch.jit.ignore def group_matcher(self, coarse=False): return dict( - stem=r'^absolute_pos_embed|patch_embed', # stem and embed + stem=r'^patch_embed', # stem and embed blocks=r'^layers\.(\d+)' if coarse else [ (r'^layers\.(\d+).downsample', (0,)), (r'^layers\.(\d+)\.\w+\.(\d+)', None), @@ -542,28 +569,26 @@ class SwinTransformer(nn.Module): @torch.jit.ignore def get_classifier(self): - return self.head + return self.head.fc def reset_classifier(self, num_classes, global_pool=None): self.num_classes = num_classes - if global_pool is not None: - assert global_pool in ('', 'avg') - self.global_pool = global_pool - self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + self.head.reset(num_classes, pool_type=global_pool) def forward_features(self, x): x = self.patch_embed(x) - if self.absolute_pos_embed is not None: - x = x + self.absolute_pos_embed - x = self.pos_drop(x) + if self.output_nchw: + # patch embed always outputs NHWC, stage layers expect NCHW input + x = x.permute(0, 3, 1, 2) x = self.layers(x) - x = self.norm(x) # B L C + if self.output_nchw: + x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + else: + x = self.norm(x) return x def forward_head(self, x, pre_logits: bool = False): - if self.global_pool == 'avg': - x = x.mean(dim=1) - return x if pre_logits else self.head(x) + return self.head(x, pre_logits=True) if pre_logits else self.head(x) def forward(self, x): x = self.forward_features(x) @@ -571,15 +596,102 @@ class SwinTransformer(nn.Module): return x +def checkpoint_filter_fn( + state_dict, + model, + adapt_layer_scale=False, + interpolation='bicubic', + antialias=True, +): + """ convert patch embedding weight from manual patchify + linear proj to conv""" + import re + out_dict = {} + state_dict = state_dict.get('model', state_dict) + state_dict = state_dict.get('state_dict', state_dict) + for k, v in state_dict.items(): + k = re.sub(r'layers.(\d+).downsample', lambda x: f'layers.{int(x.group(1)) + 1}.downsample', k) + k = k.replace('head.', 'head.fc.') + out_dict[k] = v + return out_dict + + def _create_swin_transformer(variant, pretrained=False, **kwargs): + default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1)))) + out_indices = kwargs.pop('out_indices', default_out_indices) + model = build_model_with_cfg( SwinTransformer, variant, pretrained, pretrained_filter_fn=checkpoint_filter_fn, + feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), **kwargs) return model +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.proj', 'classifier': 'head.fc', + **kwargs + } + + +default_cfgs = { + 'swin_base_patch4_window12_384': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22kto1k.pth', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), + + 'swin_base_patch4_window7_224': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22kto1k.pth', + ), + + 'swin_large_patch4_window12_384': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22kto1k.pth', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), + + 'swin_large_patch4_window7_224': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22kto1k.pth', + ), + + 'swin_small_patch4_window7_224': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth', + ), + + 'swin_tiny_patch4_window7_224': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth', + ), + + 'swin_base_patch4_window12_384_in22k': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=21841), + + 'swin_base_patch4_window7_224_in22k': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth', + num_classes=21841), + + 'swin_large_patch4_window12_384_in22k': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=21841), + + 'swin_large_patch4_window7_224_in22k': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth', + num_classes=21841), + + 'swin_s3_tiny_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/s3_t-1d53f6a8.pth' + ), + 'swin_s3_small_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/s3_s-3bb4c69d.pth' + ), + 'swin_s3_base_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/s3_b-a1e95db4.pth' + ) +} + + @register_model def swin_base_patch4_window12_384(pretrained=False, **kwargs): """ Swin-B @ 384x384, pretrained ImageNet-22k, fine tune 1k diff --git a/timm/models/swin_transformer_v2.py b/timm/models/swin_transformer_v2.py index ffdb85e0..56b0d8cc 100644 --- a/timm/models/swin_transformer_v2.py +++ b/timm/models/swin_transformer_v2.py @@ -13,7 +13,7 @@ Modifications and additions for timm hacked together by / Copyright 2022, Ross W # Written by Ze Liu # -------------------------------------------------------- import math -from typing import Tuple, Optional +from typing import Callable, Optional, Tuple, Union import torch import torch.nn as nn @@ -21,76 +21,14 @@ import torch.nn.functional as F import torch.utils.checkpoint as checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, _assert +from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, _assert, ClassifierHead from ._builder import build_model_with_cfg from ._features_fx import register_notrace_function from ._registry import register_model __all__ = ['SwinTransformerV2'] # model_registry will add each entrypoint fn to this - -def _cfg(url='', **kwargs): - return { - 'url': url, - 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, - 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, - 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'patch_embed.proj', 'classifier': 'head', - **kwargs - } - - -default_cfgs = { - 'swinv2_tiny_window8_256': _cfg( - url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window8_256.pth', - input_size=(3, 256, 256) - ), - 'swinv2_tiny_window16_256': _cfg( - url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window16_256.pth', - input_size=(3, 256, 256) - ), - 'swinv2_small_window8_256': _cfg( - url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_small_patch4_window8_256.pth', - input_size=(3, 256, 256) - ), - 'swinv2_small_window16_256': _cfg( - url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_small_patch4_window16_256.pth', - input_size=(3, 256, 256) - ), - 'swinv2_base_window8_256': _cfg( - url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window8_256.pth', - input_size=(3, 256, 256) - ), - 'swinv2_base_window16_256': _cfg( - url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window16_256.pth', - input_size=(3, 256, 256) - ), - - 'swinv2_base_window12_192_22k': _cfg( - url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12_192_22k.pth', - num_classes=21841, input_size=(3, 192, 192) - ), - 'swinv2_base_window12to16_192to256_22kft1k': _cfg( - url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12to16_192to256_22kto1k_ft.pth', - input_size=(3, 256, 256) - ), - 'swinv2_base_window12to24_192to384_22kft1k': _cfg( - url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12to24_192to384_22kto1k_ft.pth', - input_size=(3, 384, 384), crop_pct=1.0, - ), - 'swinv2_large_window12_192_22k': _cfg( - url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12_192_22k.pth', - num_classes=21841, input_size=(3, 192, 192) - ), - 'swinv2_large_window12to16_192to256_22kft1k': _cfg( - url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12to16_192to256_22kto1k_ft.pth', - input_size=(3, 256, 256) - ), - 'swinv2_large_window12to24_192to384_22kft1k': _cfg( - url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12to24_192to384_22kto1k_ft.pth', - input_size=(3, 384, 384), crop_pct=1.0, - ), -} +_int_or_tuple_2_t = Union[int, Tuple[int, int]] def window_partition(x, window_size: Tuple[int, int]): @@ -141,9 +79,15 @@ class WindowAttention(nn.Module): """ def __init__( - self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0., - pretrained_window_size=[0, 0]): - + self, + dim, + window_size, + num_heads, + qkv_bias=True, + attn_drop=0., + proj_drop=0., + pretrained_window_size=[0, 0], + ): super().__init__() self.dim = dim self.window_size = window_size # Wh, Ww @@ -231,8 +175,8 @@ class WindowAttention(nn.Module): attn = attn + relative_position_bias.unsqueeze(0) if mask is not None: - nW = mask.shape[0] - attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + num_win = mask.shape[0] + attn = attn.view(-1, num_win, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) attn = attn.view(-1, self.num_heads, N, N) attn = self.softmax(attn) else: @@ -246,29 +190,42 @@ class WindowAttention(nn.Module): return x -class SwinTransformerBlock(nn.Module): - r""" Swin Transformer Block. - - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resolution. - num_heads (int): Number of attention heads. - window_size (int): Window size. - shift_size (int): Shift size for SW-MSA. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float, optional): Stochastic depth rate. Default: 0.0 - act_layer (nn.Module, optional): Activation layer. Default: nn.GELU - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - pretrained_window_size (int): Window size in pretraining. +class SwinTransformerV2Block(nn.Module): + """ Swin Transformer Block. """ def __init__( - self, dim, input_resolution, num_heads, window_size=7, shift_size=0, - mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0., - act_layer=nn.GELU, norm_layer=nn.LayerNorm, pretrained_window_size=0): + self, + dim, + input_resolution, + num_heads, + window_size=7, + shift_size=0, + mlp_ratio=4., + qkv_bias=True, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + pretrained_window_size=0, + ): + """ + Args: + dim: Number of input channels. + input_resolution: Input resolution. + num_heads: Number of attention heads. + window_size: Window size. + shift_size: Shift size for SW-MSA. + mlp_ratio: Ratio of mlp hidden dim to embedding dim. + qkv_bias: If True, add a learnable bias to query, key, value. + drop: Dropout rate. + attn_drop: Attention dropout rate. + drop_path: Stochastic depth rate. + act_layer: Activation layer. + norm_layer: Normalization layer. + pretrained_window_size: Window size in pretraining. + """ super().__init__() self.dim = dim self.input_resolution = to_2tuple(input_resolution) @@ -280,13 +237,23 @@ class SwinTransformerBlock(nn.Module): self.mlp_ratio = mlp_ratio self.attn = WindowAttention( - dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, - qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, - pretrained_window_size=to_2tuple(pretrained_window_size)) + dim, + window_size=to_2tuple(self.window_size), + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=drop, + pretrained_window_size=to_2tuple(pretrained_window_size), + ) self.norm1 = norm_layer(dim) self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) + self.mlp = Mlp( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + drop=drop, + ) self.norm2 = norm_layer(dim) self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() @@ -322,10 +289,7 @@ class SwinTransformerBlock(nn.Module): return tuple(window_size), tuple(shift_size) def _attn(self, x): - H, W = self.input_resolution - B, L, C = x.shape - _assert(L == H * W, "input feature has wrong size") - x = x.view(B, H, W, C) + B, H, W, C = x.shape # cyclic shift has_shift = any(self.shift_size) @@ -350,113 +314,130 @@ class SwinTransformerBlock(nn.Module): x = torch.roll(shifted_x, shifts=self.shift_size, dims=(1, 2)) else: x = shifted_x - x = x.view(B, H * W, C) return x def forward(self, x): + B, H, W, C = x.shape x = x + self.drop_path1(self.norm1(self._attn(x))) + x = x.reshape(B, -1, C) x = x + self.drop_path2(self.norm2(self.mlp(x))) + x = x.reshape(B, H, W, C) return x class PatchMerging(nn.Module): - r""" Patch Merging Layer. - - Args: - input_resolution (tuple[int]): Resolution of input feature. - dim (int): Number of input channels. - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ Patch Merging Layer. """ - def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + def __init__(self, dim, out_dim=None, norm_layer=nn.LayerNorm): + """ + Args: + dim (int): Number of input channels. + out_dim (int): Number of output channels (or 2 * dim if None) + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ super().__init__() - self.input_resolution = input_resolution self.dim = dim - self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) - self.norm = norm_layer(2 * dim) + self.out_dim = out_dim or 2 * dim + self.reduction = nn.Linear(4 * dim, self.out_dim, bias=False) + self.norm = norm_layer(self.out_dim) def forward(self, x): - """ - x: B, H*W, C - """ - H, W = self.input_resolution - B, L, C = x.shape - _assert(L == H * W, "input feature has wrong size") - _assert(H % 2 == 0, f"x size ({H}*{W}) are not even.") - _assert(W % 2 == 0, f"x size ({H}*{W}) are not even.") - - x = x.view(B, H, W, C) - - x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C - x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C - x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C - x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C - x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C - x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C - + B, H, W, C = x.shape + _assert(H % 2 == 0, f"x height ({H}) is not even.") + _assert(W % 2 == 0, f"x width ({W}) is not even.") + x = x.reshape(B, H // 2, 2, W // 2, 2, C).permute(0, 1, 3, 4, 2, 5).flatten(3) x = self.reduction(x) x = self.norm(x) - return x -class BasicLayer(nn.Module): - """ A basic Swin Transformer layer for one stage. - - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resolution. - depth (int): Number of blocks. - num_heads (int): Number of attention heads. - window_size (int): Local window size. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None - pretrained_window_size (int): Local window size in pre-training. +class SwinTransformerV2Stage(nn.Module): + """ A Swin Transformer V2 Stage. """ def __init__( - self, dim, input_resolution, depth, num_heads, window_size, - mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0., - norm_layer=nn.LayerNorm, downsample=None, pretrained_window_size=0): - + self, + dim, + out_dim, + input_resolution, + depth, + num_heads, + window_size, + downsample=False, + mlp_ratio=4., + qkv_bias=True, + drop=0., + attn_drop=0., + drop_path=0., + norm_layer=nn.LayerNorm, + pretrained_window_size=0, + output_nchw=False, + ): + """ + Args: + dim: Number of input channels. + input_resolution: Input resolution. + depth: Number of blocks. + num_heads: Number of attention heads. + window_size: Local window size. + downsample: Use downsample layer at start of the block. + mlp_ratio: Ratio of mlp hidden dim to embedding dim. + qkv_bias: If True, add a learnable bias to query, key, value. + drop: Dropout rate + attn_drop: Attention dropout rate. + drop_path: Stochastic depth rate. + norm_layer: Normalization layer. + pretrained_window_size: Local window size in pretraining. + output_nchw: Output tensors on NCHW format instead of NHWC. + """ super().__init__() self.dim = dim self.input_resolution = input_resolution + self.output_resolution = tuple(i // 2 for i in input_resolution) if downsample else input_resolution self.depth = depth + self.output_nchw = output_nchw self.grad_checkpointing = False + # patch merging / downsample layer + if downsample: + self.downsample = PatchMerging(dim=dim, out_dim=out_dim, norm_layer=norm_layer) + else: + assert dim == out_dim + self.downsample = nn.Identity() + # build blocks self.blocks = nn.ModuleList([ - SwinTransformerBlock( - dim=dim, input_resolution=input_resolution, - num_heads=num_heads, window_size=window_size, + SwinTransformerV2Block( + dim=out_dim, + input_resolution=self.output_resolution, + num_heads=num_heads, + window_size=window_size, shift_size=0 if (i % 2 == 0) else window_size // 2, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, - drop=drop, attn_drop=attn_drop, + drop=drop, + attn_drop=attn_drop, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer, - pretrained_window_size=pretrained_window_size) + pretrained_window_size=pretrained_window_size, + ) for i in range(depth)]) - # patch merging layer - if downsample is not None: - self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) - else: - self.downsample = nn.Identity() - def forward(self, x): + if self.output_nchw: + x = x.permute(0, 2, 3, 1) # NCHW -> NHWC + + x = self.downsample(x) + for blk in self.blocks: if self.grad_checkpointing and not torch.jit.is_scripting(): x = checkpoint.checkpoint(blk, x) else: x = blk(x) - x = self.downsample(x) + + if self.output_nchw: + x = x.permute(0, 3, 1, 2) # NHWC -> NCHW return x def _init_respostnorm(self): @@ -468,88 +449,117 @@ class BasicLayer(nn.Module): class SwinTransformerV2(nn.Module): - r""" Swin Transformer V2 - A PyTorch impl of : `Swin Transformer V2: Scaling Up Capacity and Resolution` - - https://arxiv.org/abs/2111.09883 - Args: - img_size (int | tuple(int)): Input image size. Default 224 - patch_size (int | tuple(int)): Patch size. Default: 4 - in_chans (int): Number of input image channels. Default: 3 - num_classes (int): Number of classes for classification head. Default: 1000 - embed_dim (int): Patch embedding dimension. Default: 96 - depths (tuple(int)): Depth of each Swin Transformer layer. - num_heads (tuple(int)): Number of attention heads in different layers. - window_size (int): Window size. Default: 7 - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 - qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True - drop_rate (float): Dropout rate. Default: 0 - attn_drop_rate (float): Attention dropout rate. Default: 0 - drop_path_rate (float): Stochastic depth rate. Default: 0.1 - norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. - ape (bool): If True, add absolute position embedding to the patch embedding. Default: False - patch_norm (bool): If True, add normalization after patch embedding. Default: True - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False - pretrained_window_sizes (tuple(int)): Pretrained window sizes of each layer. + """ Swin Transformer V2 + + A PyTorch impl of : `Swin Transformer V2: Scaling Up Capacity and Resolution` + - https://arxiv.org/abs/2111.09883 """ def __init__( - self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, global_pool='avg', - embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), - window_size=7, mlp_ratio=4., qkv_bias=True, - drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, - norm_layer=nn.LayerNorm, ape=False, patch_norm=True, - pretrained_window_sizes=(0, 0, 0, 0), **kwargs): + self, + img_size: _int_or_tuple_2_t = 224, + patch_size: int = 4, + in_chans: int = 3, + num_classes: int = 1000, + global_pool: str = 'avg', + embed_dim: int = 96, + depths: Tuple[int, ...] = (2, 2, 6, 2), + num_heads: Tuple[int, ...] = (3, 6, 12, 24), + window_size: _int_or_tuple_2_t = 7, + mlp_ratio: float = 4., + qkv_bias: bool = True, + drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0.1, + norm_layer: Callable = nn.LayerNorm, + pretrained_window_sizes: Tuple[int, ...] = (0, 0, 0, 0), + output_fmt: str = 'NHWC', + **kwargs, + ): + """ + Args: + img_size: Input image size. + patch_size: Patch size. + in_chans: Number of input image channels. + num_classes: Number of classes for classification head. + embed_dim: Patch embedding dimension. + depths: Depth of each Swin Transformer stage (layer). + num_heads: Number of attention heads in different layers. + window_size: Window size. + mlp_ratio: Ratio of mlp hidden dim to embedding dim. + qkv_bias: If True, add a learnable bias to query, key, value. + drop_rate: Dropout rate. + attn_drop_rate: Attention dropout rate. + drop_path_rate: Stochastic depth rate. + norm_layer: Normalization layer. + patch_norm: If True, add normalization after patch embedding. + pretrained_window_sizes: Pretrained window sizes of each layer. + output_fmt: Output tensor format if not None, otherwise output 'NHWC' by default. + """ super().__init__() self.num_classes = num_classes assert global_pool in ('', 'avg') + assert output_fmt in ('NCHW', 'NHWC') self.global_pool = global_pool + self.output_fmt = output_fmt self.num_layers = len(depths) self.embed_dim = embed_dim - self.patch_norm = patch_norm self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) + self.output_nchw = self.output_fmt == 'NCHW' + self.feature_info = [] + + if not isinstance(embed_dim, (tuple, list)): + embed_dim = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] # split image into non-overlapping patches self.patch_embed = PatchEmbed( - img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, - norm_layer=norm_layer if self.patch_norm else None) - num_patches = self.patch_embed.num_patches + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim[0], + norm_layer=norm_layer, + output_fmt='NHWC', + ) - # absolute position embedding - if ape: - self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) - trunc_normal_(self.absolute_pos_embed, std=.02) - else: - self.absolute_pos_embed = None - - self.pos_drop = nn.Dropout(p=drop_rate) - - # stochastic depth - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule - - # build layers - self.layers = nn.ModuleList() - for i_layer in range(self.num_layers): - layer = BasicLayer( - dim=int(embed_dim * 2 ** i_layer), + dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] + layers = [] + in_dim = embed_dim[0] + scale = 1 + for i in range(self.num_layers): + out_dim = embed_dim[i] + layers += [SwinTransformerV2Stage( + dim=in_dim, + out_dim=out_dim, input_resolution=( - self.patch_embed.grid_size[0] // (2 ** i_layer), - self.patch_embed.grid_size[1] // (2 ** i_layer)), - depth=depths[i_layer], - num_heads=num_heads[i_layer], + self.patch_embed.grid_size[0] // scale, + self.patch_embed.grid_size[1] // scale), + depth=depths[i], + downsample=i > 0, + num_heads=num_heads[i], window_size=window_size, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, - drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + drop_path=dpr[i], norm_layer=norm_layer, - downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, - pretrained_window_size=pretrained_window_sizes[i_layer] - ) - self.layers.append(layer) + pretrained_window_size=pretrained_window_sizes[i], + output_nchw=self.output_nchw, + )] + in_dim = out_dim + if i > 0: + scale *= 2 + self.feature_info += [dict(num_chs=out_dim, reduction=4 * scale, module=f'layers.{i}')] + self.layers = nn.Sequential(*layers) self.norm = norm_layer(self.num_features) - self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + self.head = ClassifierHead( + self.num_features, + num_classes, + pool_type=global_pool, + drop_rate=drop_rate, + input_fmt=self.output_fmt, + ) self.apply(self._init_weights) for bly in self.layers: @@ -563,7 +573,7 @@ class SwinTransformerV2(nn.Module): @torch.jit.ignore def no_weight_decay(self): - nod = {'absolute_pos_embed'} + nod = set() for n, m in self.named_modules(): if any([kw in n for kw in ("cpb_mlp", "logit_scale", 'relative_position_bias_table')]): nod.add(n) @@ -587,31 +597,26 @@ class SwinTransformerV2(nn.Module): @torch.jit.ignore def get_classifier(self): - return self.head + return self.head.fc def reset_classifier(self, num_classes, global_pool=None): self.num_classes = num_classes - if global_pool is not None: - assert global_pool in ('', 'avg') - self.global_pool = global_pool - self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + self.head.reset(num_classes, global_pool) def forward_features(self, x): x = self.patch_embed(x) - if self.absolute_pos_embed is not None: - x = x + self.absolute_pos_embed - x = self.pos_drop(x) - - for layer in self.layers: - x = layer(x) - - x = self.norm(x) # B L C + if self.output_nchw: + # patch embed always outputs NHWC, stage layers expect NCHW input if output_nchw = True + x = x.permute(0, 3, 1, 2) + x = self.layers(x) + if self.output_nchw: + x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + else: + x = self.norm(x) return x def forward_head(self, x, pre_logits: bool = False): - if self.global_pool == 'avg': - x = x.mean(dim=1) - return x if pre_logits else self.head(x) + return self.head(x, pre_logits=True) if pre_logits else self.head(x) def forward(self, x): x = self.forward_features(x) @@ -620,25 +625,87 @@ class SwinTransformerV2(nn.Module): def checkpoint_filter_fn(state_dict, model): + import re out_dict = {} - if 'model' in state_dict: - # For deit models - state_dict = state_dict['model'] + state_dict = state_dict.get('model', state_dict) + state_dict = state_dict.get('state_dict', state_dict) for k, v in state_dict.items(): if any([n in k for n in ('relative_position_index', 'relative_coords_table')]): continue # skip buffers that should not be persistent + k = re.sub(r'layers.(\d+).downsample', lambda x: f'layers.{int(x.group(1)) + 1}.downsample', k) + k = k.replace('head.', 'head.fc.') out_dict[k] = v return out_dict def _create_swin_transformer_v2(variant, pretrained=False, **kwargs): + default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 1, 1)))) + out_indices = kwargs.pop('out_indices', default_out_indices) + model = build_model_with_cfg( SwinTransformerV2, variant, pretrained, pretrained_filter_fn=checkpoint_filter_fn, + feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), **kwargs) return model +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 256, 256), 'pool_size': (8, 8), + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.proj', 'classifier': 'head.fc', + **kwargs + } + + +default_cfgs = { + 'swinv2_tiny_window8_256': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window8_256.pth', + ), + 'swinv2_tiny_window16_256': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window16_256.pth', + ), + 'swinv2_small_window8_256': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_small_patch4_window8_256.pth', + ), + 'swinv2_small_window16_256': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_small_patch4_window16_256.pth', + ), + 'swinv2_base_window8_256': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window8_256.pth', + ), + 'swinv2_base_window16_256': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window16_256.pth', + ), + + 'swinv2_base_window12_192_22k': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12_192_22k.pth', + num_classes=21841, input_size=(3, 192, 192), pool_size=(6, 6) + ), + 'swinv2_base_window12to16_192to256_22kft1k': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12to16_192to256_22kto1k_ft.pth', + ), + 'swinv2_base_window12to24_192to384_22kft1k': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12to24_192to384_22kto1k_ft.pth', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, + ), + 'swinv2_large_window12_192_22k': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12_192_22k.pth', + num_classes=21841, input_size=(3, 192, 192), pool_size=(6, 6) + ), + 'swinv2_large_window12to16_192to256_22kft1k': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12to16_192to256_22kto1k_ft.pth', + ), + 'swinv2_large_window12to24_192to384_22kft1k': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12to24_192to384_22kto1k_ft.pth', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, + ), +} + + @register_model def swinv2_tiny_window16_256(pretrained=False, **kwargs): """ diff --git a/timm/models/swin_transformer_v2_cr.py b/timm/models/swin_transformer_v2_cr.py index 9185e3e7..90b5c228 100644 --- a/timm/models/swin_transformer_v2_cr.py +++ b/timm/models/swin_transformer_v2_cr.py @@ -37,7 +37,7 @@ import torch.nn.functional as F import torch.utils.checkpoint as checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import DropPath, Mlp, to_2tuple, _assert +from timm.layers import DropPath, Mlp, ClassifierHead, to_2tuple, _assert from ._builder import build_model_with_cfg from ._features_fx import register_notrace_function from ._manipulate import named_apply @@ -48,60 +48,6 @@ __all__ = ['SwinTransformerV2Cr'] # model_registry will add each entrypoint fn _logger = logging.getLogger(__name__) -def _cfg(url='', **kwargs): - return { - 'url': url, - 'num_classes': 1000, - 'input_size': (3, 224, 224), - 'pool_size': (7, 7), - 'crop_pct': 0.9, - 'interpolation': 'bicubic', - 'fixed_input_size': True, - 'mean': IMAGENET_DEFAULT_MEAN, - 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'patch_embed.proj', - 'classifier': 'head', - **kwargs, - } - - -default_cfgs = { - 'swinv2_cr_tiny_384': _cfg( - url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)), - 'swinv2_cr_tiny_224': _cfg( - url="", input_size=(3, 224, 224), crop_pct=0.9), - 'swinv2_cr_tiny_ns_224': _cfg( - url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_tiny_ns_224-ba8166c6.pth", - input_size=(3, 224, 224), crop_pct=0.9), - 'swinv2_cr_small_384': _cfg( - url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)), - 'swinv2_cr_small_224': _cfg( - url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_small_224-0813c165.pth", - input_size=(3, 224, 224), crop_pct=0.9), - 'swinv2_cr_small_ns_224': _cfg( - url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_small_ns_224_iv-2ce90f8e.pth", - input_size=(3, 224, 224), crop_pct=0.9), - 'swinv2_cr_base_384': _cfg( - url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)), - 'swinv2_cr_base_224': _cfg( - url="", input_size=(3, 224, 224), crop_pct=0.9), - 'swinv2_cr_base_ns_224': _cfg( - url="", input_size=(3, 224, 224), crop_pct=0.9), - 'swinv2_cr_large_384': _cfg( - url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)), - 'swinv2_cr_large_224': _cfg( - url="", input_size=(3, 224, 224), crop_pct=0.9), - 'swinv2_cr_huge_384': _cfg( - url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)), - 'swinv2_cr_huge_224': _cfg( - url="", input_size=(3, 224, 224), crop_pct=0.9), - 'swinv2_cr_giant_384': _cfg( - url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)), - 'swinv2_cr_giant_224': _cfg( - url="", input_size=(3, 224, 224), crop_pct=0.9), -} - - def bchw_to_bhwc(x: torch.Tensor) -> torch.Tensor: """Permutes a tensor from the shape (B, C, H, W) to (B, H, W, C). """ return x.permute(0, 2, 3, 1) @@ -230,22 +176,14 @@ class WindowMultiHeadAttention(nn.Module): relative_position_bias = relative_position_bias.unsqueeze(0) return relative_position_bias - def _forward_sequential( - self, - x: torch.Tensor, - mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """ - """ - # FIXME TODO figure out 'sequential' attention mentioned in paper (should reduce GPU memory) - assert False, "not implemented" + def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: + """ Forward pass. + Args: + x (torch.Tensor): Input tensor of the shape (B * windows, N, C) + mask (Optional[torch.Tensor]): Attention mask for the shift case - def _forward_batch( - self, - x: torch.Tensor, - mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """This function performs standard (non-sequential) scaled cosine self-attention. + Returns: + Output tensor of the shape [B * windows, N, C] """ Bw, L, C = x.shape @@ -272,22 +210,8 @@ class WindowMultiHeadAttention(nn.Module): x = self.proj_drop(x) return x - def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: - """ Forward pass. - Args: - x (torch.Tensor): Input tensor of the shape (B * windows, N, C) - mask (Optional[torch.Tensor]): Attention mask for the shift case - Returns: - Output tensor of the shape [B * windows, N, C] - """ - if self.sequential_attn: - return self._forward_sequential(x, mask) - else: - return self._forward_batch(x, mask) - - -class SwinTransformerBlock(nn.Module): +class SwinTransformerV2CrBlock(nn.Module): r"""This class implements the Swin transformer block. Args: @@ -321,7 +245,7 @@ class SwinTransformerBlock(nn.Module): sequential_attn: bool = False, norm_layer: Type[nn.Module] = nn.LayerNorm, ) -> None: - super(SwinTransformerBlock, self).__init__() + super(SwinTransformerV2CrBlock, self).__init__() self.dim: int = dim self.feat_size: Tuple[int, int] = feat_size self.target_shift_size: Tuple[int, int] = to_2tuple(shift_size) @@ -410,9 +334,7 @@ class SwinTransformerBlock(nn.Module): self._make_attention_mask() def _shifted_window_attn(self, x): - H, W = self.feat_size - B, L, C = x.shape - x = x.view(B, H, W, C) + B, H, W, C = x.shape # cyclic shift sh, sw = self.shift_size @@ -441,7 +363,6 @@ class SwinTransformerBlock(nn.Module): # x = torch.cat([x[:, :, -sw:], x[:, :, :-sw]], dim=2) x = torch.roll(x, shifts=(sh, sw), dims=(1, 2)) - x = x.view(B, L, C) return x def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -455,8 +376,12 @@ class SwinTransformerBlock(nn.Module): """ # post-norm branches (op -> norm -> drop) x = x + self.drop_path1(self.norm1(self._shifted_window_attn(x))) + + B, H, W, C = x.shape + x = x.reshape(B, -1, C) x = x + self.drop_path2(self.norm2(self.mlp(x))) x = self.norm3(x) # main-branch norm enabled for some blocks / stages (every 6 for Huge/Giant) + x = x.reshape(B, H, W, C) return x @@ -479,12 +404,10 @@ class PatchMerging(nn.Module): Returns: output (torch.Tensor): Output tensor of the shape [B, 2 * C, H // 2, W // 2] """ - B, C, H, W = x.shape - # unfold + BCHW -> BHWC together - # ordering, 5, 3, 1 instead of 3, 5, 1 maintains compat with original swin v1 merge - x = x.reshape(B, C, H // 2, 2, W // 2, 2).permute(0, 2, 4, 5, 3, 1).flatten(3) + B, H, W, C = x.shape + x = x.reshape(B, H // 2, 2, W // 2, 2, C).permute(0, 1, 3, 4, 2, 5).flatten(3) x = self.norm(x) - x = bhwc_to_bchw(self.reduction(x)) + x = self.reduction(x) return x @@ -511,7 +434,7 @@ class PatchEmbed(nn.Module): return x -class SwinTransformerStage(nn.Module): +class SwinTransformerV2CrStage(nn.Module): r"""This class implements a stage of the Swin transformer including multiple layers. Args: @@ -549,12 +472,16 @@ class SwinTransformerStage(nn.Module): extra_norm_stage: bool = False, sequential_attn: bool = False, ) -> None: - super(SwinTransformerStage, self).__init__() + super(SwinTransformerV2CrStage, self).__init__() self.downscale: bool = downscale self.grad_checkpointing: bool = False self.feat_size: Tuple[int, int] = (feat_size[0] // 2, feat_size[1] // 2) if downscale else feat_size - self.downsample = PatchMerging(embed_dim, norm_layer=norm_layer) if downscale else nn.Identity() + if downscale: + self.downsample = PatchMerging(embed_dim, norm_layer=norm_layer) + embed_dim = embed_dim * 2 + else: + self.downsample = nn.Identity() def _extra_norm(index): i = index + 1 @@ -562,9 +489,8 @@ class SwinTransformerStage(nn.Module): return True return i == depth if extra_norm_stage else False - embed_dim = embed_dim * 2 if downscale else embed_dim self.blocks = nn.Sequential(*[ - SwinTransformerBlock( + SwinTransformerV2CrBlock( dim=embed_dim, num_heads=num_heads, feat_size=self.feat_size, @@ -602,18 +528,15 @@ class SwinTransformerStage(nn.Module): Returns: output (torch.Tensor): Output tensor of the shape [B, 2 * C, H // 2, W // 2] """ + x = bchw_to_bhwc(x) x = self.downsample(x) - B, C, H, W = x.shape - L = H * W - - x = bchw_to_bhwc(x).reshape(B, L, C) for block in self.blocks: # Perform checkpointing if utilized if self.grad_checkpointing and not torch.jit.is_scripting(): x = checkpoint.checkpoint(block, x) else: x = block(x) - x = bhwc_to_bchw(x.reshape(B, H, W, -1)) + x = bhwc_to_bchw(x) return x @@ -676,39 +599,54 @@ class SwinTransformerV2Cr(nn.Module): self.img_size: Tuple[int, int] = img_size self.window_size: int = window_size self.num_features: int = int(embed_dim * 2 ** (len(depths) - 1)) + self.feature_info = [] self.patch_embed = PatchEmbed( - img_size=img_size, patch_size=patch_size, in_chans=in_chans, - embed_dim=embed_dim, norm_layer=norm_layer) + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + norm_layer=norm_layer, + ) patch_grid_size: Tuple[int, int] = self.patch_embed.grid_size - drop_path_rate = torch.linspace(0.0, drop_path_rate, sum(depths)).tolist() + dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] stages = [] - for index, (depth, num_heads) in enumerate(zip(depths, num_heads)): - stage_scale = 2 ** max(index - 1, 0) - stages.append( - SwinTransformerStage( - embed_dim=embed_dim * stage_scale, - depth=depth, - downscale=index != 0, - feat_size=(patch_grid_size[0] // stage_scale, patch_grid_size[1] // stage_scale), - num_heads=num_heads, - window_size=window_size, - mlp_ratio=mlp_ratio, - init_values=init_values, - drop=drop_rate, - drop_attn=attn_drop_rate, - drop_path=drop_path_rate[sum(depths[:index]):sum(depths[:index + 1])], - extra_norm_period=extra_norm_period, - extra_norm_stage=extra_norm_stage or (index + 1) == len(depths), # last stage ends w/ norm - sequential_attn=sequential_attn, - norm_layer=norm_layer, - ) - ) + in_dim = embed_dim + in_scale = 1 + for stage_idx, (depth, num_heads) in enumerate(zip(depths, num_heads)): + stages += [SwinTransformerV2CrStage( + embed_dim=in_dim, + depth=depth, + downscale=stage_idx != 0, + feat_size=( + patch_grid_size[0] // in_scale, + patch_grid_size[1] // in_scale + ), + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + init_values=init_values, + drop=drop_rate, + drop_attn=attn_drop_rate, + drop_path=dpr[stage_idx], + extra_norm_period=extra_norm_period, + extra_norm_stage=extra_norm_stage or (stage_idx + 1) == len(depths), # last stage ends w/ norm + sequential_attn=sequential_attn, + norm_layer=norm_layer, + )] + if stage_idx != 0: + in_dim *= 2 + in_scale *= 2 + self.feature_info += [dict(num_chs=in_dim, reduction=4 * in_scale, module=f'stages.{stage_idx}')] self.stages = nn.Sequential(*stages) - self.global_pool: str = global_pool - self.head = nn.Linear(self.num_features, num_classes) if num_classes else nn.Identity() + self.head = ClassifierHead( + self.num_features, + num_classes, + pool_type=global_pool, + drop_rate=drop_rate, + ) # current weight init skips custom init and uses pytorch layer defaults, seems to work well # FIXME more experiments needed @@ -765,7 +703,7 @@ class SwinTransformerV2Cr(nn.Module): Returns: head (nn.Module): Current classification head """ - return self.head + return self.head.fc def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) -> None: """Method results the classification head @@ -774,10 +712,8 @@ class SwinTransformerV2Cr(nn.Module): num_classes (int): Number of classes to be predicted global_pool (str): Unused """ - self.num_classes: int = num_classes - if global_pool is not None: - self.global_pool = global_pool - self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + self.num_classes = num_classes + self.head.reset(num_classes, global_pool) def forward_features(self, x: torch.Tensor) -> torch.Tensor: x = self.patch_embed(x) @@ -785,9 +721,7 @@ class SwinTransformerV2Cr(nn.Module): return x def forward_head(self, x, pre_logits: bool = False): - if self.global_pool == 'avg': - x = x.mean(dim=(2, 3)) - return x if pre_logits else self.head(x) + return self.head(x, pre_logits=True) if pre_logits else self.head(x) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.forward_features(x) @@ -815,29 +749,87 @@ def init_weights(module: nn.Module, name: str = ''): def checkpoint_filter_fn(state_dict, model): """ convert patch embedding weight from manual patchify + linear proj to conv""" out_dict = {} - if 'model' in state_dict: - # For deit models - state_dict = state_dict['model'] + state_dict = state_dict.get('model', state_dict) + state_dict = state_dict.get('state_dict', state_dict) for k, v in state_dict.items(): if 'tau' in k: # convert old tau based checkpoints -> logit_scale (inverse) v = torch.log(1 / v) k = k.replace('tau', 'logit_scale') + k = k.replace('head.', 'head.fc.') out_dict[k] = v return out_dict def _create_swin_transformer_v2_cr(variant, pretrained=False, **kwargs): - if kwargs.get('features_only', None): - raise RuntimeError('features_only not implemented for Vision Transformer models.') + default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 1, 1)))) + out_indices = kwargs.pop('out_indices', default_out_indices) + model = build_model_with_cfg( SwinTransformerV2Cr, variant, pretrained, pretrained_filter_fn=checkpoint_filter_fn, + feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), **kwargs ) return model +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, + 'input_size': (3, 224, 224), + 'pool_size': (7, 7), + 'crop_pct': 0.9, + 'interpolation': 'bicubic', + 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, + 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.proj', + 'classifier': 'head.fc', + **kwargs, + } + + +default_cfgs = { + 'swinv2_cr_tiny_384': _cfg( + url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)), + 'swinv2_cr_tiny_224': _cfg( + url="", input_size=(3, 224, 224), crop_pct=0.9), + 'swinv2_cr_tiny_ns_224': _cfg( + url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_tiny_ns_224-ba8166c6.pth", + input_size=(3, 224, 224), crop_pct=0.9), + 'swinv2_cr_small_384': _cfg( + url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)), + 'swinv2_cr_small_224': _cfg( + url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_small_224-0813c165.pth", + input_size=(3, 224, 224), crop_pct=0.9), + 'swinv2_cr_small_ns_224': _cfg( + url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_small_ns_224_iv-2ce90f8e.pth", + input_size=(3, 224, 224), crop_pct=0.9), + 'swinv2_cr_small_ns_256': _cfg( + url="", input_size=(3, 256, 256), crop_pct=1.0, pool_size=(8, 8)), + 'swinv2_cr_base_384': _cfg( + url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)), + 'swinv2_cr_base_224': _cfg( + url="", input_size=(3, 224, 224), crop_pct=0.9), + 'swinv2_cr_base_ns_224': _cfg( + url="", input_size=(3, 224, 224), crop_pct=0.9), + 'swinv2_cr_large_384': _cfg( + url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)), + 'swinv2_cr_large_224': _cfg( + url="", input_size=(3, 224, 224), crop_pct=0.9), + 'swinv2_cr_huge_384': _cfg( + url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)), + 'swinv2_cr_huge_224': _cfg( + url="", input_size=(3, 224, 224), crop_pct=0.9), + 'swinv2_cr_giant_384': _cfg( + url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)), + 'swinv2_cr_giant_224': _cfg( + url="", input_size=(3, 224, 224), crop_pct=0.9), +} + + @register_model def swinv2_cr_tiny_384(pretrained=False, **kwargs): """Swin-T V2 CR @ 384x384, trained ImageNet-1k""" @@ -915,6 +907,19 @@ def swinv2_cr_small_ns_224(pretrained=False, **kwargs): return _create_swin_transformer_v2_cr('swinv2_cr_small_ns_224', pretrained=pretrained, **model_kwargs) +@register_model +def swinv2_cr_small_ns_256(pretrained=False, **kwargs): + """Swin-S V2 CR @ 256x256, trained ImageNet-1k""" + model_kwargs = dict( + embed_dim=96, + depths=(2, 2, 18, 2), + num_heads=(3, 6, 12, 24), + extra_norm_stage=True, + **kwargs + ) + return _create_swin_transformer_v2_cr('swinv2_cr_small_ns_256', pretrained=pretrained, **model_kwargs) + + @register_model def swinv2_cr_base_384(pretrained=False, **kwargs): """Swin-B V2 CR @ 384x384, trained ImageNet-1k""" @@ -961,8 +966,7 @@ def swinv2_cr_large_384(pretrained=False, **kwargs): num_heads=(6, 12, 24, 48), **kwargs ) - return _create_swin_transformer_v2_cr('swinv2_cr_large_384', pretrained=pretrained, **model_kwargs - ) + return _create_swin_transformer_v2_cr('swinv2_cr_large_384', pretrained=pretrained, **model_kwargs) @register_model