Merge pull request #1628 from huggingface/focalnet_and_swin_refactor
Add FocalNet arch, refactor Swin V1/V2 for better feature extraction and HF hub multi-weight supportpull/1741/head
commit
0d5c5c39fc
|
@ -27,7 +27,8 @@ except ImportError:
|
|||
|
||||
import timm
|
||||
from timm import list_models, create_model, set_scriptable, get_pretrained_cfg_value
|
||||
from timm.models._features_fx import _leaf_modules, _autowrap_functions
|
||||
from timm.layers import Format, get_spatial_dim, get_channel_dim
|
||||
from timm.models import get_notrace_modules, get_notrace_functions
|
||||
|
||||
if hasattr(torch._C, '_jit_set_profiling_executor'):
|
||||
# legacy executor is too slow to compile large models for unit tests
|
||||
|
@ -37,10 +38,10 @@ if hasattr(torch._C, '_jit_set_profiling_executor'):
|
|||
|
||||
# transformer models don't support many of the spatial / feature based model functionalities
|
||||
NON_STD_FILTERS = [
|
||||
'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
|
||||
'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
|
||||
'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit*',
|
||||
'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*',
|
||||
'eva_*', 'flexivit*'
|
||||
'poolformer_*', 'volo_*', 'sequencer2d_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*',
|
||||
'eva_*', 'flexivit*',
|
||||
]
|
||||
NUM_NON_STD = len(NON_STD_FILTERS)
|
||||
|
||||
|
@ -52,7 +53,7 @@ if 'GITHUB_ACTIONS' in os.environ:
|
|||
'*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', '*efficientnetv2_xl*',
|
||||
'*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', 'vit_huge*', 'vit_gi*', 'swin*huge*',
|
||||
'swin*giant*', 'convnextv2_huge*', 'maxvit_xlarge*', 'davit_giant', 'davit_huge']
|
||||
NON_STD_EXCLUDE_FILTERS = ['vit_huge*', 'vit_gi*', 'swin*giant*', 'eva_giant*']
|
||||
NON_STD_EXCLUDE_FILTERS = ['vit_huge*', 'vit_gi*', 'eva_giant*']
|
||||
else:
|
||||
EXCLUDE_FILTERS = []
|
||||
NON_STD_EXCLUDE_FILTERS = ['vit_gi*']
|
||||
|
@ -145,7 +146,8 @@ def test_model_backward(model_name, batch_size):
|
|||
|
||||
@pytest.mark.cfg
|
||||
@pytest.mark.timeout(300)
|
||||
@pytest.mark.parametrize('model_name', list_models(exclude_filters=NON_STD_FILTERS, include_tags=True))
|
||||
@pytest.mark.parametrize('model_name', list_models(
|
||||
exclude_filters=EXCLUDE_FILTERS + NON_STD_FILTERS, include_tags=True))
|
||||
@pytest.mark.parametrize('batch_size', [1])
|
||||
def test_model_default_cfgs(model_name, batch_size):
|
||||
"""Run a single forward pass with each model"""
|
||||
|
@ -156,6 +158,10 @@ def test_model_default_cfgs(model_name, batch_size):
|
|||
|
||||
pool_size = cfg['pool_size']
|
||||
input_size = model.default_cfg['input_size']
|
||||
output_fmt = getattr(model, 'output_fmt', 'NCHW')
|
||||
spatial_axis = get_spatial_dim(output_fmt)
|
||||
assert len(spatial_axis) == 2 # TODO add 1D sequence support
|
||||
feat_axis = get_channel_dim(output_fmt)
|
||||
|
||||
if all([x <= MAX_FWD_OUT_SIZE for x in input_size]) and \
|
||||
not any([fnmatch.fnmatch(model_name, x) for x in EXCLUDE_FILTERS]):
|
||||
|
@ -165,13 +171,16 @@ def test_model_default_cfgs(model_name, batch_size):
|
|||
|
||||
# test forward_features (always unpooled)
|
||||
outputs = model.forward_features(input_tensor)
|
||||
assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2], 'unpooled feature shape != config'
|
||||
assert outputs.shape[spatial_axis[0]] == pool_size[0], 'unpooled feature shape != config'
|
||||
assert outputs.shape[spatial_axis[1]] == pool_size[1], 'unpooled feature shape != config'
|
||||
if not isinstance(model, (timm.models.MobileNetV3, timm.models.GhostNet, timm.models.VGG)):
|
||||
assert outputs.shape[feat_axis] == model.num_features
|
||||
|
||||
# test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features
|
||||
model.reset_classifier(0)
|
||||
outputs = model.forward(input_tensor)
|
||||
assert len(outputs.shape) == 2
|
||||
assert outputs.shape[-1] == model.num_features
|
||||
assert outputs.shape[1] == model.num_features
|
||||
|
||||
# test model forward without pooling and classifier
|
||||
model.reset_classifier(0, '') # reset classifier and set global pooling to pass-through
|
||||
|
@ -179,7 +188,7 @@ def test_model_default_cfgs(model_name, batch_size):
|
|||
assert len(outputs.shape) == 4
|
||||
if not isinstance(model, (timm.models.MobileNetV3, timm.models.GhostNet, timm.models.VGG)):
|
||||
# mobilenetv3/ghostnet/vgg forward_features vs removed pooling differ due to location or lack of GAP
|
||||
assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2]
|
||||
assert outputs.shape[spatial_axis[0]] == pool_size[0] and outputs.shape[spatial_axis[1]] == pool_size[1]
|
||||
|
||||
if 'pruned' not in model_name: # FIXME better pruned model handling
|
||||
# test classifier + global pool deletion via __init__
|
||||
|
@ -187,7 +196,7 @@ def test_model_default_cfgs(model_name, batch_size):
|
|||
outputs = model.forward(input_tensor)
|
||||
assert len(outputs.shape) == 4
|
||||
if not isinstance(model, (timm.models.MobileNetV3, timm.models.GhostNet, timm.models.VGG)):
|
||||
assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2]
|
||||
assert outputs.shape[spatial_axis[0]] == pool_size[0] and outputs.shape[spatial_axis[1]] == pool_size[1]
|
||||
|
||||
# check classifier name matches default_cfg
|
||||
if cfg.get('num_classes', None):
|
||||
|
@ -330,176 +339,181 @@ def test_model_forward_features(model_name, batch_size):
|
|||
model = create_model(model_name, pretrained=False, features_only=True)
|
||||
model.eval()
|
||||
expected_channels = model.feature_info.channels()
|
||||
expected_reduction = model.feature_info.reduction()
|
||||
assert len(expected_channels) >= 4 # all models here should have at least 4 feature levels by default, some 5 or 6
|
||||
|
||||
input_size = _get_input_size(model=model, target=TARGET_FFEAT_SIZE)
|
||||
if max(input_size) > MAX_FFEAT_SIZE:
|
||||
pytest.skip("Fixed input size model > limit.")
|
||||
output_fmt = getattr(model, 'output_fmt', 'NCHW')
|
||||
feat_axis = get_channel_dim(output_fmt)
|
||||
spatial_axis = get_spatial_dim(output_fmt)
|
||||
import math
|
||||
|
||||
outputs = model(torch.randn((batch_size, *input_size)))
|
||||
assert len(expected_channels) == len(outputs)
|
||||
for e, o in zip(expected_channels, outputs):
|
||||
assert e == o.shape[1]
|
||||
spatial_size = input_size[-2:]
|
||||
for e, r, o in zip(expected_channels, expected_reduction, outputs):
|
||||
assert e == o.shape[feat_axis]
|
||||
assert o.shape[spatial_axis[0]] <= math.ceil(spatial_size[0] / r) + 1
|
||||
assert o.shape[spatial_axis[1]] <= math.ceil(spatial_size[1] / r) + 1
|
||||
assert o.shape[0] == batch_size
|
||||
assert not torch.isnan(o).any()
|
||||
|
||||
|
||||
if not _IS_MAC:
|
||||
# MACOS test runners are really slow, only running tests below this point if not on a Darwin runner...
|
||||
def _create_fx_model(model, train=False):
|
||||
# This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode
|
||||
# So we trace once and look at the graph, and get the indices of the nodes that lead into the original fx output
|
||||
# node. Then we use those indices to select from train_nodes returned by torchvision get_graph_node_names
|
||||
tracer_kwargs = dict(
|
||||
leaf_modules=get_notrace_modules(),
|
||||
autowrap_functions=get_notrace_functions(),
|
||||
#enable_cpatching=True,
|
||||
param_shapes_constant=True
|
||||
)
|
||||
train_nodes, eval_nodes = get_graph_node_names(model, tracer_kwargs=tracer_kwargs)
|
||||
|
||||
def _create_fx_model(model, train=False):
|
||||
# This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode
|
||||
# So we trace once and look at the graph, and get the indices of the nodes that lead into the original fx output
|
||||
# node. Then we use those indices to select from train_nodes returned by torchvision get_graph_node_names
|
||||
tracer_kwargs = dict(
|
||||
leaf_modules=list(_leaf_modules),
|
||||
autowrap_functions=list(_autowrap_functions),
|
||||
#enable_cpatching=True,
|
||||
param_shapes_constant=True
|
||||
)
|
||||
train_nodes, eval_nodes = get_graph_node_names(model, tracer_kwargs=tracer_kwargs)
|
||||
eval_return_nodes = [eval_nodes[-1]]
|
||||
train_return_nodes = [train_nodes[-1]]
|
||||
if train:
|
||||
tracer = NodePathTracer(**tracer_kwargs)
|
||||
graph = tracer.trace(model)
|
||||
graph_nodes = list(reversed(graph.nodes))
|
||||
output_node_names = [n.name for n in graph_nodes[0]._input_nodes.keys()]
|
||||
graph_node_names = [n.name for n in graph_nodes]
|
||||
output_node_indices = [-graph_node_names.index(node_name) for node_name in output_node_names]
|
||||
train_return_nodes = [train_nodes[ix] for ix in output_node_indices]
|
||||
|
||||
eval_return_nodes = [eval_nodes[-1]]
|
||||
train_return_nodes = [train_nodes[-1]]
|
||||
if train:
|
||||
tracer = NodePathTracer(**tracer_kwargs)
|
||||
graph = tracer.trace(model)
|
||||
graph_nodes = list(reversed(graph.nodes))
|
||||
output_node_names = [n.name for n in graph_nodes[0]._input_nodes.keys()]
|
||||
graph_node_names = [n.name for n in graph_nodes]
|
||||
output_node_indices = [-graph_node_names.index(node_name) for node_name in output_node_names]
|
||||
train_return_nodes = [train_nodes[ix] for ix in output_node_indices]
|
||||
|
||||
fx_model = create_feature_extractor(
|
||||
model,
|
||||
train_return_nodes=train_return_nodes,
|
||||
eval_return_nodes=eval_return_nodes,
|
||||
tracer_kwargs=tracer_kwargs,
|
||||
)
|
||||
return fx_model
|
||||
fx_model = create_feature_extractor(
|
||||
model,
|
||||
train_return_nodes=train_return_nodes,
|
||||
eval_return_nodes=eval_return_nodes,
|
||||
tracer_kwargs=tracer_kwargs,
|
||||
)
|
||||
return fx_model
|
||||
|
||||
|
||||
EXCLUDE_FX_FILTERS = ['vit_gi*']
|
||||
# not enough memory to run fx on more models than other tests
|
||||
if 'GITHUB_ACTIONS' in os.environ:
|
||||
EXCLUDE_FX_FILTERS += [
|
||||
'beit_large*',
|
||||
'mixer_l*',
|
||||
'*nfnet_f2*',
|
||||
'*resnext101_32x32d',
|
||||
'resnetv2_152x2*',
|
||||
'resmlp_big*',
|
||||
'resnetrs270',
|
||||
'swin_large*',
|
||||
'vgg*',
|
||||
'vit_large*',
|
||||
'vit_base_patch8*',
|
||||
'xcit_large*',
|
||||
]
|
||||
EXCLUDE_FX_FILTERS = ['vit_gi*']
|
||||
# not enough memory to run fx on more models than other tests
|
||||
if 'GITHUB_ACTIONS' in os.environ:
|
||||
EXCLUDE_FX_FILTERS += [
|
||||
'beit_large*',
|
||||
'mixer_l*',
|
||||
'*nfnet_f2*',
|
||||
'*resnext101_32x32d',
|
||||
'resnetv2_152x2*',
|
||||
'resmlp_big*',
|
||||
'resnetrs270',
|
||||
'swin_large*',
|
||||
'vgg*',
|
||||
'vit_large*',
|
||||
'vit_base_patch8*',
|
||||
'xcit_large*',
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.fxforward
|
||||
@pytest.mark.timeout(120)
|
||||
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FX_FILTERS))
|
||||
@pytest.mark.parametrize('batch_size', [1])
|
||||
def test_model_forward_fx(model_name, batch_size):
|
||||
"""
|
||||
Symbolically trace each model and run single forward pass through the resulting GraphModule
|
||||
Also check that the output of a forward pass through the GraphModule is the same as that from the original Module
|
||||
"""
|
||||
if not has_fx_feature_extraction:
|
||||
pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.")
|
||||
|
||||
model = create_model(model_name, pretrained=False)
|
||||
model.eval()
|
||||
|
||||
input_size = _get_input_size(model=model, target=TARGET_FWD_FX_SIZE)
|
||||
if max(input_size) > MAX_FWD_FX_SIZE:
|
||||
pytest.skip("Fixed input size model > limit.")
|
||||
with torch.no_grad():
|
||||
inputs = torch.randn((batch_size, *input_size))
|
||||
outputs = model(inputs)
|
||||
if isinstance(outputs, tuple):
|
||||
outputs = torch.cat(outputs)
|
||||
|
||||
model = _create_fx_model(model)
|
||||
fx_outputs = tuple(model(inputs).values())
|
||||
if isinstance(fx_outputs, tuple):
|
||||
fx_outputs = torch.cat(fx_outputs)
|
||||
|
||||
assert torch.all(fx_outputs == outputs)
|
||||
assert outputs.shape[0] == batch_size
|
||||
assert not torch.isnan(outputs).any(), 'Output included NaNs'
|
||||
|
||||
|
||||
@pytest.mark.fxbackward
|
||||
@pytest.mark.timeout(120)
|
||||
@pytest.mark.parametrize('model_name', list_models(
|
||||
exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FX_FILTERS, name_matches_cfg=True))
|
||||
@pytest.mark.parametrize('batch_size', [2])
|
||||
def test_model_backward_fx(model_name, batch_size):
|
||||
"""Symbolically trace each model and run single backward pass through the resulting GraphModule"""
|
||||
if not has_fx_feature_extraction:
|
||||
pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.")
|
||||
|
||||
input_size = _get_input_size(model_name=model_name, target=TARGET_BWD_FX_SIZE)
|
||||
if max(input_size) > MAX_BWD_FX_SIZE:
|
||||
pytest.skip("Fixed input size model > limit.")
|
||||
|
||||
model = create_model(model_name, pretrained=False, num_classes=42)
|
||||
model.train()
|
||||
num_params = sum([x.numel() for x in model.parameters()])
|
||||
if 'GITHUB_ACTIONS' in os.environ and num_params > 100e6:
|
||||
pytest.skip("Skipping FX backward test on model with more than 100M params.")
|
||||
|
||||
model = _create_fx_model(model, train=True)
|
||||
outputs = tuple(model(torch.randn((batch_size, *input_size))).values())
|
||||
if isinstance(outputs, tuple):
|
||||
outputs = torch.cat(outputs)
|
||||
outputs.mean().backward()
|
||||
for n, x in model.named_parameters():
|
||||
assert x.grad is not None, f'No gradient for {n}'
|
||||
num_grad = sum([x.grad.numel() for x in model.parameters() if x.grad is not None])
|
||||
|
||||
assert outputs.shape[-1] == 42
|
||||
assert num_params == num_grad, 'Some parameters are missing gradients'
|
||||
assert not torch.isnan(outputs).any(), 'Output included NaNs'
|
||||
|
||||
|
||||
if 'GITHUB_ACTIONS' not in os.environ:
|
||||
# FIXME this test is causing GitHub actions to run out of RAM and abruptly kill the test process
|
||||
|
||||
# reason: model is scripted after fx tracing, but beit has torch.jit.is_scripting() control flow
|
||||
EXCLUDE_FX_JIT_FILTERS = [
|
||||
'deit_*_distilled_patch16_224',
|
||||
'levit*',
|
||||
'pit_*_distilled_224',
|
||||
] + EXCLUDE_FX_FILTERS
|
||||
|
||||
|
||||
@pytest.mark.fxforward
|
||||
@pytest.mark.timeout(120)
|
||||
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FX_FILTERS))
|
||||
@pytest.mark.parametrize(
|
||||
'model_name', list_models(
|
||||
exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS + EXCLUDE_FX_JIT_FILTERS, name_matches_cfg=True))
|
||||
@pytest.mark.parametrize('batch_size', [1])
|
||||
def test_model_forward_fx(model_name, batch_size):
|
||||
"""
|
||||
Symbolically trace each model and run single forward pass through the resulting GraphModule
|
||||
Also check that the output of a forward pass through the GraphModule is the same as that from the original Module
|
||||
"""
|
||||
def test_model_forward_fx_torchscript(model_name, batch_size):
|
||||
"""Symbolically trace each model, script it, and run single forward pass"""
|
||||
if not has_fx_feature_extraction:
|
||||
pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.")
|
||||
|
||||
model = create_model(model_name, pretrained=False)
|
||||
input_size = _get_input_size(model_name=model_name, target=TARGET_JIT_SIZE)
|
||||
if max(input_size) > MAX_JIT_SIZE:
|
||||
pytest.skip("Fixed input size model > limit.")
|
||||
|
||||
with set_scriptable(True):
|
||||
model = create_model(model_name, pretrained=False)
|
||||
model.eval()
|
||||
|
||||
input_size = _get_input_size(model=model, target=TARGET_FWD_FX_SIZE)
|
||||
if max(input_size) > MAX_FWD_FX_SIZE:
|
||||
pytest.skip("Fixed input size model > limit.")
|
||||
model = torch.jit.script(_create_fx_model(model))
|
||||
with torch.no_grad():
|
||||
inputs = torch.randn((batch_size, *input_size))
|
||||
outputs = model(inputs)
|
||||
outputs = tuple(model(torch.randn((batch_size, *input_size))).values())
|
||||
if isinstance(outputs, tuple):
|
||||
outputs = torch.cat(outputs)
|
||||
|
||||
model = _create_fx_model(model)
|
||||
fx_outputs = tuple(model(inputs).values())
|
||||
if isinstance(fx_outputs, tuple):
|
||||
fx_outputs = torch.cat(fx_outputs)
|
||||
|
||||
assert torch.all(fx_outputs == outputs)
|
||||
assert outputs.shape[0] == batch_size
|
||||
assert not torch.isnan(outputs).any(), 'Output included NaNs'
|
||||
|
||||
|
||||
@pytest.mark.fxbackward
|
||||
@pytest.mark.timeout(120)
|
||||
@pytest.mark.parametrize('model_name', list_models(
|
||||
exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FX_FILTERS, name_matches_cfg=True))
|
||||
@pytest.mark.parametrize('batch_size', [2])
|
||||
def test_model_backward_fx(model_name, batch_size):
|
||||
"""Symbolically trace each model and run single backward pass through the resulting GraphModule"""
|
||||
if not has_fx_feature_extraction:
|
||||
pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.")
|
||||
|
||||
input_size = _get_input_size(model_name=model_name, target=TARGET_BWD_FX_SIZE)
|
||||
if max(input_size) > MAX_BWD_FX_SIZE:
|
||||
pytest.skip("Fixed input size model > limit.")
|
||||
|
||||
model = create_model(model_name, pretrained=False, num_classes=42)
|
||||
model.train()
|
||||
num_params = sum([x.numel() for x in model.parameters()])
|
||||
if 'GITHUB_ACTIONS' in os.environ and num_params > 100e6:
|
||||
pytest.skip("Skipping FX backward test on model with more than 100M params.")
|
||||
|
||||
model = _create_fx_model(model, train=True)
|
||||
outputs = tuple(model(torch.randn((batch_size, *input_size))).values())
|
||||
if isinstance(outputs, tuple):
|
||||
outputs = torch.cat(outputs)
|
||||
outputs.mean().backward()
|
||||
for n, x in model.named_parameters():
|
||||
assert x.grad is not None, f'No gradient for {n}'
|
||||
num_grad = sum([x.grad.numel() for x in model.parameters() if x.grad is not None])
|
||||
|
||||
assert outputs.shape[-1] == 42
|
||||
assert num_params == num_grad, 'Some parameters are missing gradients'
|
||||
assert not torch.isnan(outputs).any(), 'Output included NaNs'
|
||||
|
||||
|
||||
if 'GITHUB_ACTIONS' not in os.environ:
|
||||
# FIXME this test is causing GitHub actions to run out of RAM and abruptly kill the test process
|
||||
|
||||
# reason: model is scripted after fx tracing, but beit has torch.jit.is_scripting() control flow
|
||||
EXCLUDE_FX_JIT_FILTERS = [
|
||||
'deit_*_distilled_patch16_224',
|
||||
'levit*',
|
||||
'pit_*_distilled_224',
|
||||
] + EXCLUDE_FX_FILTERS
|
||||
|
||||
|
||||
@pytest.mark.timeout(120)
|
||||
@pytest.mark.parametrize(
|
||||
'model_name', list_models(
|
||||
exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS + EXCLUDE_FX_JIT_FILTERS, name_matches_cfg=True))
|
||||
@pytest.mark.parametrize('batch_size', [1])
|
||||
def test_model_forward_fx_torchscript(model_name, batch_size):
|
||||
"""Symbolically trace each model, script it, and run single forward pass"""
|
||||
if not has_fx_feature_extraction:
|
||||
pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.")
|
||||
|
||||
input_size = _get_input_size(model_name=model_name, target=TARGET_JIT_SIZE)
|
||||
if max(input_size) > MAX_JIT_SIZE:
|
||||
pytest.skip("Fixed input size model > limit.")
|
||||
|
||||
with set_scriptable(True):
|
||||
model = create_model(model_name, pretrained=False)
|
||||
model.eval()
|
||||
|
||||
model = torch.jit.script(_create_fx_model(model))
|
||||
with torch.no_grad():
|
||||
outputs = tuple(model(torch.randn((batch_size, *input_size))).values())
|
||||
if isinstance(outputs, tuple):
|
||||
outputs = torch.cat(outputs)
|
||||
|
||||
assert outputs.shape[0] == batch_size
|
||||
assert not torch.isnan(outputs).any(), 'Output included NaNs'
|
||||
|
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -88,6 +88,7 @@ class IterableImageDataset(data.IterableDataset):
|
|||
root,
|
||||
reader=None,
|
||||
split='train',
|
||||
class_map=None,
|
||||
is_training=False,
|
||||
batch_size=None,
|
||||
seed=42,
|
||||
|
@ -102,6 +103,7 @@ class IterableImageDataset(data.IterableDataset):
|
|||
reader,
|
||||
root=root,
|
||||
split=split,
|
||||
class_map=class_map,
|
||||
is_training=is_training,
|
||||
batch_size=batch_size,
|
||||
seed=seed,
|
||||
|
|
|
@ -157,6 +157,7 @@ def create_dataset(
|
|||
root,
|
||||
reader=name,
|
||||
split=split,
|
||||
class_map=class_map,
|
||||
is_training=is_training,
|
||||
download=download,
|
||||
batch_size=batch_size,
|
||||
|
@ -169,6 +170,7 @@ def create_dataset(
|
|||
root,
|
||||
reader=name,
|
||||
split=split,
|
||||
class_map=class_map,
|
||||
is_training=is_training,
|
||||
batch_size=batch_size,
|
||||
repeats=repeats,
|
||||
|
|
|
@ -7,12 +7,14 @@ from typing import Dict, List, Optional, Union
|
|||
from .dataset_info import DatasetInfo
|
||||
|
||||
|
||||
# NOTE no ambiguity wrt to mapping from # classes to ImageNet subset so far, but likely to change
|
||||
_NUM_CLASSES_TO_SUBSET = {
|
||||
1000: 'imagenet-1k',
|
||||
11821: 'imagenet-12k',
|
||||
21841: 'imagenet-22k',
|
||||
21843: 'imagenet-21k-goog',
|
||||
11221: 'imagenet-21k-miil',
|
||||
11221: 'imagenet-21k-miil', # miil subset of fall11
|
||||
11821: 'imagenet-12k', # timm specific 12k subset of fall11
|
||||
21841: 'imagenet-22k', # as in fall11.tar
|
||||
21842: 'imagenet-22k-ms', # a Microsoft (for FocalNet) remapping of 22k w/ moves ImageNet-1k classes to first 1000
|
||||
21843: 'imagenet-21k-goog', # Google's ImageNet full has two classes not in fall11
|
||||
}
|
||||
|
||||
_SUBSETS = {
|
||||
|
@ -22,6 +24,7 @@ _SUBSETS = {
|
|||
'imagenet21k': 'imagenet21k_goog_synsets.txt',
|
||||
'imagenet21kgoog': 'imagenet21k_goog_synsets.txt',
|
||||
'imagenet21kmiil': 'imagenet21k_miil_synsets.txt',
|
||||
'imagenet22kms': 'imagenet22k_ms_synsets.txt',
|
||||
}
|
||||
_LEMMA_FILE = 'imagenet_synset_to_lemma.txt'
|
||||
_DEFINITION_FILE = 'imagenet_synset_to_definition.txt'
|
||||
|
|
|
@ -34,6 +34,7 @@ except ImportError as e:
|
|||
print("Please install tensorflow_datasets package `pip install tensorflow-datasets`.")
|
||||
exit(1)
|
||||
|
||||
from .class_map import load_class_map
|
||||
from .reader import Reader
|
||||
from .shared_count import SharedCount
|
||||
|
||||
|
@ -94,6 +95,7 @@ class ReaderTfds(Reader):
|
|||
root,
|
||||
name,
|
||||
split='train',
|
||||
class_map=None,
|
||||
is_training=False,
|
||||
batch_size=None,
|
||||
download=False,
|
||||
|
@ -151,7 +153,12 @@ class ReaderTfds(Reader):
|
|||
# NOTE: the tfds command line app can be used download & prepare datasets if you don't enable download flag
|
||||
if download:
|
||||
self.builder.download_and_prepare()
|
||||
self.class_to_idx = get_class_labels(self.builder.info) if self.target_name == 'label' else {}
|
||||
self.remap_class = False
|
||||
if class_map:
|
||||
self.class_to_idx = load_class_map(class_map)
|
||||
self.remap_class = True
|
||||
else:
|
||||
self.class_to_idx = get_class_labels(self.builder.info) if self.target_name == 'label' else {}
|
||||
self.split_info = self.builder.info.splits[split]
|
||||
self.num_samples = self.split_info.num_examples
|
||||
|
||||
|
@ -299,6 +306,8 @@ class ReaderTfds(Reader):
|
|||
target_data = sample[self.target_name]
|
||||
if self.target_img_mode:
|
||||
target_data = Image.fromarray(target_data, mode=self.target_img_mode)
|
||||
elif self.remap_class:
|
||||
target_data = self.class_to_idx[target_data]
|
||||
yield input_data, target_data
|
||||
sample_count += 1
|
||||
if self.is_training and sample_count >= target_sample_count:
|
||||
|
|
|
@ -29,6 +29,7 @@ except ImportError:
|
|||
wds = None
|
||||
expand_urls = None
|
||||
|
||||
from .class_map import load_class_map
|
||||
from .reader import Reader
|
||||
from .shared_count import SharedCount
|
||||
|
||||
|
@ -42,13 +43,13 @@ def _load_info(root, basename='info'):
|
|||
info_yaml = os.path.join(root, basename + '.yaml')
|
||||
err_str = ''
|
||||
try:
|
||||
with wds.gopen.gopen(info_json) as f:
|
||||
with wds.gopen(info_json) as f:
|
||||
info_dict = json.load(f)
|
||||
return info_dict
|
||||
except Exception as e:
|
||||
err_str = str(e)
|
||||
try:
|
||||
with wds.gopen.gopen(info_yaml) as f:
|
||||
with wds.gopen(info_yaml) as f:
|
||||
info_dict = yaml.safe_load(f)
|
||||
return info_dict
|
||||
except Exception:
|
||||
|
@ -110,8 +111,8 @@ def _parse_split_info(split: str, info: Dict):
|
|||
filenames=split_filenames,
|
||||
)
|
||||
else:
|
||||
if split not in info['splits']:
|
||||
raise RuntimeError(f"split {split} not found in info ({info['splits'].keys()})")
|
||||
if 'splits' not in info or split not in info['splits']:
|
||||
raise RuntimeError(f"split {split} not found in info ({info.get('splits', {}).keys()})")
|
||||
split = split
|
||||
split_info = info['splits'][split]
|
||||
split_info = _info_convert(split_info)
|
||||
|
@ -290,6 +291,7 @@ class ReaderWds(Reader):
|
|||
batch_size=None,
|
||||
repeats=0,
|
||||
seed=42,
|
||||
class_map=None,
|
||||
input_name='jpg',
|
||||
input_image='RGB',
|
||||
target_name='cls',
|
||||
|
@ -320,6 +322,12 @@ class ReaderWds(Reader):
|
|||
self.num_samples = self.split_info.num_samples
|
||||
if not self.num_samples:
|
||||
raise RuntimeError(f'Invalid split definition, no samples found.')
|
||||
self.remap_class = False
|
||||
if class_map:
|
||||
self.class_to_idx = load_class_map(class_map)
|
||||
self.remap_class = True
|
||||
else:
|
||||
self.class_to_idx = {}
|
||||
|
||||
# Distributed world state
|
||||
self.dist_rank = 0
|
||||
|
@ -431,7 +439,10 @@ class ReaderWds(Reader):
|
|||
i = 0
|
||||
# _logger.info(f'start {i}, {self.worker_id}') # FIXME temporary debug
|
||||
for sample in ds:
|
||||
yield sample[self.image_key], sample[self.target_key]
|
||||
target = sample[self.target_key]
|
||||
if self.remap_class:
|
||||
target = self.class_to_idx[target]
|
||||
yield sample[self.image_key], target
|
||||
i += 1
|
||||
# _logger.info(f'end {i}, {self.worker_id}') # FIXME temporary debug
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ from .evo_norm import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2,\
|
|||
EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a
|
||||
from .fast_norm import is_fast_norm, set_fast_norm, fast_group_norm, fast_layer_norm
|
||||
from .filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d
|
||||
from .format import Format, get_channel_dim, get_spatial_dim, nchw_to, nhwc_to
|
||||
from .gather_excite import GatherExcite
|
||||
from .global_context import GlobalContext
|
||||
from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple
|
||||
|
|
|
@ -9,31 +9,37 @@ Both a functional and a nn.Module version of the pooling is provided.
|
|||
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .format import get_spatial_dim, get_channel_dim
|
||||
|
||||
_int_tuple_2_t = Union[int, Tuple[int, int]]
|
||||
|
||||
|
||||
def adaptive_pool_feat_mult(pool_type='avg'):
|
||||
if pool_type == 'catavgmax':
|
||||
if pool_type.endswith('catavgmax'):
|
||||
return 2
|
||||
else:
|
||||
return 1
|
||||
|
||||
|
||||
def adaptive_avgmax_pool2d(x, output_size=1):
|
||||
def adaptive_avgmax_pool2d(x, output_size: _int_tuple_2_t = 1):
|
||||
x_avg = F.adaptive_avg_pool2d(x, output_size)
|
||||
x_max = F.adaptive_max_pool2d(x, output_size)
|
||||
return 0.5 * (x_avg + x_max)
|
||||
|
||||
|
||||
def adaptive_catavgmax_pool2d(x, output_size=1):
|
||||
def adaptive_catavgmax_pool2d(x, output_size: _int_tuple_2_t = 1):
|
||||
x_avg = F.adaptive_avg_pool2d(x, output_size)
|
||||
x_max = F.adaptive_max_pool2d(x, output_size)
|
||||
return torch.cat((x_avg, x_max), 1)
|
||||
|
||||
|
||||
def select_adaptive_pool2d(x, pool_type='avg', output_size=1):
|
||||
def select_adaptive_pool2d(x, pool_type='avg', output_size: _int_tuple_2_t = 1):
|
||||
"""Selectable global pooling function with dynamic input kernel size
|
||||
"""
|
||||
if pool_type == 'avg':
|
||||
|
@ -49,17 +55,56 @@ def select_adaptive_pool2d(x, pool_type='avg', output_size=1):
|
|||
return x
|
||||
|
||||
|
||||
class FastAdaptiveAvgPool2d(nn.Module):
|
||||
def __init__(self, flatten=False):
|
||||
super(FastAdaptiveAvgPool2d, self).__init__()
|
||||
class FastAdaptiveAvgPool(nn.Module):
|
||||
def __init__(self, flatten: bool = False, input_fmt: F = 'NCHW'):
|
||||
super(FastAdaptiveAvgPool, self).__init__()
|
||||
self.flatten = flatten
|
||||
self.dim = get_spatial_dim(input_fmt)
|
||||
|
||||
def forward(self, x):
|
||||
return x.mean((2, 3), keepdim=not self.flatten)
|
||||
return x.mean(self.dim, keepdim=not self.flatten)
|
||||
|
||||
|
||||
class FastAdaptiveMaxPool(nn.Module):
|
||||
def __init__(self, flatten: bool = False, input_fmt: str = 'NCHW'):
|
||||
super(FastAdaptiveMaxPool, self).__init__()
|
||||
self.flatten = flatten
|
||||
self.dim = get_spatial_dim(input_fmt)
|
||||
|
||||
def forward(self, x):
|
||||
return x.amax(self.dim, keepdim=not self.flatten)
|
||||
|
||||
|
||||
class FastAdaptiveAvgMaxPool(nn.Module):
|
||||
def __init__(self, flatten: bool = False, input_fmt: str = 'NCHW'):
|
||||
super(FastAdaptiveAvgMaxPool, self).__init__()
|
||||
self.flatten = flatten
|
||||
self.dim = get_spatial_dim(input_fmt)
|
||||
|
||||
def forward(self, x):
|
||||
x_avg = x.mean(self.dim, keepdim=not self.flatten)
|
||||
x_max = x.amax(self.dim, keepdim=not self.flatten)
|
||||
return 0.5 * x_avg + 0.5 * x_max
|
||||
|
||||
|
||||
class FastAdaptiveCatAvgMaxPool(nn.Module):
|
||||
def __init__(self, flatten: bool = False, input_fmt: str = 'NCHW'):
|
||||
super(FastAdaptiveCatAvgMaxPool, self).__init__()
|
||||
self.flatten = flatten
|
||||
self.dim_reduce = get_spatial_dim(input_fmt)
|
||||
if flatten:
|
||||
self.dim_cat = 1
|
||||
else:
|
||||
self.dim_cat = get_channel_dim(input_fmt)
|
||||
|
||||
def forward(self, x):
|
||||
x_avg = x.mean(self.dim_reduce, keepdim=not self.flatten)
|
||||
x_max = x.amax(self.dim_reduce, keepdim=not self.flatten)
|
||||
return torch.cat((x_avg, x_max), self.dim_cat)
|
||||
|
||||
|
||||
class AdaptiveAvgMaxPool2d(nn.Module):
|
||||
def __init__(self, output_size=1):
|
||||
def __init__(self, output_size: _int_tuple_2_t = 1):
|
||||
super(AdaptiveAvgMaxPool2d, self).__init__()
|
||||
self.output_size = output_size
|
||||
|
||||
|
@ -68,7 +113,7 @@ class AdaptiveAvgMaxPool2d(nn.Module):
|
|||
|
||||
|
||||
class AdaptiveCatAvgMaxPool2d(nn.Module):
|
||||
def __init__(self, output_size=1):
|
||||
def __init__(self, output_size: _int_tuple_2_t = 1):
|
||||
super(AdaptiveCatAvgMaxPool2d, self).__init__()
|
||||
self.output_size = output_size
|
||||
|
||||
|
@ -79,26 +124,41 @@ class AdaptiveCatAvgMaxPool2d(nn.Module):
|
|||
class SelectAdaptivePool2d(nn.Module):
|
||||
"""Selectable global pooling layer with dynamic input kernel size
|
||||
"""
|
||||
def __init__(self, output_size=1, pool_type='fast', flatten=False):
|
||||
def __init__(
|
||||
self,
|
||||
output_size: _int_tuple_2_t = 1,
|
||||
pool_type: str = 'fast',
|
||||
flatten: bool = False,
|
||||
input_fmt: str = 'NCHW',
|
||||
):
|
||||
super(SelectAdaptivePool2d, self).__init__()
|
||||
assert input_fmt in ('NCHW', 'NHWC')
|
||||
self.pool_type = pool_type or '' # convert other falsy values to empty string for consistent TS typing
|
||||
self.flatten = nn.Flatten(1) if flatten else nn.Identity()
|
||||
if pool_type == '':
|
||||
if not pool_type:
|
||||
self.pool = nn.Identity() # pass through
|
||||
elif pool_type == 'fast':
|
||||
assert output_size == 1
|
||||
self.pool = FastAdaptiveAvgPool2d(flatten)
|
||||
self.flatten = nn.Flatten(1) if flatten else nn.Identity()
|
||||
elif pool_type.startswith('fast') or input_fmt != 'NCHW':
|
||||
assert output_size == 1, 'Fast pooling and non NCHW input formats require output_size == 1.'
|
||||
if pool_type.endswith('avgmax'):
|
||||
self.pool = FastAdaptiveAvgMaxPool(flatten, input_fmt=input_fmt)
|
||||
elif pool_type.endswith('catavgmax'):
|
||||
self.pool = FastAdaptiveCatAvgMaxPool(flatten, input_fmt=input_fmt)
|
||||
elif pool_type.endswith('max'):
|
||||
self.pool = FastAdaptiveMaxPool(flatten, input_fmt=input_fmt)
|
||||
else:
|
||||
self.pool = FastAdaptiveAvgPool(flatten, input_fmt=input_fmt)
|
||||
self.flatten = nn.Identity()
|
||||
elif pool_type == 'avg':
|
||||
self.pool = nn.AdaptiveAvgPool2d(output_size)
|
||||
elif pool_type == 'avgmax':
|
||||
self.pool = AdaptiveAvgMaxPool2d(output_size)
|
||||
elif pool_type == 'catavgmax':
|
||||
self.pool = AdaptiveCatAvgMaxPool2d(output_size)
|
||||
elif pool_type == 'max':
|
||||
self.pool = nn.AdaptiveMaxPool2d(output_size)
|
||||
else:
|
||||
assert False, 'Invalid pool type: %s' % pool_type
|
||||
assert input_fmt == 'NCHW'
|
||||
if pool_type == 'avgmax':
|
||||
self.pool = AdaptiveAvgMaxPool2d(output_size)
|
||||
elif pool_type == 'catavgmax':
|
||||
self.pool = AdaptiveCatAvgMaxPool2d(output_size)
|
||||
elif pool_type == 'max':
|
||||
self.pool = nn.AdaptiveMaxPool2d(output_size)
|
||||
else:
|
||||
self.pool = nn.AdaptiveAvgPool2d(output_size)
|
||||
self.flatten = nn.Flatten(1) if flatten else nn.Identity()
|
||||
|
||||
def is_identity(self):
|
||||
return not self.pool_type
|
||||
|
|
|
@ -15,13 +15,23 @@ from .create_act import get_act_layer
|
|||
from .create_norm import get_norm_layer
|
||||
|
||||
|
||||
def _create_pool(num_features, num_classes, pool_type='avg', use_conv=False):
|
||||
def _create_pool(
|
||||
num_features: int,
|
||||
num_classes: int,
|
||||
pool_type: str = 'avg',
|
||||
use_conv: bool = False,
|
||||
input_fmt: Optional[str] = None,
|
||||
):
|
||||
flatten_in_pool = not use_conv # flatten when we use a Linear layer after pooling
|
||||
if not pool_type:
|
||||
assert num_classes == 0 or use_conv,\
|
||||
'Pooling can only be disabled if classifier is also removed or conv classifier is used'
|
||||
flatten_in_pool = False # disable flattening if pooling is pass-through (no pooling)
|
||||
global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=flatten_in_pool)
|
||||
global_pool = SelectAdaptivePool2d(
|
||||
pool_type=pool_type,
|
||||
flatten=flatten_in_pool,
|
||||
input_fmt=input_fmt,
|
||||
)
|
||||
num_pooled_features = num_features * global_pool.feat_mult()
|
||||
return global_pool, num_pooled_features
|
||||
|
||||
|
@ -36,9 +46,25 @@ def _create_fc(num_features, num_classes, use_conv=False):
|
|||
return fc
|
||||
|
||||
|
||||
def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False):
|
||||
global_pool, num_pooled_features = _create_pool(num_features, num_classes, pool_type, use_conv=use_conv)
|
||||
fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv)
|
||||
def create_classifier(
|
||||
num_features: int,
|
||||
num_classes: int,
|
||||
pool_type: str = 'avg',
|
||||
use_conv: bool = False,
|
||||
input_fmt: str = 'NCHW',
|
||||
):
|
||||
global_pool, num_pooled_features = _create_pool(
|
||||
num_features,
|
||||
num_classes,
|
||||
pool_type,
|
||||
use_conv=use_conv,
|
||||
input_fmt=input_fmt,
|
||||
)
|
||||
fc = _create_fc(
|
||||
num_pooled_features,
|
||||
num_classes,
|
||||
use_conv=use_conv,
|
||||
)
|
||||
return global_pool, fc
|
||||
|
||||
|
||||
|
@ -52,6 +78,7 @@ class ClassifierHead(nn.Module):
|
|||
pool_type: str = 'avg',
|
||||
drop_rate: float = 0.,
|
||||
use_conv: bool = False,
|
||||
input_fmt: str = 'NCHW',
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
|
@ -64,28 +91,43 @@ class ClassifierHead(nn.Module):
|
|||
self.drop_rate = drop_rate
|
||||
self.in_features = in_features
|
||||
self.use_conv = use_conv
|
||||
self.input_fmt = input_fmt
|
||||
|
||||
self.global_pool, num_pooled_features = _create_pool(in_features, num_classes, pool_type, use_conv=use_conv)
|
||||
self.fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv)
|
||||
self.global_pool, self.fc = create_classifier(
|
||||
in_features,
|
||||
num_classes,
|
||||
pool_type,
|
||||
use_conv=use_conv,
|
||||
input_fmt=input_fmt,
|
||||
)
|
||||
self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity()
|
||||
|
||||
def reset(self, num_classes, global_pool=None):
|
||||
if global_pool is not None:
|
||||
if global_pool != self.global_pool.pool_type:
|
||||
self.global_pool, _ = _create_pool(self.in_features, num_classes, global_pool, use_conv=self.use_conv)
|
||||
self.flatten = nn.Flatten(1) if self.use_conv and global_pool else nn.Identity()
|
||||
num_pooled_features = self.in_features * self.global_pool.feat_mult()
|
||||
self.fc = _create_fc(num_pooled_features, num_classes, use_conv=self.use_conv)
|
||||
def reset(self, num_classes, pool_type=None):
|
||||
if pool_type is not None and pool_type != self.global_pool.pool_type:
|
||||
self.global_pool, self.fc = create_classifier(
|
||||
self.in_features,
|
||||
num_classes,
|
||||
pool_type=pool_type,
|
||||
use_conv=self.use_conv,
|
||||
input_fmt=self.input_fmt,
|
||||
)
|
||||
self.flatten = nn.Flatten(1) if self.use_conv and pool_type else nn.Identity()
|
||||
else:
|
||||
num_pooled_features = self.in_features * self.global_pool.feat_mult()
|
||||
self.fc = _create_fc(
|
||||
num_pooled_features,
|
||||
num_classes,
|
||||
use_conv=self.use_conv,
|
||||
)
|
||||
|
||||
def forward(self, x, pre_logits: bool = False):
|
||||
x = self.global_pool(x)
|
||||
if self.drop_rate:
|
||||
x = F.dropout(x, p=float(self.drop_rate), training=self.training)
|
||||
if pre_logits:
|
||||
return x.flatten(1)
|
||||
else:
|
||||
x = self.fc(x)
|
||||
return self.flatten(x)
|
||||
x = self.fc(x)
|
||||
return self.flatten(x)
|
||||
|
||||
|
||||
class NormMlpClassifierHead(nn.Module):
|
||||
|
|
|
@ -0,0 +1,58 @@
|
|||
from enum import Enum
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class Format(str, Enum):
|
||||
NCHW = 'NCHW'
|
||||
NHWC = 'NHWC'
|
||||
NCL = 'NCL'
|
||||
NLC = 'NLC'
|
||||
|
||||
|
||||
FormatT = Union[str, Format]
|
||||
|
||||
|
||||
def get_spatial_dim(fmt: FormatT):
|
||||
fmt = Format(fmt)
|
||||
if fmt is Format.NLC:
|
||||
dim = (1,)
|
||||
elif fmt is Format.NCL:
|
||||
dim = (2,)
|
||||
elif fmt is Format.NHWC:
|
||||
dim = (1, 2)
|
||||
else:
|
||||
dim = (2, 3)
|
||||
return dim
|
||||
|
||||
|
||||
def get_channel_dim(fmt: FormatT):
|
||||
fmt = Format(fmt)
|
||||
if fmt is Format.NHWC:
|
||||
dim = 3
|
||||
elif fmt is Format.NLC:
|
||||
dim = 2
|
||||
else:
|
||||
dim = 1
|
||||
return dim
|
||||
|
||||
|
||||
def nchw_to(x: torch.Tensor, fmt: Format):
|
||||
if fmt == Format.NHWC:
|
||||
x = x.permute(0, 2, 3, 1)
|
||||
elif fmt == Format.NLC:
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
elif fmt == Format.NCL:
|
||||
x = x.flatten(2)
|
||||
return x
|
||||
|
||||
|
||||
def nhwc_to(x: torch.Tensor, fmt: Format):
|
||||
if fmt == Format.NCHW:
|
||||
x = x.permute(0, 3, 1, 2)
|
||||
elif fmt == Format.NLC:
|
||||
x = x.flatten(1, 2)
|
||||
elif fmt == Format.NCL:
|
||||
x = x.flatten(1, 2).transpose(1, 2)
|
||||
return x
|
|
@ -9,12 +9,13 @@ Based on code in:
|
|||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import logging
|
||||
from typing import List
|
||||
from typing import List, Optional, Callable
|
||||
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .format import Format, nchw_to
|
||||
from .helpers import to_2tuple
|
||||
from .trace_utils import _assert
|
||||
|
||||
|
@ -24,15 +25,18 @@ _logger = logging.getLogger(__name__)
|
|||
class PatchEmbed(nn.Module):
|
||||
""" 2D Image to Patch Embedding
|
||||
"""
|
||||
output_fmt: Format
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_chans=3,
|
||||
embed_dim=768,
|
||||
norm_layer=None,
|
||||
flatten=True,
|
||||
bias=True,
|
||||
img_size: int = 224,
|
||||
patch_size: int = 16,
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
norm_layer: Optional[Callable] = None,
|
||||
flatten: bool = True,
|
||||
output_fmt: Optional[str] = None,
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
img_size = to_2tuple(img_size)
|
||||
|
@ -41,7 +45,13 @@ class PatchEmbed(nn.Module):
|
|||
self.patch_size = patch_size
|
||||
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
||||
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
||||
self.flatten = flatten
|
||||
if output_fmt is not None:
|
||||
self.flatten = False
|
||||
self.output_fmt = Format(output_fmt)
|
||||
else:
|
||||
# flatten spatial dim and transpose to channels last, kept for bwd compat
|
||||
self.flatten = flatten
|
||||
self.output_fmt = Format.NCHW
|
||||
|
||||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
|
||||
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
||||
|
@ -52,7 +62,9 @@ class PatchEmbed(nn.Module):
|
|||
_assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")
|
||||
x = self.proj(x)
|
||||
if self.flatten:
|
||||
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
|
||||
elif self.output_fmt != Format.NCHW:
|
||||
x = nchw_to(x, self.output_fmt)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
|
|
|
@ -15,26 +15,48 @@ from .weight_init import trunc_normal_
|
|||
|
||||
def gen_relative_position_index(
|
||||
q_size: Tuple[int, int],
|
||||
k_size: Tuple[int, int] = None,
|
||||
class_token: bool = False) -> torch.Tensor:
|
||||
k_size: Optional[Tuple[int, int]] = None,
|
||||
class_token: bool = False,
|
||||
) -> torch.Tensor:
|
||||
# Adapted with significant modifications from Swin / BeiT codebases
|
||||
# get pair-wise relative position index for each token inside the window
|
||||
q_coords = torch.stack(torch.meshgrid([torch.arange(q_size[0]), torch.arange(q_size[1])])).flatten(1) # 2, Wh, Ww
|
||||
if k_size is None:
|
||||
k_coords = q_coords
|
||||
k_size = q_size
|
||||
coords = torch.stack(
|
||||
torch.meshgrid([
|
||||
torch.arange(q_size[0]),
|
||||
torch.arange(q_size[1])
|
||||
])
|
||||
).flatten(1) # 2, Wh, Ww
|
||||
relative_coords = coords[:, :, None] - coords[:, None, :] # 2, Wh*Ww, Wh*Ww
|
||||
relative_coords = relative_coords.permute(1, 2, 0) # Qh*Qw, Kh*Kw, 2
|
||||
num_relative_distance = (2 * q_size[0] - 1) * (2 * q_size[1] - 1) + 3
|
||||
else:
|
||||
# different q vs k sizes is a WIP
|
||||
k_coords = torch.stack(torch.meshgrid([torch.arange(k_size[0]), torch.arange(k_size[1])])).flatten(1)
|
||||
relative_coords = q_coords[:, :, None] - k_coords[:, None, :] # 2, Wh*Ww, Wh*Ww
|
||||
relative_coords = relative_coords.permute(1, 2, 0) # Wh*Ww, Wh*Ww, 2
|
||||
# FIXME different q vs k sizes is a WIP, need to better offset the two grids?
|
||||
q_coords = torch.stack(
|
||||
torch.meshgrid([
|
||||
torch.arange(q_size[0]),
|
||||
torch.arange(q_size[1])
|
||||
])
|
||||
).flatten(1) # 2, Wh, Ww
|
||||
k_coords = torch.stack(
|
||||
torch.meshgrid([
|
||||
torch.arange(k_size[0]),
|
||||
torch.arange(k_size[1])
|
||||
])
|
||||
).flatten(1)
|
||||
relative_coords = q_coords[:, :, None] - k_coords[:, None, :] # 2, Wh*Ww, Wh*Ww
|
||||
relative_coords = relative_coords.permute(1, 2, 0) # Qh*Qw, Kh*Kw, 2
|
||||
# relative_coords[:, :, 0] += max(q_size[0], k_size[0]) - 1 # shift to start from 0
|
||||
# relative_coords[:, :, 1] += max(q_size[1], k_size[1]) - 1
|
||||
# relative_coords[:, :, 0] *= k_size[1] + q_size[1] - 1
|
||||
# relative_position_index = relative_coords.sum(-1) # Qh*Qw, Kh*Kw
|
||||
num_relative_distance = (q_size[0] + k_size[0] - 1) * (q_size[1] + q_size[1] - 1) + 3
|
||||
|
||||
_, relative_position_index = torch.unique(relative_coords.view(-1, 2), return_inverse=True, dim=0)
|
||||
|
||||
if class_token:
|
||||
# handle cls to token & token 2 cls & cls to cls as per beit for rel pos bias
|
||||
# NOTE not intended or tested with MLP log-coords
|
||||
max_size = (max(q_size[0], k_size[0]), max(q_size[1], k_size[1]))
|
||||
num_relative_distance = (2 * max_size[0] - 1) * (2 * max_size[1] - 1) + 3
|
||||
relative_position_index = F.pad(relative_position_index, [1, 0, 1, 0])
|
||||
relative_position_index[0, 0:] = num_relative_distance - 3
|
||||
relative_position_index[0:, 0] = num_relative_distance - 2
|
||||
|
@ -59,7 +81,7 @@ class RelPosBias(nn.Module):
|
|||
self.relative_position_bias_table = nn.Parameter(torch.zeros(num_relative_distance, num_heads))
|
||||
self.register_buffer(
|
||||
"relative_position_index",
|
||||
gen_relative_position_index(self.window_size, class_token=prefix_tokens > 0),
|
||||
gen_relative_position_index(self.window_size, class_token=prefix_tokens > 0).view(-1),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
|
@ -69,7 +91,7 @@ class RelPosBias(nn.Module):
|
|||
trunc_normal_(self.relative_position_bias_table, std=.02)
|
||||
|
||||
def get_bias(self) -> torch.Tensor:
|
||||
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)]
|
||||
relative_position_bias = self.relative_position_bias_table[self.relative_position_index]
|
||||
# win_h * win_w, win_h * win_w, num_heads
|
||||
relative_position_bias = relative_position_bias.view(self.bias_shape).permute(2, 0, 1)
|
||||
return relative_position_bias.unsqueeze(0).contiguous()
|
||||
|
@ -148,7 +170,7 @@ class RelPosMlp(nn.Module):
|
|||
|
||||
self.register_buffer(
|
||||
"relative_position_index",
|
||||
gen_relative_position_index(window_size),
|
||||
gen_relative_position_index(window_size).view(-1),
|
||||
persistent=False)
|
||||
|
||||
# get relative_coords_table
|
||||
|
@ -160,8 +182,7 @@ class RelPosMlp(nn.Module):
|
|||
def get_bias(self) -> torch.Tensor:
|
||||
relative_position_bias = self.mlp(self.rel_coords_log)
|
||||
if self.relative_position_index is not None:
|
||||
relative_position_bias = relative_position_bias.view(-1, self.num_heads)[
|
||||
self.relative_position_index.view(-1)] # Wh*Ww,Wh*Ww,nH
|
||||
relative_position_bias = relative_position_bias.view(-1, self.num_heads)[self.relative_position_index]
|
||||
relative_position_bias = relative_position_bias.view(self.bias_shape)
|
||||
relative_position_bias = relative_position_bias.permute(2, 0, 1)
|
||||
relative_position_bias = self.bias_act(relative_position_bias)
|
||||
|
|
|
@ -17,6 +17,7 @@ from .edgenext import *
|
|||
from .efficientformer import *
|
||||
from .efficientformer_v2 import *
|
||||
from .efficientnet import *
|
||||
from .focalnet import *
|
||||
from .gcvit import *
|
||||
from .ghostnet import *
|
||||
from .gluon_resnet import *
|
||||
|
@ -71,13 +72,14 @@ from ._builder import build_model_with_cfg, load_pretrained, load_custom_pretrai
|
|||
from ._factory import create_model, parse_model_name, safe_model_name
|
||||
from ._features import FeatureInfo, FeatureHooks, FeatureHookNet, FeatureListNet, FeatureDictNet
|
||||
from ._features_fx import FeatureGraphNet, GraphExtractNet, create_feature_extractor, \
|
||||
register_notrace_module, register_notrace_function
|
||||
from ._helpers import clean_state_dict, load_state_dict, load_checkpoint, remap_checkpoint, resume_checkpoint
|
||||
register_notrace_module, is_notrace_module, get_notrace_modules, \
|
||||
register_notrace_function, is_notrace_function, get_notrace_functions
|
||||
from ._helpers import clean_state_dict, load_state_dict, load_checkpoint, remap_state_dict, resume_checkpoint
|
||||
from ._hub import load_model_config_from_hf, load_state_dict_from_hf, push_to_hf_hub
|
||||
from ._manipulate import model_parameters, named_apply, named_modules, named_modules_with_params, \
|
||||
group_modules, group_parameters, checkpoint_seq, adapt_input_conv
|
||||
from ._pretrained import PretrainedCfg, DefaultCfg, \
|
||||
filter_pretrained_cfg, generate_default_cfgs, split_model_name_tag
|
||||
from ._pretrained import PretrainedCfg, DefaultCfg, filter_pretrained_cfg
|
||||
from ._prune import adapt_model_from_string
|
||||
from ._registry import register_model, model_entrypoint, list_models, list_pretrained, is_model, list_modules, \
|
||||
is_model_in_modules, is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value
|
||||
from ._registry import split_model_name_tag, get_arch_name, generate_default_cfgs, register_model, \
|
||||
register_model_deprecations, model_entrypoint, list_models, list_pretrained, get_deprecated_models, \
|
||||
is_model, list_modules, is_model_in_modules, is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value
|
||||
|
|
|
@ -392,6 +392,9 @@ def build_model_with_cfg(
|
|||
# Wrap the model in a feature extraction module if enabled
|
||||
if features:
|
||||
feature_cls = FeatureListNet
|
||||
output_fmt = getattr(model, 'output_fmt', None)
|
||||
if output_fmt is not None:
|
||||
feature_cfg.setdefault('output_fmt', output_fmt)
|
||||
if 'feature_cls' in feature_cfg:
|
||||
feature_cls = feature_cfg.pop('feature_cls')
|
||||
if isinstance(feature_cls, str):
|
||||
|
@ -403,7 +406,7 @@ def build_model_with_cfg(
|
|||
else:
|
||||
assert False, f'Unknown feature class {feature_cls}'
|
||||
model = feature_cls(model, **feature_cfg)
|
||||
model.pretrained_cfg = pretrained_cfg_for_features(pretrained_cfg) # add back default_cfg
|
||||
model.default_cfg = model.pretrained_cfg # alias for backwards compat
|
||||
model.pretrained_cfg = pretrained_cfg_for_features(pretrained_cfg) # add back pretrained cfg
|
||||
model.default_cfg = model.pretrained_cfg # alias for rename backwards compat (default_cfg -> pretrained_cfg)
|
||||
|
||||
return model
|
||||
|
|
|
@ -3,10 +3,10 @@ from typing import Any, Dict, Optional, Union
|
|||
from urllib.parse import urlsplit
|
||||
|
||||
from timm.layers import set_layer_config
|
||||
from ._pretrained import PretrainedCfg, split_model_name_tag
|
||||
from ._helpers import load_checkpoint
|
||||
from ._hub import load_model_config_from_hf
|
||||
from ._registry import is_model, model_entrypoint
|
||||
from ._pretrained import PretrainedCfg
|
||||
from ._registry import is_model, model_entrypoint, split_model_name_tag
|
||||
|
||||
|
||||
__all__ = ['parse_model_name', 'safe_model_name', 'create_model']
|
||||
|
|
|
@ -17,6 +17,8 @@ import torch
|
|||
import torch.nn as nn
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from timm.layers import Format
|
||||
|
||||
|
||||
__all__ = ['FeatureInfo', 'FeatureHooks', 'FeatureDictNet', 'FeatureListNet', 'FeatureHookNet']
|
||||
|
||||
|
@ -181,6 +183,7 @@ class FeatureDictNet(nn.ModuleDict):
|
|||
model: nn.Module,
|
||||
out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
|
||||
out_map: Sequence[Union[int, str]] = None,
|
||||
output_fmt: str = 'NCHW',
|
||||
feature_concat: bool = False,
|
||||
flatten_sequential: bool = False,
|
||||
):
|
||||
|
@ -195,6 +198,7 @@ class FeatureDictNet(nn.ModuleDict):
|
|||
"""
|
||||
super(FeatureDictNet, self).__init__()
|
||||
self.feature_info = _get_feature_info(model, out_indices)
|
||||
self.output_fmt = Format(output_fmt)
|
||||
self.concat = feature_concat
|
||||
self.grad_checkpointing = False
|
||||
self.return_layers = {}
|
||||
|
@ -253,6 +257,7 @@ class FeatureListNet(FeatureDictNet):
|
|||
self,
|
||||
model: nn.Module,
|
||||
out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
|
||||
output_fmt: str = 'NCHW',
|
||||
feature_concat: bool = False,
|
||||
flatten_sequential: bool = False,
|
||||
):
|
||||
|
@ -264,9 +269,10 @@ class FeatureListNet(FeatureDictNet):
|
|||
first element e.g. `x[0]`
|
||||
flatten_sequential: Flatten first two-levels of sequential modules in model (re-writes model modules)
|
||||
"""
|
||||
super(FeatureListNet, self).__init__(
|
||||
super().__init__(
|
||||
model,
|
||||
out_indices=out_indices,
|
||||
output_fmt=output_fmt,
|
||||
feature_concat=feature_concat,
|
||||
flatten_sequential=flatten_sequential,
|
||||
)
|
||||
|
@ -293,7 +299,8 @@ class FeatureHookNet(nn.ModuleDict):
|
|||
model: nn.Module,
|
||||
out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
|
||||
out_map: Sequence[Union[int, str]] = None,
|
||||
out_as_dict: bool = False,
|
||||
return_dict: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
no_rewrite: bool = False,
|
||||
flatten_sequential: bool = False,
|
||||
default_hook_type: str = 'forward',
|
||||
|
@ -304,16 +311,17 @@ class FeatureHookNet(nn.ModuleDict):
|
|||
model: Model from which to extract features.
|
||||
out_indices: Output indices of the model features to extract.
|
||||
out_map: Return id mapping for each output index, otherwise str(index) is used.
|
||||
out_as_dict: Output features as a dict.
|
||||
return_dict: Output features as a dict.
|
||||
no_rewrite: Enforce that model is not re-written if True, ie no modules are removed / changed.
|
||||
flatten_sequential arg must also be False if this is set True.
|
||||
flatten_sequential: Re-write modules by flattening first two levels of nn.Sequential containers.
|
||||
default_hook_type: The default hook type to use if not specified in model.feature_info.
|
||||
"""
|
||||
super(FeatureHookNet, self).__init__()
|
||||
super().__init__()
|
||||
assert not torch.jit.is_scripting()
|
||||
self.feature_info = _get_feature_info(model, out_indices)
|
||||
self.out_as_dict = out_as_dict
|
||||
self.return_dict = return_dict
|
||||
self.output_fmt = Format(output_fmt)
|
||||
self.grad_checkpointing = False
|
||||
|
||||
layers = OrderedDict()
|
||||
|
@ -356,4 +364,4 @@ class FeatureHookNet(nn.ModuleDict):
|
|||
else:
|
||||
x = module(x)
|
||||
out = self.hooks.get_output(x.device)
|
||||
return out if self.out_as_dict else list(out.values())
|
||||
return out if self.return_dict else list(out.values())
|
||||
|
|
|
@ -19,6 +19,11 @@ from timm.layers import Conv2dSame, ScaledStdConv2dSame, CondConv2d, StdConv2dSa
|
|||
from timm.layers.non_local_attn import BilinearAttnTransform
|
||||
from timm.layers.pool2d_same import MaxPool2dSame, AvgPool2dSame
|
||||
|
||||
__all__ = ['register_notrace_module', 'is_notrace_module', 'get_notrace_modules',
|
||||
'register_notrace_function', 'is_notrace_function', 'get_notrace_functions',
|
||||
'create_feature_extractor', 'FeatureGraphNet', 'GraphExtractNet']
|
||||
|
||||
|
||||
# NOTE: By default, any modules from timm.models.layers that we want to treat as leaf modules go here
|
||||
# BUT modules from timm.models should use the registration mechanism below
|
||||
_leaf_modules = {
|
||||
|
@ -35,10 +40,6 @@ except ImportError:
|
|||
pass
|
||||
|
||||
|
||||
__all__ = ['register_notrace_module', 'register_notrace_function', 'create_feature_extractor',
|
||||
'FeatureGraphNet', 'GraphExtractNet']
|
||||
|
||||
|
||||
def register_notrace_module(module: Type[nn.Module]):
|
||||
"""
|
||||
Any module not under timm.models.layers should get this decorator if we don't want to trace through it.
|
||||
|
@ -47,6 +48,14 @@ def register_notrace_module(module: Type[nn.Module]):
|
|||
return module
|
||||
|
||||
|
||||
def is_notrace_module(module: Type[nn.Module]):
|
||||
return module in _leaf_modules
|
||||
|
||||
|
||||
def get_notrace_modules():
|
||||
return list(_leaf_modules)
|
||||
|
||||
|
||||
# Functions we want to autowrap (treat them as leaves)
|
||||
_autowrap_functions = set()
|
||||
|
||||
|
@ -59,6 +68,14 @@ def register_notrace_function(func: Callable):
|
|||
return func
|
||||
|
||||
|
||||
def is_notrace_function(func: Callable):
|
||||
return func in _autowrap_functions
|
||||
|
||||
|
||||
def get_notrace_functions():
|
||||
return list(_autowrap_functions)
|
||||
|
||||
|
||||
def create_feature_extractor(model: nn.Module, return_nodes: Union[Dict[str, str], List[str]]):
|
||||
assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
|
||||
return _create_feature_extractor(
|
||||
|
|
|
@ -5,6 +5,7 @@ Hacked together by / Copyright 2020 Ross Wightman
|
|||
import logging
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Callable, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
try:
|
||||
|
@ -13,30 +14,32 @@ try:
|
|||
except ImportError:
|
||||
_has_safetensors = False
|
||||
|
||||
import timm.models._builder
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
__all__ = ['clean_state_dict', 'load_state_dict', 'load_checkpoint', 'remap_checkpoint', 'resume_checkpoint']
|
||||
__all__ = ['clean_state_dict', 'load_state_dict', 'load_checkpoint', 'remap_state_dict', 'resume_checkpoint']
|
||||
|
||||
|
||||
def clean_state_dict(state_dict):
|
||||
def clean_state_dict(state_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# 'clean' checkpoint by removing .module prefix from state dict if it exists from parallel training
|
||||
cleaned_state_dict = OrderedDict()
|
||||
cleaned_state_dict = {}
|
||||
for k, v in state_dict.items():
|
||||
name = k[7:] if k.startswith('module.') else k
|
||||
cleaned_state_dict[name] = v
|
||||
return cleaned_state_dict
|
||||
|
||||
|
||||
def load_state_dict(checkpoint_path, use_ema=True):
|
||||
def load_state_dict(
|
||||
checkpoint_path: str,
|
||||
use_ema: bool = True,
|
||||
device: Union[str, torch.device] = 'cpu',
|
||||
) -> Dict[str, Any]:
|
||||
if checkpoint_path and os.path.isfile(checkpoint_path):
|
||||
# Check if safetensors or not and load weights accordingly
|
||||
if str(checkpoint_path).endswith(".safetensors"):
|
||||
assert _has_safetensors, "`pip install safetensors` to use .safetensors"
|
||||
checkpoint = safetensors.torch.load_file(checkpoint_path, device='cpu')
|
||||
checkpoint = safetensors.torch.load_file(checkpoint_path, device=device)
|
||||
else:
|
||||
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
||||
checkpoint = torch.load(checkpoint_path, map_location=device)
|
||||
|
||||
state_dict_key = ''
|
||||
if isinstance(checkpoint, dict):
|
||||
|
@ -56,22 +59,37 @@ def load_state_dict(checkpoint_path, use_ema=True):
|
|||
raise FileNotFoundError()
|
||||
|
||||
|
||||
def load_checkpoint(model, checkpoint_path, use_ema=True, strict=True, remap=False):
|
||||
def load_checkpoint(
|
||||
model: torch.nn.Module,
|
||||
checkpoint_path: str,
|
||||
use_ema: bool = True,
|
||||
device: Union[str, torch.device] = 'cpu',
|
||||
strict: bool = True,
|
||||
remap: bool = False,
|
||||
filter_fn: Optional[Callable] = None,
|
||||
):
|
||||
if os.path.splitext(checkpoint_path)[-1].lower() in ('.npz', '.npy'):
|
||||
# numpy checkpoint, try to load via model specific load_pretrained fn
|
||||
if hasattr(model, 'load_pretrained'):
|
||||
timm.models._model_builder.load_pretrained(checkpoint_path)
|
||||
model.load_pretrained(checkpoint_path)
|
||||
else:
|
||||
raise NotImplementedError('Model cannot load numpy checkpoint')
|
||||
return
|
||||
state_dict = load_state_dict(checkpoint_path, use_ema)
|
||||
|
||||
state_dict = load_state_dict(checkpoint_path, use_ema, device=device)
|
||||
if remap:
|
||||
state_dict = remap_checkpoint(model, state_dict)
|
||||
state_dict = remap_state_dict(state_dict, model)
|
||||
elif filter_fn:
|
||||
state_dict = filter_fn(state_dict, model)
|
||||
incompatible_keys = model.load_state_dict(state_dict, strict=strict)
|
||||
return incompatible_keys
|
||||
|
||||
|
||||
def remap_checkpoint(model, state_dict, allow_reshape=True):
|
||||
def remap_state_dict(
|
||||
state_dict: Dict[str, Any],
|
||||
model: torch.nn.Module,
|
||||
allow_reshape: bool = True
|
||||
):
|
||||
""" remap checkpoint by iterating over state dicts in order (ignoring original keys).
|
||||
This assumes models (and originating state dict) were created with params registered in same order.
|
||||
"""
|
||||
|
@ -87,7 +105,13 @@ def remap_checkpoint(model, state_dict, allow_reshape=True):
|
|||
return out_dict
|
||||
|
||||
|
||||
def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True):
|
||||
def resume_checkpoint(
|
||||
model: torch.nn.Module,
|
||||
checkpoint_path: str,
|
||||
optimizer: torch.optim.Optimizer = None,
|
||||
loss_scaler: Any = None,
|
||||
log_info: bool = True,
|
||||
):
|
||||
resume_epoch = None
|
||||
if os.path.isfile(checkpoint_path):
|
||||
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
||||
|
|
|
@ -3,7 +3,7 @@ import math
|
|||
import re
|
||||
from collections import defaultdict
|
||||
from itertools import chain
|
||||
from typing import Callable, Union, Dict
|
||||
from typing import Any, Callable, Dict, Iterator, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
|
@ -13,7 +13,7 @@ __all__ = ['model_parameters', 'named_apply', 'named_modules', 'named_modules_wi
|
|||
'group_with_matcher', 'group_modules', 'group_parameters', 'flatten_modules', 'checkpoint_seq']
|
||||
|
||||
|
||||
def model_parameters(model, exclude_head=False):
|
||||
def model_parameters(model: nn.Module, exclude_head: bool = False):
|
||||
if exclude_head:
|
||||
# FIXME this a bit of a quick and dirty hack to skip classifier head params based on ordering
|
||||
return [p for p in model.parameters()][:-2]
|
||||
|
@ -21,7 +21,12 @@ def model_parameters(model, exclude_head=False):
|
|||
return model.parameters()
|
||||
|
||||
|
||||
def named_apply(fn: Callable, module: nn.Module, name='', depth_first=True, include_root=False) -> nn.Module:
|
||||
def named_apply(
|
||||
fn: Callable,
|
||||
module: nn.Module, name='',
|
||||
depth_first: bool = True,
|
||||
include_root: bool = False,
|
||||
) -> nn.Module:
|
||||
if not depth_first and include_root:
|
||||
fn(module=module, name=name)
|
||||
for child_name, child_module in module.named_children():
|
||||
|
@ -32,7 +37,12 @@ def named_apply(fn: Callable, module: nn.Module, name='', depth_first=True, incl
|
|||
return module
|
||||
|
||||
|
||||
def named_modules(module: nn.Module, name='', depth_first=True, include_root=False):
|
||||
def named_modules(
|
||||
module: nn.Module,
|
||||
name: str = '',
|
||||
depth_first: bool = True,
|
||||
include_root: bool = False,
|
||||
):
|
||||
if not depth_first and include_root:
|
||||
yield name, module
|
||||
for child_name, child_module in module.named_children():
|
||||
|
@ -43,7 +53,12 @@ def named_modules(module: nn.Module, name='', depth_first=True, include_root=Fal
|
|||
yield name, module
|
||||
|
||||
|
||||
def named_modules_with_params(module: nn.Module, name='', depth_first=True, include_root=False):
|
||||
def named_modules_with_params(
|
||||
module: nn.Module,
|
||||
name: str = '',
|
||||
depth_first: bool = True,
|
||||
include_root: bool = False,
|
||||
):
|
||||
if module._parameters and not depth_first and include_root:
|
||||
yield name, module
|
||||
for child_name, child_module in module.named_children():
|
||||
|
@ -58,9 +73,9 @@ MATCH_PREV_GROUP = (99999,)
|
|||
|
||||
|
||||
def group_with_matcher(
|
||||
named_objects,
|
||||
named_objects: Iterator[Tuple[str, Any]],
|
||||
group_matcher: Union[Dict, Callable],
|
||||
output_values: bool = False,
|
||||
return_values: bool = False,
|
||||
reverse: bool = False
|
||||
):
|
||||
if isinstance(group_matcher, dict):
|
||||
|
@ -96,7 +111,7 @@ def group_with_matcher(
|
|||
# map layers into groups via ordinals (ints or tuples of ints) from matcher
|
||||
grouping = defaultdict(list)
|
||||
for k, v in named_objects:
|
||||
grouping[_get_grouping(k)].append(v if output_values else k)
|
||||
grouping[_get_grouping(k)].append(v if return_values else k)
|
||||
|
||||
# remap to integers
|
||||
layer_id_to_param = defaultdict(list)
|
||||
|
@ -107,7 +122,7 @@ def group_with_matcher(
|
|||
layer_id_to_param[lid].extend(grouping[k])
|
||||
|
||||
if reverse:
|
||||
assert not output_values, "reverse mapping only sensible for name output"
|
||||
assert not return_values, "reverse mapping only sensible for name output"
|
||||
# output reverse mapping
|
||||
param_to_layer_id = {}
|
||||
for lid, lm in layer_id_to_param.items():
|
||||
|
@ -121,24 +136,29 @@ def group_with_matcher(
|
|||
def group_parameters(
|
||||
module: nn.Module,
|
||||
group_matcher,
|
||||
output_values=False,
|
||||
reverse=False,
|
||||
return_values: bool = False,
|
||||
reverse: bool = False,
|
||||
):
|
||||
return group_with_matcher(
|
||||
module.named_parameters(), group_matcher, output_values=output_values, reverse=reverse)
|
||||
module.named_parameters(), group_matcher, return_values=return_values, reverse=reverse)
|
||||
|
||||
|
||||
def group_modules(
|
||||
module: nn.Module,
|
||||
group_matcher,
|
||||
output_values=False,
|
||||
reverse=False,
|
||||
return_values: bool = False,
|
||||
reverse: bool = False,
|
||||
):
|
||||
return group_with_matcher(
|
||||
named_modules_with_params(module), group_matcher, output_values=output_values, reverse=reverse)
|
||||
named_modules_with_params(module), group_matcher, return_values=return_values, reverse=reverse)
|
||||
|
||||
|
||||
def flatten_modules(named_modules, depth=1, prefix='', module_types='sequential'):
|
||||
def flatten_modules(
|
||||
named_modules: Iterator[Tuple[str, nn.Module]],
|
||||
depth: int = 1,
|
||||
prefix: Union[str, Tuple[str, ...]] = '',
|
||||
module_types: Union[str, Tuple[Type[nn.Module]]] = 'sequential',
|
||||
):
|
||||
prefix_is_tuple = isinstance(prefix, tuple)
|
||||
if isinstance(module_types, str):
|
||||
if module_types == 'container':
|
||||
|
|
|
@ -4,7 +4,7 @@ from dataclasses import dataclass, field, replace, asdict
|
|||
from typing import Any, Deque, Dict, Tuple, Optional, Union
|
||||
|
||||
|
||||
__all__ = ['PretrainedCfg', 'filter_pretrained_cfg', 'DefaultCfg', 'split_model_name_tag', 'generate_default_cfgs']
|
||||
__all__ = ['PretrainedCfg', 'filter_pretrained_cfg', 'DefaultCfg']
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -91,41 +91,3 @@ class DefaultCfg:
|
|||
def default_with_tag(self):
|
||||
tag = self.tags[0]
|
||||
return tag, self.cfgs[tag]
|
||||
|
||||
|
||||
def split_model_name_tag(model_name: str, no_tag: str = '') -> Tuple[str, str]:
|
||||
model_name, *tag_list = model_name.split('.', 1)
|
||||
tag = tag_list[0] if tag_list else no_tag
|
||||
return model_name, tag
|
||||
|
||||
|
||||
def generate_default_cfgs(cfgs: Dict[str, Union[Dict[str, Any], PretrainedCfg]]):
|
||||
out = defaultdict(DefaultCfg)
|
||||
default_set = set() # no tag and tags ending with * are prioritized as default
|
||||
|
||||
for k, v in cfgs.items():
|
||||
if isinstance(v, dict):
|
||||
v = PretrainedCfg(**v)
|
||||
has_weights = v.has_weights
|
||||
|
||||
model, tag = split_model_name_tag(k)
|
||||
is_default_set = model in default_set
|
||||
priority = (has_weights and not tag) or (tag.endswith('*') and not is_default_set)
|
||||
tag = tag.strip('*')
|
||||
|
||||
default_cfg = out[model]
|
||||
|
||||
if priority:
|
||||
default_cfg.tags.appendleft(tag)
|
||||
default_set.add(model)
|
||||
elif has_weights and not default_cfg.is_pretrained:
|
||||
default_cfg.tags.appendleft(tag)
|
||||
else:
|
||||
default_cfg.tags.append(tag)
|
||||
|
||||
if has_weights:
|
||||
default_cfg.is_pretrained = True
|
||||
|
||||
default_cfg.cfgs[tag] = v
|
||||
|
||||
return out
|
||||
|
|
|
@ -5,16 +5,19 @@ Hacked together by / Copyright 2020 Ross Wightman
|
|||
import fnmatch
|
||||
import re
|
||||
import sys
|
||||
import warnings
|
||||
from collections import defaultdict, deque
|
||||
from copy import deepcopy
|
||||
from dataclasses import replace
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Sequence, Union, Tuple
|
||||
|
||||
from ._pretrained import PretrainedCfg, DefaultCfg, split_model_name_tag
|
||||
from ._pretrained import PretrainedCfg, DefaultCfg
|
||||
|
||||
__all__ = [
|
||||
'split_model_name_tag', 'get_arch_name', 'register_model', 'generate_default_cfgs',
|
||||
'list_models', 'list_pretrained', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules',
|
||||
'get_pretrained_cfg_value', 'is_model_pretrained', 'get_arch_name']
|
||||
'get_pretrained_cfg_value', 'is_model_pretrained'
|
||||
]
|
||||
|
||||
_module_to_models: Dict[str, Set[str]] = defaultdict(set) # dict of sets to check membership of model in module
|
||||
_model_to_module: Dict[str, str] = {} # mapping of model names to module names
|
||||
|
@ -23,12 +26,52 @@ _model_has_pretrained: Set[str] = set() # set of model names that have pretrain
|
|||
_model_default_cfgs: Dict[str, PretrainedCfg] = {} # central repo for model arch -> default cfg objects
|
||||
_model_pretrained_cfgs: Dict[str, PretrainedCfg] = {} # central repo for model arch.tag -> pretrained cfgs
|
||||
_model_with_tags: Dict[str, List[str]] = defaultdict(list) # shortcut to map each model arch to all model + tag names
|
||||
_module_to_deprecated_models: Dict[str, Dict[str, Optional[str]]] = defaultdict(dict)
|
||||
_deprecated_models: Dict[str, Optional[str]] = {}
|
||||
|
||||
|
||||
def split_model_name_tag(model_name: str, no_tag: str = '') -> Tuple[str, str]:
|
||||
model_name, *tag_list = model_name.split('.', 1)
|
||||
tag = tag_list[0] if tag_list else no_tag
|
||||
return model_name, tag
|
||||
|
||||
|
||||
def get_arch_name(model_name: str) -> str:
|
||||
return split_model_name_tag(model_name)[0]
|
||||
|
||||
|
||||
def generate_default_cfgs(cfgs: Dict[str, Union[Dict[str, Any], PretrainedCfg]]):
|
||||
out = defaultdict(DefaultCfg)
|
||||
default_set = set() # no tag and tags ending with * are prioritized as default
|
||||
|
||||
for k, v in cfgs.items():
|
||||
if isinstance(v, dict):
|
||||
v = PretrainedCfg(**v)
|
||||
has_weights = v.has_weights
|
||||
|
||||
model, tag = split_model_name_tag(k)
|
||||
is_default_set = model in default_set
|
||||
priority = (has_weights and not tag) or (tag.endswith('*') and not is_default_set)
|
||||
tag = tag.strip('*')
|
||||
|
||||
default_cfg = out[model]
|
||||
|
||||
if priority:
|
||||
default_cfg.tags.appendleft(tag)
|
||||
default_set.add(model)
|
||||
elif has_weights and not default_cfg.is_pretrained:
|
||||
default_cfg.tags.appendleft(tag)
|
||||
else:
|
||||
default_cfg.tags.append(tag)
|
||||
|
||||
if has_weights:
|
||||
default_cfg.is_pretrained = True
|
||||
|
||||
default_cfg.cfgs[tag] = v
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def register_model(fn: Callable[..., Any]) -> Callable[..., Any]:
|
||||
# lookup containing module
|
||||
mod = sys.modules[fn.__module__]
|
||||
|
@ -87,6 +130,37 @@ def register_model(fn: Callable[..., Any]) -> Callable[..., Any]:
|
|||
return fn
|
||||
|
||||
|
||||
def _deprecated_model_shim(deprecated_name: str, current_fn: Callable = None, current_tag: str = ''):
|
||||
def _fn(pretrained=False, **kwargs):
|
||||
assert current_fn is not None, f'Model {deprecated_name} has been removed with no replacement.'
|
||||
warnings.warn(f'Mapping deprecated model {deprecated_name} to current {current_fn.__name__}', stacklevel=2)
|
||||
pretrained_cfg = kwargs.pop('pretrained_cfg', None)
|
||||
return current_fn(pretrained=pretrained, pretrained_cfg=pretrained_cfg or current_tag, **kwargs)
|
||||
return _fn
|
||||
|
||||
|
||||
def register_model_deprecations(module_name: str, deprecation_map: Dict[str, Optional[str]]):
|
||||
mod = sys.modules[module_name]
|
||||
module_name_split = module_name.split('.')
|
||||
module_name = module_name_split[-1] if len(module_name_split) else ''
|
||||
|
||||
for deprecated, current in deprecation_map.items():
|
||||
if hasattr(mod, '__all__'):
|
||||
mod.__all__.append(deprecated)
|
||||
current_fn = None
|
||||
current_tag = ''
|
||||
if current:
|
||||
current_name, current_tag = split_model_name_tag(current)
|
||||
current_fn = getattr(mod, current_name)
|
||||
deprecated_entrypoint_fn = _deprecated_model_shim(deprecated, current_fn, current_tag)
|
||||
setattr(mod, deprecated, deprecated_entrypoint_fn)
|
||||
_model_entrypoints[deprecated] = deprecated_entrypoint_fn
|
||||
_model_to_module[deprecated] = module_name
|
||||
_module_to_models[module_name].add(deprecated)
|
||||
_deprecated_models[deprecated] = current
|
||||
_module_to_deprecated_models[module_name][deprecated] = current
|
||||
|
||||
|
||||
def _natural_key(string_: str) -> List[Union[int, str]]:
|
||||
"""See https://blog.codinghorror.com/sorting-for-humans-natural-sort-order/"""
|
||||
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
|
||||
|
@ -122,16 +196,14 @@ def list_models(
|
|||
# FIXME should this be default behaviour? or default to include_tags=True?
|
||||
include_tags = pretrained
|
||||
|
||||
if module:
|
||||
all_models: Iterable[str] = list(_module_to_models[module])
|
||||
else:
|
||||
all_models = _model_entrypoints.keys()
|
||||
all_models: Set[str] = _module_to_models[module] if module else set(_model_entrypoints.keys())
|
||||
all_models = all_models - _deprecated_models.keys() # remove deprecated models from listings
|
||||
|
||||
if include_tags:
|
||||
# expand model names to include names w/ pretrained tags
|
||||
models_with_tags = []
|
||||
models_with_tags: Set[str] = set()
|
||||
for m in all_models:
|
||||
models_with_tags.extend(_model_with_tags[m])
|
||||
models_with_tags.update(_model_with_tags[m])
|
||||
all_models = models_with_tags
|
||||
|
||||
if filter:
|
||||
|
@ -142,7 +214,7 @@ def list_models(
|
|||
if len(include_models):
|
||||
models = models.union(include_models)
|
||||
else:
|
||||
models = set(all_models)
|
||||
models = all_models
|
||||
|
||||
if exclude_filters:
|
||||
if not isinstance(exclude_filters, (tuple, list)):
|
||||
|
@ -173,6 +245,11 @@ def list_pretrained(
|
|||
)
|
||||
|
||||
|
||||
def get_deprecated_models(module: str = '') -> Dict[str, str]:
|
||||
all_deprecated = _module_to_deprecated_models[module] if module else _deprecated_models
|
||||
return deepcopy(all_deprecated)
|
||||
|
||||
|
||||
def is_model(model_name: str) -> bool:
|
||||
""" Check if a model name exists
|
||||
"""
|
||||
|
|
|
@ -63,8 +63,7 @@ from torch.utils.checkpoint import checkpoint
|
|||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||
from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._pretrained import generate_default_cfgs
|
||||
from ._registry import register_model
|
||||
from ._registry import generate_default_cfgs, register_model
|
||||
from .vision_transformer import checkpoint_filter_fn
|
||||
|
||||
__all__ = ['Beit']
|
||||
|
|
|
@ -50,8 +50,7 @@ from timm.layers import trunc_normal_, AvgPool2dSame, DropPath, Mlp, GlobalRespo
|
|||
from timm.layers import NormMlpClassifierHead, ClassifierHead
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import named_apply, checkpoint_seq
|
||||
from ._pretrained import generate_default_cfgs
|
||||
from ._registry import register_model
|
||||
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
||||
|
||||
__all__ = ['ConvNeXt'] # model_registry will add each entrypoint fn to this
|
||||
|
||||
|
@ -406,7 +405,7 @@ class ConvNeXt(nn.Module):
|
|||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes=0, global_pool=None):
|
||||
self.head.reset(num_classes, global_pool=global_pool)
|
||||
self.head.reset(num_classes, global_pool)
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.stem(x)
|
||||
|
@ -415,7 +414,7 @@ class ConvNeXt(nn.Module):
|
|||
return x
|
||||
|
||||
def forward_head(self, x, pre_logits: bool = False):
|
||||
return self.head(x, pre_logits=pre_logits)
|
||||
return self.head(x, pre_logits=True) if pre_logits else self.head(x)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
|
@ -519,6 +518,13 @@ def _cfgv2(url='', **kwargs):
|
|||
|
||||
default_cfgs = generate_default_cfgs({
|
||||
# timm specific variants
|
||||
'convnext_tiny.in12k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||
'convnext_small.in12k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||
|
||||
'convnext_atto.d2_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_d2-01bb0f51.pth',
|
||||
hf_hub_id='timm/',
|
||||
|
@ -558,12 +564,6 @@ default_cfgs = generate_default_cfgs({
|
|||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_tiny_hnf_a2h-ab7e9df2.pth',
|
||||
hf_hub_id='timm/',
|
||||
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||
'convnext_tiny.in12k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||
'convnext_small.in12k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||
|
||||
'convnext_tiny.in12k_ft_in1k_384': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
|
@ -582,25 +582,6 @@ default_cfgs = generate_default_cfgs({
|
|||
hf_hub_id='timm/',
|
||||
crop_pct=0.95, num_classes=11821),
|
||||
|
||||
'convnext_tiny.fb_in1k': _cfg(
|
||||
url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth",
|
||||
hf_hub_id='timm/',
|
||||
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||
'convnext_small.fb_in1k': _cfg(
|
||||
url="https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth",
|
||||
hf_hub_id='timm/',
|
||||
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||
'convnext_base.fb_in1k': _cfg(
|
||||
url="https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth",
|
||||
hf_hub_id='timm/',
|
||||
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||
'convnext_large.fb_in1k': _cfg(
|
||||
url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth",
|
||||
hf_hub_id='timm/',
|
||||
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||
'convnext_xlarge.untrained': _cfg(),
|
||||
'convnext_xxlarge.untrained': _cfg(),
|
||||
|
||||
'convnext_tiny.fb_in22k_ft_in1k': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_224.pth',
|
||||
hf_hub_id='timm/',
|
||||
|
@ -622,6 +603,23 @@ default_cfgs = generate_default_cfgs({
|
|||
hf_hub_id='timm/',
|
||||
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||
|
||||
'convnext_tiny.fb_in1k': _cfg(
|
||||
url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth",
|
||||
hf_hub_id='timm/',
|
||||
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||
'convnext_small.fb_in1k': _cfg(
|
||||
url="https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth",
|
||||
hf_hub_id='timm/',
|
||||
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||
'convnext_base.fb_in1k': _cfg(
|
||||
url="https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth",
|
||||
hf_hub_id='timm/',
|
||||
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||
'convnext_large.fb_in1k': _cfg(
|
||||
url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth",
|
||||
hf_hub_id='timm/',
|
||||
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||
|
||||
'convnext_tiny.fb_in22k_ft_in1k_384': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_384.pth',
|
||||
hf_hub_id='timm/',
|
||||
|
@ -1038,3 +1036,22 @@ def convnextv2_huge(pretrained=False, **kwargs):
|
|||
model_args = dict(depths=[3, 3, 27, 3], dims=[352, 704, 1408, 2816], use_grn=True, ls_init_value=None)
|
||||
model = _create_convnext('convnextv2_huge', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
register_model_deprecations(__name__, {
|
||||
'convnext_tiny_in22ft1k': 'convnext_tiny.fb_in22k_ft_in1k',
|
||||
'convnext_small_in22ft1k': 'convnext_small.fb_in22k_ft_in1k',
|
||||
'convnext_base_in22ft1k': 'convnext_base.fb_in22k_ft_in1k',
|
||||
'convnext_large_in22ft1k': 'convnext_large.fb_in22k_ft_in1k',
|
||||
'convnext_xlarge_in22ft1k': 'convnext_xlarge.fb_in22k_ft_in1k',
|
||||
'convnext_tiny_384_in22ft1k': 'convnext_tiny.fb_in22k_ft_in1k_384',
|
||||
'convnext_small_384_in22ft1k': 'convnext_small.fb_in22k_ft_in1k_384',
|
||||
'convnext_base_384_in22ft1k': 'convnext_base.fb_in22k_ft_in1k_384',
|
||||
'convnext_large_384_in22ft1k': 'convnext_large.fb_in22k_ft_in1k_384',
|
||||
'convnext_xlarge_384_in22ft1k': 'convnext_xlarge.fb_in22k_ft_in1k_384',
|
||||
'convnext_tiny_in22k': 'convnext_tiny.fb_in22k',
|
||||
'convnext_small_in22k': 'convnext_small.fb_in22k',
|
||||
'convnext_base_in22k': 'convnext_base.fb_in22k',
|
||||
'convnext_large_in22k': 'convnext_large.fb_in22k',
|
||||
'convnext_xlarge_in22k': 'convnext_xlarge.fb_in22k',
|
||||
})
|
||||
|
|
|
@ -11,8 +11,6 @@ DaViT model defs and weights adapted from https://github.com/dingmyu/davit, orig
|
|||
# Copyright (c) 2022 Mingyu Ding
|
||||
# All rights reserved.
|
||||
# This source code is licensed under the MIT license
|
||||
import itertools
|
||||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
from typing import Tuple
|
||||
|
||||
|
@ -22,13 +20,12 @@ import torch.nn.functional as F
|
|||
from torch import Tensor
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import DropPath, to_2tuple, trunc_normal_, SelectAdaptivePool2d, Mlp, LayerNorm2d, get_norm_layer
|
||||
from timm.layers import DropPath, to_2tuple, trunc_normal_, Mlp, LayerNorm2d, get_norm_layer
|
||||
from timm.layers import NormMlpClassifierHead, ClassifierHead
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features_fx import register_notrace_function
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._pretrained import generate_default_cfgs
|
||||
from ._registry import register_model
|
||||
from ._registry import generate_default_cfgs, register_model
|
||||
|
||||
__all__ = ['DaViT']
|
||||
|
||||
|
|
|
@ -359,10 +359,9 @@ class DLA(nn.Module):
|
|||
if self.drop_rate > 0.:
|
||||
x = F.dropout(x, p=self.drop_rate, training=self.training)
|
||||
if pre_logits:
|
||||
return x.flatten(1)
|
||||
else:
|
||||
x = self.fc(x)
|
||||
return self.flatten(x)
|
||||
x = self.fc(x)
|
||||
return self.flatten(x)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
|
|
|
@ -298,10 +298,9 @@ class DPN(nn.Module):
|
|||
if self.drop_rate > 0.:
|
||||
x = F.dropout(x, p=self.drop_rate, training=self.training)
|
||||
if pre_logits:
|
||||
return x.flatten(1)
|
||||
else:
|
||||
x = self.classifier(x)
|
||||
return self.flatten(x)
|
||||
x = self.classifier(x)
|
||||
return self.flatten(x)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
|
|
|
@ -21,8 +21,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|||
from timm.layers import DropPath, trunc_normal_, to_2tuple, Mlp
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._pretrained import generate_default_cfgs
|
||||
from ._registry import register_model
|
||||
from ._registry import generate_default_cfgs, register_model
|
||||
|
||||
__all__ = ['EfficientFormer'] # model_registry will add each entrypoint fn to this
|
||||
|
||||
|
|
|
@ -26,8 +26,7 @@ from timm.layers import create_conv2d, create_norm_layer, get_act_layer, get_nor
|
|||
from timm.layers import DropPath, trunc_normal_, to_2tuple, to_ntuple
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._pretrained import generate_default_cfgs
|
||||
from ._registry import register_model
|
||||
from ._registry import generate_default_cfgs, register_model
|
||||
|
||||
|
||||
EfficientFormer_width = {
|
||||
|
|
|
@ -51,8 +51,7 @@ from ._efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficie
|
|||
round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
|
||||
from ._features import FeatureInfo, FeatureHooks
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._pretrained import generate_default_cfgs
|
||||
from ._registry import register_model
|
||||
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
||||
|
||||
__all__ = ['EfficientNet', 'EfficientNetFeatures']
|
||||
|
||||
|
@ -1064,42 +1063,46 @@ default_cfgs = generate_default_cfgs({
|
|||
'efficientnetv2_xl.untrained': _cfg(
|
||||
input_size=(3, 384, 384), test_input_size=(3, 512, 512), pool_size=(12, 12), crop_pct=1.0),
|
||||
|
||||
'tf_efficientnet_b0.aa_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_aa-827b6e33.pth',
|
||||
'tf_efficientnet_b0.ns_jft_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_ns-c0e6a31c.pth',
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 224, 224)),
|
||||
'tf_efficientnet_b1.aa_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_aa-ea7a6ee0.pth',
|
||||
'tf_efficientnet_b1.ns_jft_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_ns-99dd0c41.pth',
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882),
|
||||
'tf_efficientnet_b2.aa_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_aa-60c94f97.pth',
|
||||
'tf_efficientnet_b2.ns_jft_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_ns-00306e48.pth',
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890),
|
||||
'tf_efficientnet_b3.aa_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_aa-84b4657e.pth',
|
||||
'tf_efficientnet_b3.ns_jft_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_ns-9d44bf68.pth',
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904),
|
||||
'tf_efficientnet_b4.aa_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_aa-818f208c.pth',
|
||||
'tf_efficientnet_b4.ns_jft_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_ns-d6313a46.pth',
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922),
|
||||
'tf_efficientnet_b5.ra_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ra-9a3e5369.pth',
|
||||
'tf_efficientnet_b5.ns_jft_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ns-6f26d0cf.pth',
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934),
|
||||
'tf_efficientnet_b6.aa_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_aa-80ba17e4.pth',
|
||||
'tf_efficientnet_b6.ns_jft_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_ns-51548356.pth',
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 528, 528), pool_size=(17, 17), crop_pct=0.942),
|
||||
'tf_efficientnet_b7.ra_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ra-6c08e654.pth',
|
||||
'tf_efficientnet_b7.ns_jft_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ns-1dbc32de.pth',
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949),
|
||||
'tf_efficientnet_b8.ra_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b8_ra-572d5dd9.pth',
|
||||
'tf_efficientnet_l2.ns_jft_in1k_475': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns_475-bebbd00a.pth',
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 672, 672), pool_size=(21, 21), crop_pct=0.954),
|
||||
input_size=(3, 475, 475), pool_size=(15, 15), crop_pct=0.936),
|
||||
'tf_efficientnet_l2.ns_jft_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns-df73bb44.pth',
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 800, 800), pool_size=(25, 25), crop_pct=0.96),
|
||||
|
||||
'tf_efficientnet_b0.ap_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_ap-f262efe1.pth',
|
||||
|
@ -1146,46 +1149,42 @@ default_cfgs = generate_default_cfgs({
|
|||
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD,
|
||||
input_size=(3, 672, 672), pool_size=(21, 21), crop_pct=0.954),
|
||||
|
||||
'tf_efficientnet_b0.ns_jft_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_ns-c0e6a31c.pth',
|
||||
'tf_efficientnet_b0.aa_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_aa-827b6e33.pth',
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 224, 224)),
|
||||
'tf_efficientnet_b1.ns_jft_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_ns-99dd0c41.pth',
|
||||
'tf_efficientnet_b1.aa_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_aa-ea7a6ee0.pth',
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882),
|
||||
'tf_efficientnet_b2.ns_jft_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_ns-00306e48.pth',
|
||||
'tf_efficientnet_b2.aa_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_aa-60c94f97.pth',
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890),
|
||||
'tf_efficientnet_b3.ns_jft_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_ns-9d44bf68.pth',
|
||||
'tf_efficientnet_b3.aa_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_aa-84b4657e.pth',
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904),
|
||||
'tf_efficientnet_b4.ns_jft_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_ns-d6313a46.pth',
|
||||
'tf_efficientnet_b4.aa_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_aa-818f208c.pth',
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922),
|
||||
'tf_efficientnet_b5.ns_jft_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ns-6f26d0cf.pth',
|
||||
'tf_efficientnet_b5.ra_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ra-9a3e5369.pth',
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934),
|
||||
'tf_efficientnet_b6.ns_jft_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_ns-51548356.pth',
|
||||
'tf_efficientnet_b6.aa_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_aa-80ba17e4.pth',
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 528, 528), pool_size=(17, 17), crop_pct=0.942),
|
||||
'tf_efficientnet_b7.ns_jft_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ns-1dbc32de.pth',
|
||||
'tf_efficientnet_b7.ra_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ra-6c08e654.pth',
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949),
|
||||
'tf_efficientnet_l2.ns_jft_in1k_475': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns_475-bebbd00a.pth',
|
||||
'tf_efficientnet_b8.ra_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b8_ra-572d5dd9.pth',
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 475, 475), pool_size=(15, 15), crop_pct=0.936),
|
||||
'tf_efficientnet_l2.ns_jft_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns-df73bb44.pth',
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 800, 800), pool_size=(25, 25), crop_pct=0.96),
|
||||
input_size=(3, 672, 672), pool_size=(21, 21), crop_pct=0.954),
|
||||
|
||||
'tf_efficientnet_es.in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_es-ca1afbfe.pth',
|
||||
|
@ -1248,22 +1247,6 @@ default_cfgs = generate_default_cfgs({
|
|||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
||||
input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.920, interpolation='bilinear'),
|
||||
|
||||
'tf_efficientnetv2_s.in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s-eb54923e.pth',
|
||||
hf_hub_id='timm/',
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
||||
input_size=(3, 300, 300), test_input_size=(3, 384, 384), pool_size=(10, 10), crop_pct=1.0),
|
||||
'tf_efficientnetv2_m.in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_m-cc09e0cd.pth',
|
||||
hf_hub_id='timm/',
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
||||
input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
|
||||
'tf_efficientnetv2_l.in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_l-d664b728.pth',
|
||||
hf_hub_id='timm/',
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
||||
input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
|
||||
|
||||
'tf_efficientnetv2_s.in21k_ft_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s_21ft1k-d7dafa41.pth',
|
||||
hf_hub_id='timm/',
|
||||
|
@ -1285,6 +1268,22 @@ default_cfgs = generate_default_cfgs({
|
|||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
||||
input_size=(3, 384, 384), test_input_size=(3, 512, 512), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
|
||||
|
||||
'tf_efficientnetv2_s.in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s-eb54923e.pth',
|
||||
hf_hub_id='timm/',
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
||||
input_size=(3, 300, 300), test_input_size=(3, 384, 384), pool_size=(10, 10), crop_pct=1.0),
|
||||
'tf_efficientnetv2_m.in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_m-cc09e0cd.pth',
|
||||
hf_hub_id='timm/',
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
||||
input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
|
||||
'tf_efficientnetv2_l.in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_l-d664b728.pth',
|
||||
hf_hub_id='timm/',
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
||||
input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
|
||||
|
||||
'tf_efficientnetv2_s.in21k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s_21k-6337ad01.pth',
|
||||
hf_hub_id='timm/',
|
||||
|
@ -2289,3 +2288,34 @@ def tinynet_d(pretrained=False, **kwargs):
|
|||
def tinynet_e(pretrained=False, **kwargs):
|
||||
model = _gen_tinynet('tinynet_e', 0.51, 0.6, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
register_model_deprecations(__name__, {
|
||||
'tf_efficientnet_b0_ap': 'tf_efficientnet_b0.ap_in1k',
|
||||
'tf_efficientnet_b1_ap': 'tf_efficientnet_b1.ap_in1k',
|
||||
'tf_efficientnet_b2_ap': 'tf_efficientnet_b2.ap_in1k',
|
||||
'tf_efficientnet_b3_ap': 'tf_efficientnet_b3.ap_in1k',
|
||||
'tf_efficientnet_b4_ap': 'tf_efficientnet_b4.ap_in1k',
|
||||
'tf_efficientnet_b5_ap': 'tf_efficientnet_b5.ap_in1k',
|
||||
'tf_efficientnet_b6_ap': 'tf_efficientnet_b6.ap_in1k',
|
||||
'tf_efficientnet_b7_ap': 'tf_efficientnet_b7.ap_in1k',
|
||||
'tf_efficientnet_b8_ap': 'tf_efficientnet_b8.ap_in1k',
|
||||
'tf_efficientnet_b0_ns': 'tf_efficientnet_b0.ns_jft_in1k',
|
||||
'tf_efficientnet_b1_ns': 'tf_efficientnet_b1.ns_jft_in1k',
|
||||
'tf_efficientnet_b2_ns': 'tf_efficientnet_b2.ns_jft_in1k',
|
||||
'tf_efficientnet_b3_ns': 'tf_efficientnet_b3.ns_jft_in1k',
|
||||
'tf_efficientnet_b4_ns': 'tf_efficientnet_b4.ns_jft_in1k',
|
||||
'tf_efficientnet_b5_ns': 'tf_efficientnet_b5.ns_jft_in1k',
|
||||
'tf_efficientnet_b6_ns': 'tf_efficientnet_b6.ns_jft_in1k',
|
||||
'tf_efficientnet_b7_ns': 'tf_efficientnet_b7.ns_jft_in1k',
|
||||
'tf_efficientnet_l2_ns_475': 'tf_efficientnet_l2.ns_jft_in1k_475',
|
||||
'tf_efficientnet_l2_ns': 'tf_efficientnet_l2.ns_jft_in1k',
|
||||
'tf_efficientnetv2_s_in21ft1k': 'tf_efficientnetv2_s.in21k_ft_in1k',
|
||||
'tf_efficientnetv2_m_in21ft1k': 'tf_efficientnetv2_m.in21k_ft_in1k',
|
||||
'tf_efficientnetv2_l_in21ft1k': 'tf_efficientnetv2_l.in21k_ft_in1k',
|
||||
'tf_efficientnetv2_xl_in21ft1k': 'tf_efficientnetv2_xl.in21k_ft_in1k',
|
||||
'tf_efficientnetv2_s_in21k': 'tf_efficientnetv2_s.in21k',
|
||||
'tf_efficientnetv2_m_in21k': 'tf_efficientnetv2_m.in21k',
|
||||
'tf_efficientnetv2_l_in21k': 'tf_efficientnetv2_l.in21k',
|
||||
'tf_efficientnetv2_xl_in21k': 'tf_efficientnetv2_xl.in21k',
|
||||
})
|
||||
|
|
|
@ -0,0 +1,643 @@
|
|||
""" FocalNet
|
||||
|
||||
As described in `Focal Modulation Networks` - https://arxiv.org/abs/2203.11926
|
||||
|
||||
Significant modifications and refactoring from the original impl at https://github.com/microsoft/FocalNet
|
||||
|
||||
This impl is/has:
|
||||
* fully convolutional, NCHW tensor layout throughout, seemed to have minimal performance impact but more flexible
|
||||
* re-ordered downsample / layer so that striding always at beginning of layer (stage)
|
||||
* no input size constraints or input resolution/H/W tracking through the model
|
||||
* torchscript fixed and a number of quirks cleaned up
|
||||
* feature extraction support via `features_only=True`
|
||||
"""
|
||||
# --------------------------------------------------------
|
||||
# FocalNets -- Focal Modulation Networks
|
||||
# Copyright (c) 2022 Microsoft
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# Written by Jianwei Yang (jianwyan@microsoft.com)
|
||||
# --------------------------------------------------------
|
||||
from functools import partial
|
||||
from typing import Callable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint as checkpoint
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import Mlp, DropPath, LayerNorm2d, trunc_normal_, ClassifierHead, NormMlpClassifierHead
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import named_apply
|
||||
from ._registry import generate_default_cfgs, register_model
|
||||
|
||||
__all__ = ['FocalNet']
|
||||
|
||||
|
||||
class FocalModulation(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
focal_window,
|
||||
focal_level: int,
|
||||
focal_factor: int = 2,
|
||||
bias: bool = True,
|
||||
use_post_norm: bool = False,
|
||||
normalize_modulator: bool = False,
|
||||
proj_drop: float = 0.,
|
||||
norm_layer: Callable = LayerNorm2d,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.dim = dim
|
||||
self.focal_window = focal_window
|
||||
self.focal_level = focal_level
|
||||
self.focal_factor = focal_factor
|
||||
self.use_post_norm = use_post_norm
|
||||
self.normalize_modulator = normalize_modulator
|
||||
self.input_split = [dim, dim, self.focal_level + 1]
|
||||
|
||||
self.f = nn.Conv2d(dim, 2 * dim + (self.focal_level + 1), kernel_size=1, bias=bias)
|
||||
self.h = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
|
||||
|
||||
self.act = nn.GELU()
|
||||
self.proj = nn.Conv2d(dim, dim, kernel_size=1)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
self.focal_layers = nn.ModuleList()
|
||||
|
||||
self.kernel_sizes = []
|
||||
for k in range(self.focal_level):
|
||||
kernel_size = self.focal_factor * k + self.focal_window
|
||||
self.focal_layers.append(nn.Sequential(
|
||||
nn.Conv2d(dim, dim, kernel_size=kernel_size, groups=dim, padding=kernel_size // 2, bias=False),
|
||||
nn.GELU(),
|
||||
))
|
||||
self.kernel_sizes.append(kernel_size)
|
||||
self.norm = norm_layer(dim) if self.use_post_norm else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Args:
|
||||
x: input features with shape of (B, H, W, C)
|
||||
"""
|
||||
C = x.shape[1]
|
||||
|
||||
# pre linear projection
|
||||
x = self.f(x)
|
||||
q, ctx, gates = torch.split(x, self.input_split, 1)
|
||||
|
||||
# context aggreation
|
||||
ctx_all = 0
|
||||
for l, focal_layer in enumerate(self.focal_layers):
|
||||
ctx = focal_layer(ctx)
|
||||
ctx_all = ctx_all + ctx * gates[:, l:l + 1]
|
||||
ctx_global = self.act(ctx.mean((2, 3), keepdim=True))
|
||||
ctx_all = ctx_all + ctx_global * gates[:, self.focal_level:]
|
||||
|
||||
# normalize context
|
||||
if self.normalize_modulator:
|
||||
ctx_all = ctx_all / (self.focal_level + 1)
|
||||
|
||||
# focal modulation
|
||||
x_out = q * self.h(ctx_all)
|
||||
x_out = self.norm(x_out)
|
||||
|
||||
# post linear projection
|
||||
x_out = self.proj(x_out)
|
||||
x_out = self.proj_drop(x_out)
|
||||
return x_out
|
||||
|
||||
|
||||
class LayerScale2d(nn.Module):
|
||||
def __init__(self, dim, init_values=1e-5, inplace=False):
|
||||
super().__init__()
|
||||
self.inplace = inplace
|
||||
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
||||
|
||||
def forward(self, x):
|
||||
gamma = self.gamma.view(1, -1, 1, 1)
|
||||
return x.mul_(gamma) if self.inplace else x * gamma
|
||||
|
||||
|
||||
class FocalNetBlock(nn.Module):
|
||||
""" Focal Modulation Network Block.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
mlp_ratio: float = 4.,
|
||||
focal_level: int = 1,
|
||||
focal_window: int = 3,
|
||||
use_post_norm: bool = False,
|
||||
use_post_norm_in_modulation: bool = False,
|
||||
normalize_modulator: bool = False,
|
||||
layerscale_value: float = 1e-4,
|
||||
proj_drop: float = 0.,
|
||||
drop_path: float = 0.,
|
||||
act_layer: Callable = nn.GELU,
|
||||
norm_layer: Callable = LayerNorm2d,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
dim: Number of input channels.
|
||||
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
|
||||
focal_level: Number of focal levels.
|
||||
focal_window: Focal window size at first focal level.
|
||||
use_post_norm: Whether to use layer norm after modulation.
|
||||
use_post_norm_in_modulation: Whether to use layer norm in modulation.
|
||||
layerscale_value: Initial layerscale value.
|
||||
proj_drop: Dropout rate.
|
||||
drop_path: Stochastic depth rate.
|
||||
act_layer: Activation layer.
|
||||
norm_layer: Normalization layer.
|
||||
"""
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.mlp_ratio = mlp_ratio
|
||||
|
||||
self.focal_window = focal_window
|
||||
self.focal_level = focal_level
|
||||
self.use_post_norm = use_post_norm
|
||||
|
||||
self.norm1 = norm_layer(dim) if not use_post_norm else nn.Identity()
|
||||
self.modulation = FocalModulation(
|
||||
dim,
|
||||
focal_window=focal_window,
|
||||
focal_level=self.focal_level,
|
||||
use_post_norm=use_post_norm_in_modulation,
|
||||
normalize_modulator=normalize_modulator,
|
||||
proj_drop=proj_drop,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
self.norm1_post = norm_layer(dim) if use_post_norm else nn.Identity()
|
||||
self.ls1 = LayerScale2d(dim, layerscale_value) if layerscale_value is not None else nn.Identity()
|
||||
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
|
||||
self.norm2 = norm_layer(dim) if not use_post_norm else nn.Identity()
|
||||
self.mlp = Mlp(
|
||||
in_features=dim,
|
||||
hidden_features=int(dim * mlp_ratio),
|
||||
act_layer=act_layer,
|
||||
drop=proj_drop,
|
||||
use_conv=True,
|
||||
)
|
||||
self.norm2_post = norm_layer(dim) if use_post_norm else nn.Identity()
|
||||
self.ls2 = LayerScale2d(dim, layerscale_value) if layerscale_value is not None else nn.Identity()
|
||||
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
shortcut = x
|
||||
|
||||
# Focal Modulation
|
||||
x = self.norm1(x)
|
||||
x = self.modulation(x)
|
||||
x = self.norm1_post(x)
|
||||
x = shortcut + self.drop_path1(self.ls1(x))
|
||||
|
||||
# FFN
|
||||
x = x + self.drop_path2(self.ls2(self.norm2_post(self.mlp(self.norm2(x)))))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class FocalNetStage(nn.Module):
|
||||
""" A basic Focal Transformer layer for one stage.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
out_dim: int,
|
||||
depth: int,
|
||||
mlp_ratio: float = 4.,
|
||||
downsample: bool = True,
|
||||
focal_level: int = 1,
|
||||
focal_window: int = 1,
|
||||
use_overlap_down: bool = False,
|
||||
use_post_norm: bool = False,
|
||||
use_post_norm_in_modulation: bool = False,
|
||||
normalize_modulator: bool = False,
|
||||
layerscale_value: float = 1e-4,
|
||||
proj_drop: float = 0.,
|
||||
drop_path: float = 0.,
|
||||
norm_layer: Callable = LayerNorm2d,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
dim: Number of input channels.
|
||||
out_dim: Number of output channels.
|
||||
depth: Number of blocks.
|
||||
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
|
||||
downsample: Downsample layer at start of the layer.
|
||||
focal_level: Number of focal levels
|
||||
focal_window: Focal window size at first focal level
|
||||
use_overlap_down: User overlapped convolution in downsample layer.
|
||||
use_post_norm: Whether to use layer norm after modulation.
|
||||
use_post_norm_in_modulation: Whether to use layer norm in modulation.
|
||||
layerscale_value: Initial layerscale value
|
||||
proj_drop: Dropout rate for projections.
|
||||
drop_path: Stochastic depth rate.
|
||||
norm_layer: Normalization layer.
|
||||
"""
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.depth = depth
|
||||
self.grad_checkpointing = False
|
||||
|
||||
if downsample:
|
||||
self.downsample = Downsample(
|
||||
in_chs=dim,
|
||||
out_chs=out_dim,
|
||||
stride=2,
|
||||
overlap=use_overlap_down,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
else:
|
||||
self.downsample = nn.Identity()
|
||||
|
||||
# build blocks
|
||||
self.blocks = nn.ModuleList([
|
||||
FocalNetBlock(
|
||||
dim=out_dim,
|
||||
mlp_ratio=mlp_ratio,
|
||||
focal_level=focal_level,
|
||||
focal_window=focal_window,
|
||||
use_post_norm=use_post_norm,
|
||||
use_post_norm_in_modulation=use_post_norm_in_modulation,
|
||||
normalize_modulator=normalize_modulator,
|
||||
layerscale_value=layerscale_value,
|
||||
proj_drop=proj_drop,
|
||||
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
for i in range(depth)])
|
||||
|
||||
@torch.jit.ignore
|
||||
def set_grad_checkpointing(self, enable=True):
|
||||
self.grad_checkpointing = enable
|
||||
|
||||
def forward(self, x):
|
||||
x = self.downsample(x)
|
||||
for blk in self.blocks:
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
x = checkpoint.checkpoint(blk, x)
|
||||
else:
|
||||
x = blk(x)
|
||||
return x
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_chs: int,
|
||||
out_chs: int,
|
||||
stride: int = 4,
|
||||
overlap: bool = False,
|
||||
norm_layer: Optional[Callable] = None,
|
||||
):
|
||||
"""
|
||||
|
||||
Args:
|
||||
in_chs: Number of input image channels.
|
||||
out_chs: Number of linear projection output channels.
|
||||
stride: Downsample stride.
|
||||
overlap: Use overlapping convolutions if True.
|
||||
norm_layer: Normalization layer.
|
||||
"""
|
||||
super().__init__()
|
||||
self.stride = stride
|
||||
padding = 0
|
||||
kernel_size = stride
|
||||
if overlap:
|
||||
assert stride in (2, 4)
|
||||
if stride == 4:
|
||||
kernel_size, padding = 7, 2
|
||||
elif stride == 2:
|
||||
kernel_size, padding = 3, 1
|
||||
self.proj = nn.Conv2d(in_chs, out_chs, kernel_size=kernel_size, stride=stride, padding=padding)
|
||||
self.norm = norm_layer(out_chs) if norm_layer is not None else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.proj(x)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
|
||||
class FocalNet(nn.Module):
|
||||
"""" Focal Modulation Networks (FocalNets)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_chans: int = 3,
|
||||
num_classes: int = 1000,
|
||||
global_pool: str = 'avg',
|
||||
embed_dim: int = 96,
|
||||
depths: Tuple[int, ...] = (2, 2, 6, 2),
|
||||
mlp_ratio: float = 4.,
|
||||
focal_levels: Tuple[int, ...] = (2, 2, 2, 2),
|
||||
focal_windows: Tuple[int, ...] = (3, 3, 3, 3),
|
||||
use_overlap_down: bool = False,
|
||||
use_post_norm: bool = False,
|
||||
use_post_norm_in_modulation: bool = False,
|
||||
normalize_modulator: bool = False,
|
||||
head_hidden_size: Optional[int] = None,
|
||||
head_init_scale: float = 1.0,
|
||||
layerscale_value: Optional[float] = None,
|
||||
drop_rate: bool = 0.,
|
||||
proj_drop_rate: bool = 0.,
|
||||
drop_path_rate: bool = 0.1,
|
||||
norm_layer: Callable = partial(LayerNorm2d, eps=1e-5),
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
in_chans: Number of input image channels.
|
||||
num_classes: Number of classes for classification head.
|
||||
embed_dim: Patch embedding dimension.
|
||||
depths: Depth of each Focal Transformer layer.
|
||||
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
|
||||
focal_levels: How many focal levels at all stages. Note that this excludes the finest-grain level.
|
||||
focal_windows: The focal window size at all stages.
|
||||
use_overlap_down: Whether to use convolutional embedding.
|
||||
use_post_norm: Whether to use layernorm after modulation (it helps stablize training of large models)
|
||||
layerscale_value: Value for layer scale.
|
||||
drop_rate: Dropout rate.
|
||||
drop_path_rate: Stochastic depth rate.
|
||||
norm_layer: Normalization layer.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.num_layers = len(depths)
|
||||
embed_dim = [embed_dim * (2 ** i) for i in range(self.num_layers)]
|
||||
|
||||
self.num_classes = num_classes
|
||||
self.embed_dim = embed_dim
|
||||
self.num_features = embed_dim[-1]
|
||||
self.feature_info = []
|
||||
|
||||
self.stem = Downsample(
|
||||
in_chs=in_chans,
|
||||
out_chs=embed_dim[0],
|
||||
overlap=use_overlap_down,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
in_dim = embed_dim[0]
|
||||
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
||||
layers = []
|
||||
for i_layer in range(self.num_layers):
|
||||
out_dim = embed_dim[i_layer]
|
||||
layer = FocalNetStage(
|
||||
dim=in_dim,
|
||||
out_dim=out_dim,
|
||||
depth=depths[i_layer],
|
||||
mlp_ratio=mlp_ratio,
|
||||
downsample=i_layer > 0,
|
||||
focal_level=focal_levels[i_layer],
|
||||
focal_window=focal_windows[i_layer],
|
||||
use_overlap_down=use_overlap_down,
|
||||
use_post_norm=use_post_norm,
|
||||
use_post_norm_in_modulation=use_post_norm_in_modulation,
|
||||
normalize_modulator=normalize_modulator,
|
||||
layerscale_value=layerscale_value,
|
||||
proj_drop=proj_drop_rate,
|
||||
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
in_dim = out_dim
|
||||
layers += [layer]
|
||||
self.feature_info += [dict(num_chs=out_dim, reduction=4 * 2 ** i_layer, module=f'layers.{i_layer}')]
|
||||
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
if head_hidden_size:
|
||||
self.norm = nn.Identity()
|
||||
self.head = NormMlpClassifierHead(
|
||||
self.num_features,
|
||||
num_classes,
|
||||
hidden_size=head_hidden_size,
|
||||
pool_type=global_pool,
|
||||
drop_rate=drop_rate,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
else:
|
||||
self.norm = norm_layer(self.num_features)
|
||||
self.head = ClassifierHead(
|
||||
self.num_features,
|
||||
num_classes,
|
||||
pool_type=global_pool,
|
||||
drop_rate=drop_rate
|
||||
)
|
||||
|
||||
named_apply(partial(_init_weights, head_init_scale=head_init_scale), self)
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay(self):
|
||||
return {''}
|
||||
|
||||
@torch.jit.ignore
|
||||
def set_grad_checkpointing(self, enable=True):
|
||||
self.grad_checkpointing = enable
|
||||
for l in self.layers:
|
||||
l.set_grad_checkpointing(enable=enable)
|
||||
|
||||
@torch.jit.ignore
|
||||
def get_classifier(self):
|
||||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
self.head.reset(num_classes, pool_type=global_pool)
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.stem(x)
|
||||
x = self.layers(x)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
def forward_head(self, x, pre_logits: bool = False):
|
||||
return self.head(x, pre_logits=pre_logits)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
x = self.forward_head(x)
|
||||
return x
|
||||
|
||||
|
||||
def _init_weights(module, name=None, head_init_scale=1.0):
|
||||
if isinstance(module, nn.Conv2d):
|
||||
trunc_normal_(module.weight, std=.02)
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.Linear):
|
||||
trunc_normal_(module.weight, std=.02)
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
if name and 'head.fc' in name:
|
||||
module.weight.data.mul_(head_init_scale)
|
||||
module.bias.data.mul_(head_init_scale)
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||
'crop_pct': .9, 'interpolation': 'bicubic',
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'stem.proj', 'classifier': 'head.fc',
|
||||
'license': 'mit', **kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = generate_default_cfgs({
|
||||
"focalnet_tiny_srf.ms_in1k": _cfg(
|
||||
hf_hub_id='timm/'),
|
||||
"focalnet_small_srf.ms_in1k": _cfg(
|
||||
hf_hub_id='timm/'),
|
||||
"focalnet_base_srf.ms_in1k": _cfg(
|
||||
hf_hub_id='timm/'),
|
||||
"focalnet_tiny_lrf.ms_in1k": _cfg(
|
||||
hf_hub_id='timm/'),
|
||||
"focalnet_small_lrf.ms_in1k": _cfg(
|
||||
hf_hub_id='timm/'),
|
||||
"focalnet_base_lrf.ms_in1k": _cfg(
|
||||
hf_hub_id='timm/'),
|
||||
|
||||
"focalnet_large_fl3.ms_in22k": _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=21842),
|
||||
"focalnet_large_fl4.ms_in22k": _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=21842),
|
||||
"focalnet_xlarge_fl3.ms_in22k": _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=21842),
|
||||
"focalnet_xlarge_fl4.ms_in22k": _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=21842),
|
||||
"focalnet_huge_fl3.ms_in22k": _cfg(
|
||||
hf_hub_id='timm/',
|
||||
num_classes=21842),
|
||||
"focalnet_huge_fl4.ms_in22k": _cfg(
|
||||
hf_hub_id='timm/',
|
||||
num_classes=0),
|
||||
})
|
||||
|
||||
|
||||
def checkpoint_filter_fn(state_dict, model: FocalNet):
|
||||
state_dict = state_dict.get('model', state_dict)
|
||||
if 'stem.proj.weight' in state_dict:
|
||||
return state_dict
|
||||
import re
|
||||
out_dict = {}
|
||||
dest_dict = model.state_dict()
|
||||
for k, v in state_dict.items():
|
||||
k = re.sub(r'gamma_([0-9])', r'ls\1.gamma', k)
|
||||
k = k.replace('patch_embed', 'stem')
|
||||
k = re.sub(r'layers.(\d+).downsample', lambda x: f'layers.{int(x.group(1)) + 1}.downsample', k)
|
||||
if 'norm' in k and k not in dest_dict:
|
||||
k = re.sub(r'norm([0-9])', r'norm\1_post', k)
|
||||
k = k.replace('ln.', 'norm.')
|
||||
k = k.replace('head', 'head.fc')
|
||||
if k in dest_dict and dest_dict[k].numel() == v.numel() and dest_dict[k].shape != v.shape:
|
||||
v = v.reshape(dest_dict[k].shape)
|
||||
out_dict[k] = v
|
||||
return out_dict
|
||||
|
||||
|
||||
def _create_focalnet(variant, pretrained=False, **kwargs):
|
||||
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1))))
|
||||
out_indices = kwargs.pop('out_indices', default_out_indices)
|
||||
|
||||
model = build_model_with_cfg(
|
||||
FocalNet, variant, pretrained,
|
||||
pretrained_filter_fn=checkpoint_filter_fn,
|
||||
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
|
||||
**kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def focalnet_tiny_srf(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(depths=[2, 2, 6, 2], embed_dim=96, **kwargs)
|
||||
return _create_focalnet('focalnet_tiny_srf', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def focalnet_small_srf(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(depths=[2, 2, 18, 2], embed_dim=96, **kwargs)
|
||||
return _create_focalnet('focalnet_small_srf', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def focalnet_base_srf(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(depths=[2, 2, 18, 2], embed_dim=128, **kwargs)
|
||||
return _create_focalnet('focalnet_base_srf', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def focalnet_tiny_lrf(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(depths=[2, 2, 6, 2], embed_dim=96, focal_levels=[3, 3, 3, 3], **kwargs)
|
||||
return _create_focalnet('focalnet_tiny_lrf', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def focalnet_small_lrf(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(depths=[2, 2, 18, 2], embed_dim=96, focal_levels=[3, 3, 3, 3], **kwargs)
|
||||
return _create_focalnet('focalnet_small_lrf', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def focalnet_base_lrf(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(depths=[2, 2, 18, 2], embed_dim=128, focal_levels=[3, 3, 3, 3], **kwargs)
|
||||
return _create_focalnet('focalnet_base_lrf', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
# FocalNet large+ models
|
||||
@register_model
|
||||
def focalnet_large_fl3(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
depths=[2, 2, 18, 2], embed_dim=192, focal_levels=[3, 3, 3, 3], focal_windows=[5] * 4,
|
||||
use_post_norm=True, use_overlap_down=True, layerscale_value=1e-4, **kwargs)
|
||||
return _create_focalnet('focalnet_large_fl3', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def focalnet_large_fl4(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
depths=[2, 2, 18, 2], embed_dim=192, focal_levels=[4, 4, 4, 4],
|
||||
use_post_norm=True, use_overlap_down=True, layerscale_value=1e-4, **kwargs)
|
||||
return _create_focalnet('focalnet_large_fl4', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def focalnet_xlarge_fl3(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
depths=[2, 2, 18, 2], embed_dim=256, focal_levels=[3, 3, 3, 3], focal_windows=[5] * 4,
|
||||
use_post_norm=True, use_overlap_down=True, layerscale_value=1e-4, **kwargs)
|
||||
return _create_focalnet('focalnet_xlarge_fl3', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def focalnet_xlarge_fl4(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
depths=[2, 2, 18, 2], embed_dim=256, focal_levels=[4, 4, 4, 4],
|
||||
use_post_norm=True, use_overlap_down=True, layerscale_value=1e-4, **kwargs)
|
||||
return _create_focalnet('focalnet_xlarge_fl4', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def focalnet_huge_fl3(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
depths=[2, 2, 18, 2], embed_dim=352, focal_levels=[3, 3, 3, 3], focal_windows=[3] * 4,
|
||||
use_post_norm=True, use_post_norm_in_modulation=True, use_overlap_down=True, layerscale_value=1e-4, **kwargs)
|
||||
return _create_focalnet('focalnet_huge_fl3', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def focalnet_huge_fl4(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
depths=[2, 2, 18, 2], embed_dim=352, focal_levels=[4, 4, 4, 4],
|
||||
use_post_norm=True, use_post_norm_in_modulation=True, use_overlap_down=True, layerscale_value=1e-4, **kwargs)
|
||||
return _create_focalnet('focalnet_huge_fl4', pretrained=pretrained, **model_kwargs)
|
||||
|
|
@ -29,12 +29,11 @@ import torch.utils.checkpoint as checkpoint
|
|||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import DropPath, to_2tuple, to_ntuple, Mlp, ClassifierHead, LayerNorm2d, \
|
||||
get_attn, get_act_layer, get_norm_layer, _assert
|
||||
get_attn, get_act_layer, get_norm_layer, RelPosBias, _assert
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features_fx import register_notrace_function
|
||||
from ._manipulate import named_apply
|
||||
from ._registry import register_model
|
||||
from .vision_transformer_relpos import RelPosBias # FIXME move to common location
|
||||
|
||||
__all__ = ['GlobalContextVit']
|
||||
|
||||
|
@ -222,7 +221,7 @@ class WindowAttentionGlobal(nn.Module):
|
|||
q, k, v = qkv.unbind(0)
|
||||
q = q * self.scale
|
||||
|
||||
attn = (q @ k.transpose(-2, -1))
|
||||
attn = q @ k.transpose(-2, -1).contiguous() # NOTE contiguous() fixes an odd jit bug in PyTorch 2.0
|
||||
attn = self.rel_pos(attn)
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
|
|
@ -24,7 +24,6 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W
|
|||
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
||||
# Copyright 2020 Ross Wightman, Apache-2.0 License
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Dict
|
||||
|
||||
|
@ -35,9 +34,7 @@ from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN
|
|||
from timm.layers import to_ntuple, to_2tuple, get_act_layer, DropPath, trunc_normal_
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._pretrained import generate_default_cfgs
|
||||
from ._registry import register_model
|
||||
|
||||
from ._registry import generate_default_cfgs, register_model
|
||||
|
||||
__all__ = ['Levit']
|
||||
|
||||
|
|
|
@ -52,8 +52,7 @@ from timm.layers import RelPosMlp, RelPosBias, RelPosBiasTf
|
|||
from ._builder import build_model_with_cfg
|
||||
from ._features_fx import register_notrace_function
|
||||
from ._manipulate import named_apply, checkpoint_seq
|
||||
from ._pretrained import generate_default_cfgs
|
||||
from ._registry import register_model
|
||||
from ._registry import generate_default_cfgs, register_model
|
||||
|
||||
__all__ = ['MaxxVitCfg', 'MaxxVitConvCfg', 'MaxxVitTransformerCfg', 'MaxxVit']
|
||||
|
||||
|
|
|
@ -22,8 +22,7 @@ from ._efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficie
|
|||
round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
|
||||
from ._features import FeatureInfo, FeatureHooks
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._pretrained import generate_default_cfgs
|
||||
from ._registry import register_model
|
||||
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
||||
|
||||
__all__ = ['MobileNetV3', 'MobileNetV3Features']
|
||||
|
||||
|
@ -145,13 +144,12 @@ class MobileNetV3(nn.Module):
|
|||
x = self.global_pool(x)
|
||||
x = self.conv_head(x)
|
||||
x = self.act2(x)
|
||||
x = self.flatten(x)
|
||||
if pre_logits:
|
||||
return x.flatten(1)
|
||||
else:
|
||||
x = self.flatten(x)
|
||||
if self.drop_rate > 0.:
|
||||
x = F.dropout(x, p=self.drop_rate, training=self.training)
|
||||
return self.classifier(x)
|
||||
return x
|
||||
if self.drop_rate > 0.:
|
||||
x = F.dropout(x, p=self.drop_rate, training=self.training)
|
||||
return self.classifier(x)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
|
@ -797,3 +795,9 @@ def lcnet_150(pretrained=False, **kwargs):
|
|||
""" PP-LCNet 1.5"""
|
||||
model = _gen_lcnet('lcnet_150', 1.5, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
register_model_deprecations(__name__, {
|
||||
'mobilenetv3_large_100_miil': 'mobilenetv3_large_100.miil_in21k_ft_in1k',
|
||||
'mobilenetv3_large_100_miil_in21k': 'mobilenetv3_large_100.miil_in21k',
|
||||
})
|
||||
|
|
|
@ -17,86 +17,24 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W
|
|||
# --------------------------------------------------------
|
||||
import logging
|
||||
import math
|
||||
from typing import Optional
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, to_ntuple, trunc_normal_, _assert
|
||||
from timm.layers import PatchEmbed, Mlp, DropPath, ClassifierHead, to_2tuple, to_ntuple, trunc_normal_, _assert
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features_fx import register_notrace_function
|
||||
from ._manipulate import checkpoint_seq, named_apply
|
||||
from ._registry import register_model
|
||||
from .vision_transformer import checkpoint_filter_fn, get_init_weights_vit
|
||||
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
||||
from .vision_transformer import get_init_weights_vit
|
||||
|
||||
__all__ = ['SwinTransformer'] # model_registry will add each entrypoint fn to this
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
||||
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'patch_embed.proj', 'classifier': 'head',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
'swin_base_patch4_window12_384': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22kto1k.pth',
|
||||
input_size=(3, 384, 384), crop_pct=1.0),
|
||||
|
||||
'swin_base_patch4_window7_224': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22kto1k.pth',
|
||||
),
|
||||
|
||||
'swin_large_patch4_window12_384': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22kto1k.pth',
|
||||
input_size=(3, 384, 384), crop_pct=1.0),
|
||||
|
||||
'swin_large_patch4_window7_224': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22kto1k.pth',
|
||||
),
|
||||
|
||||
'swin_small_patch4_window7_224': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth',
|
||||
),
|
||||
|
||||
'swin_tiny_patch4_window7_224': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth',
|
||||
),
|
||||
|
||||
'swin_base_patch4_window12_384_in22k': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth',
|
||||
input_size=(3, 384, 384), crop_pct=1.0, num_classes=21841),
|
||||
|
||||
'swin_base_patch4_window7_224_in22k': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth',
|
||||
num_classes=21841),
|
||||
|
||||
'swin_large_patch4_window12_384_in22k': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth',
|
||||
input_size=(3, 384, 384), crop_pct=1.0, num_classes=21841),
|
||||
|
||||
'swin_large_patch4_window7_224_in22k': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth',
|
||||
num_classes=21841),
|
||||
|
||||
'swin_s3_tiny_224': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/s3_t-1d53f6a8.pth'
|
||||
),
|
||||
'swin_s3_small_224': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/s3_s-3bb4c69d.pth'
|
||||
),
|
||||
'swin_s3_base_224': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/s3_b-a1e95db4.pth'
|
||||
)
|
||||
}
|
||||
_int_or_tuple_2_t = Union[int, Tuple[int, int]]
|
||||
|
||||
|
||||
def window_partition(x, window_size: int):
|
||||
|
@ -132,7 +70,7 @@ def window_reverse(windows, window_size: int, H: int, W: int):
|
|||
return x
|
||||
|
||||
|
||||
def get_relative_position_index(win_h, win_w):
|
||||
def get_relative_position_index(win_h: int, win_w: int):
|
||||
# get pair-wise relative position index for each token inside the window
|
||||
coords = torch.stack(torch.meshgrid([torch.arange(win_h), torch.arange(win_w)])) # 2, Wh, Ww
|
||||
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
||||
|
@ -145,21 +83,30 @@ def get_relative_position_index(win_h, win_w):
|
|||
|
||||
|
||||
class WindowAttention(nn.Module):
|
||||
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
||||
It supports both of shifted and non-shifted window.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
num_heads (int): Number of attention heads.
|
||||
head_dim (int): Number of channels per head (dim // num_heads if not set)
|
||||
window_size (tuple[int]): The height and width of the window.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
||||
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
||||
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
||||
""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
||||
It supports shifted and non-shifted windows.
|
||||
"""
|
||||
|
||||
def __init__(self, dim, num_heads, head_dim=None, window_size=7, qkv_bias=True, attn_drop=0., proj_drop=0.):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
head_dim: Optional[int] = None,
|
||||
window_size: _int_or_tuple_2_t = 7,
|
||||
qkv_bias: bool = True,
|
||||
attn_drop: float = 0.,
|
||||
proj_drop: float = 0.,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
dim: Number of input channels.
|
||||
num_heads: Number of attention heads.
|
||||
head_dim: Number of channels per head (dim // num_heads if not set)
|
||||
window_size: The height and width of the window.
|
||||
qkv_bias: If True, add a learnable bias to query, key, value.
|
||||
attn_drop: Dropout ratio of attention weight.
|
||||
proj_drop: Dropout ratio of output.
|
||||
"""
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.window_size = to_2tuple(window_size) # Wh, Ww
|
||||
|
@ -198,7 +145,7 @@ class WindowAttention(nn.Module):
|
|||
"""
|
||||
B_, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
||||
q, k, v = qkv.unbind(0)
|
||||
|
||||
q = q * self.scale
|
||||
attn = (q @ k.transpose(-2, -1))
|
||||
|
@ -206,7 +153,7 @@ class WindowAttention(nn.Module):
|
|||
|
||||
if mask is not None:
|
||||
num_win = mask.shape[0]
|
||||
attn = attn.view(B_ // num_win, num_win, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
||||
attn = attn.view(-1, num_win, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
||||
attn = attn.view(-1, self.num_heads, N, N)
|
||||
attn = self.softmax(attn)
|
||||
else:
|
||||
|
@ -221,28 +168,41 @@ class WindowAttention(nn.Module):
|
|||
|
||||
|
||||
class SwinTransformerBlock(nn.Module):
|
||||
r""" Swin Transformer Block.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
input_resolution (tuple[int]): Input resulotion.
|
||||
window_size (int): Window size.
|
||||
num_heads (int): Number of attention heads.
|
||||
head_dim (int): Enforce the number of channels per head
|
||||
shift_size (int): Shift size for SW-MSA.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
||||
drop (float, optional): Dropout rate. Default: 0.0
|
||||
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
||||
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
||||
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
||||
""" Swin Transformer Block.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, dim, input_resolution, num_heads=4, head_dim=None, window_size=7, shift_size=0,
|
||||
mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
|
||||
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
||||
self,
|
||||
dim: int,
|
||||
input_resolution: _int_or_tuple_2_t,
|
||||
num_heads: int = 4,
|
||||
head_dim: Optional[int] = None,
|
||||
window_size: _int_or_tuple_2_t = 7,
|
||||
shift_size: int = 0,
|
||||
mlp_ratio: float = 4.,
|
||||
qkv_bias: bool = True,
|
||||
drop: float = 0.,
|
||||
attn_drop: float = 0.,
|
||||
drop_path: float = 0.,
|
||||
act_layer: Callable = nn.GELU,
|
||||
norm_layer: Callable = nn.LayerNorm,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
dim: Number of input channels.
|
||||
input_resolution: Input resolution.
|
||||
window_size: Window size.
|
||||
num_heads: Number of attention heads.
|
||||
head_dim: Enforce the number of channels per head
|
||||
shift_size: Shift size for SW-MSA.
|
||||
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias: If True, add a learnable bias to query, key, value.
|
||||
drop: Dropout rate.
|
||||
attn_drop: Attention dropout rate.
|
||||
drop_path: Stochastic depth rate.
|
||||
act_layer: Activation layer.
|
||||
norm_layer: Normalization layer.
|
||||
"""
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.input_resolution = input_resolution
|
||||
|
@ -257,12 +217,23 @@ class SwinTransformerBlock(nn.Module):
|
|||
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = WindowAttention(
|
||||
dim, num_heads=num_heads, head_dim=head_dim, window_size=to_2tuple(self.window_size),
|
||||
qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim,
|
||||
window_size=to_2tuple(self.window_size),
|
||||
qkv_bias=qkv_bias,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop,
|
||||
)
|
||||
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
|
||||
self.mlp = Mlp(
|
||||
in_features=dim,
|
||||
hidden_features=int(dim * mlp_ratio),
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
)
|
||||
|
||||
if self.shift_size > 0:
|
||||
# calculate attention mask for SW-MSA
|
||||
|
@ -285,17 +256,15 @@ class SwinTransformerBlock(nn.Module):
|
|||
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
||||
else:
|
||||
attn_mask = None
|
||||
|
||||
self.register_buffer("attn_mask", attn_mask)
|
||||
|
||||
def forward(self, x):
|
||||
H, W = self.input_resolution
|
||||
B, L, C = x.shape
|
||||
_assert(L == H * W, "input feature has wrong size")
|
||||
B, H, W, C = x.shape
|
||||
_assert(H == self.input_resolution[0], "input feature has wrong size")
|
||||
_assert(W == self.input_resolution[1], "input feature has wrong size")
|
||||
|
||||
shortcut = x
|
||||
x = self.norm1(x)
|
||||
x = x.view(B, H, W, C)
|
||||
|
||||
# cyclic shift
|
||||
if self.shift_size > 0:
|
||||
|
@ -319,176 +288,221 @@ class SwinTransformerBlock(nn.Module):
|
|||
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
||||
else:
|
||||
x = shifted_x
|
||||
x = x.view(B, H * W, C)
|
||||
|
||||
# FFN
|
||||
x = shortcut + self.drop_path(x)
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||
|
||||
x = x.reshape(B, -1, C)
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||
x = x.reshape(B, H, W, C)
|
||||
return x
|
||||
|
||||
|
||||
class PatchMerging(nn.Module):
|
||||
r""" Patch Merging Layer.
|
||||
|
||||
Args:
|
||||
input_resolution (tuple[int]): Resolution of input feature.
|
||||
dim (int): Number of input channels.
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
||||
""" Patch Merging Layer.
|
||||
"""
|
||||
|
||||
def __init__(self, input_resolution, dim, out_dim=None, norm_layer=nn.LayerNorm):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
out_dim: Optional[int] = None,
|
||||
norm_layer: Callable = nn.LayerNorm,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
dim: Number of input channels.
|
||||
out_dim: Number of output channels (or 2 * dim if None)
|
||||
norm_layer: Normalization layer.
|
||||
"""
|
||||
super().__init__()
|
||||
self.input_resolution = input_resolution
|
||||
self.dim = dim
|
||||
self.out_dim = out_dim or 2 * dim
|
||||
self.norm = norm_layer(4 * dim)
|
||||
self.reduction = nn.Linear(4 * dim, self.out_dim, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
x: B, H*W, C
|
||||
"""
|
||||
H, W = self.input_resolution
|
||||
B, L, C = x.shape
|
||||
_assert(L == H * W, "input feature has wrong size")
|
||||
_assert(H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even.")
|
||||
|
||||
x = x.view(B, H, W, C)
|
||||
|
||||
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
||||
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
||||
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
||||
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
||||
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
||||
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
||||
|
||||
B, H, W, C = x.shape
|
||||
_assert(H % 2 == 0, f"x height ({H}) is not even.")
|
||||
_assert(W % 2 == 0, f"x width ({W}) is not even.")
|
||||
x = x.reshape(B, H // 2, 2, W // 2, 2, C).permute(0, 1, 3, 4, 2, 5).flatten(3)
|
||||
x = self.norm(x)
|
||||
x = self.reduction(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class BasicLayer(nn.Module):
|
||||
class SwinTransformerStage(nn.Module):
|
||||
""" A basic Swin Transformer layer for one stage.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
input_resolution (tuple[int]): Input resolution.
|
||||
depth (int): Number of blocks.
|
||||
num_heads (int): Number of attention heads.
|
||||
head_dim (int): Channels per head (dim // num_heads if not set)
|
||||
window_size (int): Local window size.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
||||
drop (float, optional): Dropout rate. Default: 0.0
|
||||
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
||||
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
||||
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, dim, out_dim, input_resolution, depth, num_heads=4, head_dim=None,
|
||||
window_size=7, mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
|
||||
drop_path=0., norm_layer=nn.LayerNorm, downsample=None):
|
||||
|
||||
self,
|
||||
dim: int,
|
||||
out_dim: int,
|
||||
input_resolution: Tuple[int, int],
|
||||
depth: int,
|
||||
downsample: bool = True,
|
||||
num_heads: int = 4,
|
||||
head_dim: Optional[int] = None,
|
||||
window_size: _int_or_tuple_2_t = 7,
|
||||
mlp_ratio: float = 4.,
|
||||
qkv_bias: bool = True,
|
||||
drop: float = 0.,
|
||||
attn_drop: float = 0.,
|
||||
drop_path: Union[List[float], float] = 0.,
|
||||
norm_layer: Callable = nn.LayerNorm,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
dim: Number of input channels.
|
||||
input_resolution: Input resolution.
|
||||
depth: Number of blocks.
|
||||
downsample: Downsample layer at the end of the layer.
|
||||
num_heads: Number of attention heads.
|
||||
head_dim: Channels per head (dim // num_heads if not set)
|
||||
window_size: Local window size.
|
||||
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias: If True, add a learnable bias to query, key, value.
|
||||
drop: Dropout rate.
|
||||
attn_drop: Attention dropout rate.
|
||||
drop_path: Stochastic depth rate.
|
||||
norm_layer: Normalization layer.
|
||||
"""
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.input_resolution = input_resolution
|
||||
self.output_resolution = tuple(i // 2 for i in input_resolution) if downsample else input_resolution
|
||||
self.depth = depth
|
||||
self.grad_checkpointing = False
|
||||
|
||||
# patch merging layer
|
||||
if downsample:
|
||||
self.downsample = PatchMerging(
|
||||
dim=dim,
|
||||
out_dim=out_dim,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
else:
|
||||
assert dim == out_dim
|
||||
self.downsample = nn.Identity()
|
||||
|
||||
# build blocks
|
||||
self.blocks = nn.Sequential(*[
|
||||
SwinTransformerBlock(
|
||||
dim=dim, input_resolution=input_resolution, num_heads=num_heads, head_dim=head_dim,
|
||||
window_size=window_size, shift_size=0 if (i % 2 == 0) else window_size // 2,
|
||||
mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop, attn_drop=attn_drop,
|
||||
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer)
|
||||
dim=out_dim,
|
||||
input_resolution=self.output_resolution,
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim,
|
||||
window_size=window_size,
|
||||
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
drop=drop,
|
||||
attn_drop=attn_drop,
|
||||
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
for i in range(depth)])
|
||||
|
||||
# patch merging layer
|
||||
if downsample is not None:
|
||||
self.downsample = downsample(input_resolution, dim=dim, out_dim=out_dim, norm_layer=norm_layer)
|
||||
else:
|
||||
self.downsample = None
|
||||
|
||||
def forward(self, x):
|
||||
x = self.downsample(x)
|
||||
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
x = checkpoint_seq(self.blocks, x)
|
||||
else:
|
||||
x = self.blocks(x)
|
||||
if self.downsample is not None:
|
||||
x = self.downsample(x)
|
||||
return x
|
||||
|
||||
|
||||
class SwinTransformer(nn.Module):
|
||||
r""" Swin Transformer
|
||||
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
|
||||
https://arxiv.org/pdf/2103.14030
|
||||
""" Swin Transformer
|
||||
|
||||
Args:
|
||||
img_size (int | tuple(int)): Input image size. Default 224
|
||||
patch_size (int | tuple(int)): Patch size. Default: 4
|
||||
in_chans (int): Number of input image channels. Default: 3
|
||||
num_classes (int): Number of classes for classification head. Default: 1000
|
||||
embed_dim (int): Patch embedding dimension. Default: 96
|
||||
depths (tuple(int)): Depth of each Swin Transformer layer.
|
||||
num_heads (tuple(int)): Number of attention heads in different layers.
|
||||
head_dim (int, tuple(int)):
|
||||
window_size (int): Window size. Default: 7
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
|
||||
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
||||
drop_rate (float): Dropout rate. Default: 0
|
||||
attn_drop_rate (float): Attention dropout rate. Default: 0
|
||||
drop_path_rate (float): Stochastic depth rate. Default: 0.1
|
||||
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
||||
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
|
||||
patch_norm (bool): If True, add normalization after patch embedding. Default: True
|
||||
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
|
||||
https://arxiv.org/pdf/2103.14030
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, global_pool='avg',
|
||||
embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), head_dim=None,
|
||||
window_size=7, mlp_ratio=4., qkv_bias=True,
|
||||
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
||||
norm_layer=nn.LayerNorm, ape=False, patch_norm=True, weight_init='', **kwargs):
|
||||
self,
|
||||
img_size: _int_or_tuple_2_t = 224,
|
||||
patch_size: int = 4,
|
||||
in_chans: int = 3,
|
||||
num_classes: int = 1000,
|
||||
global_pool: str = 'avg',
|
||||
embed_dim: int = 96,
|
||||
depths: Tuple[int, ...] = (2, 2, 6, 2),
|
||||
num_heads: Tuple[int, ...] = (3, 6, 12, 24),
|
||||
head_dim: Optional[int] = None,
|
||||
window_size: _int_or_tuple_2_t = 7,
|
||||
mlp_ratio: float = 4.,
|
||||
qkv_bias: bool = True,
|
||||
drop_rate: float = 0.,
|
||||
attn_drop_rate: float = 0.,
|
||||
drop_path_rate: float = 0.1,
|
||||
norm_layer: Union[str, Callable] = nn.LayerNorm,
|
||||
weight_init: str = '',
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
img_size: Input image size.
|
||||
patch_size: Patch size.
|
||||
in_chans: Number of input image channels.
|
||||
num_classes: Number of classes for classification head.
|
||||
embed_dim: Patch embedding dimension.
|
||||
depths: Depth of each Swin Transformer layer.
|
||||
num_heads: Number of attention heads in different layers.
|
||||
head_dim: Dimension of self-attention heads.
|
||||
window_size: Window size.
|
||||
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias: If True, add a learnable bias to query, key, value.
|
||||
drop_rate: Dropout rate.
|
||||
attn_drop_rate (float): Attention dropout rate.
|
||||
drop_path_rate (float): Stochastic depth rate.
|
||||
norm_layer (nn.Module): Normalization layer.
|
||||
"""
|
||||
super().__init__()
|
||||
assert global_pool in ('', 'avg')
|
||||
self.num_classes = num_classes
|
||||
self.global_pool = global_pool
|
||||
self.output_fmt = 'NHWC'
|
||||
|
||||
self.num_layers = len(depths)
|
||||
self.embed_dim = embed_dim
|
||||
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
|
||||
self.feature_info = []
|
||||
|
||||
if not isinstance(embed_dim, (tuple, list)):
|
||||
embed_dim = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
|
||||
|
||||
# split image into non-overlapping patches
|
||||
self.patch_embed = PatchEmbed(
|
||||
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
|
||||
norm_layer=norm_layer if patch_norm else None)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
embed_dim=embed_dim[0],
|
||||
norm_layer=norm_layer,
|
||||
output_fmt='NHWC',
|
||||
)
|
||||
self.patch_grid = self.patch_embed.grid_size
|
||||
|
||||
# absolute position embedding
|
||||
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) if ape else None
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
|
||||
# build layers
|
||||
if not isinstance(embed_dim, (tuple, list)):
|
||||
embed_dim = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
|
||||
embed_out_dim = embed_dim[1:] + [None]
|
||||
head_dim = to_ntuple(self.num_layers)(head_dim)
|
||||
window_size = to_ntuple(self.num_layers)(window_size)
|
||||
mlp_ratio = to_ntuple(self.num_layers)(mlp_ratio)
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
||||
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
|
||||
layers = []
|
||||
in_dim = embed_dim[0]
|
||||
scale = 1
|
||||
for i in range(self.num_layers):
|
||||
layers += [BasicLayer(
|
||||
dim=embed_dim[i],
|
||||
out_dim=embed_out_dim[i],
|
||||
input_resolution=(self.patch_grid[0] // (2 ** i), self.patch_grid[1] // (2 ** i)),
|
||||
out_dim = embed_dim[i]
|
||||
layers += [SwinTransformerStage(
|
||||
dim=in_dim,
|
||||
out_dim=out_dim,
|
||||
input_resolution=(
|
||||
self.patch_grid[0] // scale,
|
||||
self.patch_grid[1] // scale
|
||||
),
|
||||
depth=depths[i],
|
||||
downsample=i > 0,
|
||||
num_heads=num_heads[i],
|
||||
head_dim=head_dim[i],
|
||||
window_size=window_size[i],
|
||||
|
@ -496,29 +510,35 @@ class SwinTransformer(nn.Module):
|
|||
qkv_bias=qkv_bias,
|
||||
drop=drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[sum(depths[:i]):sum(depths[:i + 1])],
|
||||
drop_path=dpr[i],
|
||||
norm_layer=norm_layer,
|
||||
downsample=PatchMerging if (i < self.num_layers - 1) else None
|
||||
)]
|
||||
in_dim = out_dim
|
||||
if i > 0:
|
||||
scale *= 2
|
||||
self.feature_info += [dict(num_chs=out_dim, reduction=4 * scale, module=f'layers.{i}')]
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
self.norm = norm_layer(self.num_features)
|
||||
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
self.head = ClassifierHead(
|
||||
self.num_features,
|
||||
num_classes,
|
||||
pool_type=global_pool,
|
||||
drop_rate=drop_rate,
|
||||
input_fmt=self.output_fmt,
|
||||
)
|
||||
if weight_init != 'skip':
|
||||
self.init_weights(weight_init)
|
||||
|
||||
@torch.jit.ignore
|
||||
def init_weights(self, mode=''):
|
||||
assert mode in ('jax', 'jax_nlhb', 'moco', '')
|
||||
if self.absolute_pos_embed is not None:
|
||||
trunc_normal_(self.absolute_pos_embed, std=.02)
|
||||
head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
|
||||
named_apply(get_init_weights_vit(mode, head_bias=head_bias), self)
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay(self):
|
||||
nwd = {'absolute_pos_embed'}
|
||||
nwd = set()
|
||||
for n, _ in self.named_parameters():
|
||||
if 'relative_position_bias_table' in n:
|
||||
nwd.add(n)
|
||||
|
@ -527,7 +547,7 @@ class SwinTransformer(nn.Module):
|
|||
@torch.jit.ignore
|
||||
def group_matcher(self, coarse=False):
|
||||
return dict(
|
||||
stem=r'^absolute_pos_embed|patch_embed', # stem and embed
|
||||
stem=r'^patch_embed', # stem and embed
|
||||
blocks=r'^layers\.(\d+)' if coarse else [
|
||||
(r'^layers\.(\d+).downsample', (0,)),
|
||||
(r'^layers\.(\d+)\.\w+\.(\d+)', None),
|
||||
|
@ -542,28 +562,20 @@ class SwinTransformer(nn.Module):
|
|||
|
||||
@torch.jit.ignore
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
self.num_classes = num_classes
|
||||
if global_pool is not None:
|
||||
assert global_pool in ('', 'avg')
|
||||
self.global_pool = global_pool
|
||||
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
self.head.reset(num_classes, pool_type=global_pool)
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.patch_embed(x)
|
||||
if self.absolute_pos_embed is not None:
|
||||
x = x + self.absolute_pos_embed
|
||||
x = self.pos_drop(x)
|
||||
x = self.layers(x)
|
||||
x = self.norm(x) # B L C
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
def forward_head(self, x, pre_logits: bool = False):
|
||||
if self.global_pool == 'avg':
|
||||
x = x.mean(dim=1)
|
||||
return x if pre_logits else self.head(x)
|
||||
return self.head(x, pre_logits=True) if pre_logits else self.head(x)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
|
@ -571,58 +583,118 @@ class SwinTransformer(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
def checkpoint_filter_fn(state_dict, model):
|
||||
""" convert patch embedding weight from manual patchify + linear proj to conv"""
|
||||
if 'head.fc.weight' in state_dict:
|
||||
return state_dict
|
||||
import re
|
||||
out_dict = {}
|
||||
state_dict = state_dict.get('model', state_dict)
|
||||
state_dict = state_dict.get('state_dict', state_dict)
|
||||
for k, v in state_dict.items():
|
||||
k = re.sub(r'layers.(\d+).downsample', lambda x: f'layers.{int(x.group(1)) + 1}.downsample', k)
|
||||
k = k.replace('head.', 'head.fc.')
|
||||
out_dict[k] = v
|
||||
return out_dict
|
||||
|
||||
|
||||
def _create_swin_transformer(variant, pretrained=False, **kwargs):
|
||||
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1))))
|
||||
out_indices = kwargs.pop('out_indices', default_out_indices)
|
||||
|
||||
model = build_model_with_cfg(
|
||||
SwinTransformer, variant, pretrained,
|
||||
pretrained_filter_fn=checkpoint_filter_fn,
|
||||
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
|
||||
**kwargs)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def swin_base_patch4_window12_384(pretrained=False, **kwargs):
|
||||
""" Swin-B @ 384x384, pretrained ImageNet-22k, fine tune 1k
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=4, window_size=12, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs)
|
||||
return _create_swin_transformer('swin_base_patch4_window12_384', pretrained=pretrained, **model_kwargs)
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'patch_embed.proj', 'classifier': 'head.fc',
|
||||
'license': 'mit', **kwargs
|
||||
}
|
||||
|
||||
|
||||
@register_model
|
||||
def swin_base_patch4_window7_224(pretrained=False, **kwargs):
|
||||
""" Swin-B @ 224x224, pretrained ImageNet-22k, fine tune 1k
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=4, window_size=7, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs)
|
||||
return _create_swin_transformer('swin_base_patch4_window7_224', pretrained=pretrained, **model_kwargs)
|
||||
default_cfgs = generate_default_cfgs({
|
||||
'swin_small_patch4_window7_224.ms_in22k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.8/swin_small_patch4_window7_224_22kto1k_finetune.pth', ),
|
||||
'swin_base_patch4_window7_224.ms_in22k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22kto1k.pth',),
|
||||
'swin_base_patch4_window12_384.ms_in22k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22kto1k.pth',
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
|
||||
'swin_large_patch4_window7_224.ms_in22k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22kto1k.pth',),
|
||||
'swin_large_patch4_window12_384.ms_in22k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22kto1k.pth',
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
|
||||
|
||||
'swin_tiny_patch4_window7_224.ms_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth',),
|
||||
'swin_small_patch4_window7_224.ms_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth',),
|
||||
'swin_base_patch4_window7_224.ms_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth',),
|
||||
'swin_base_patch4_window12_384.ms_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384.pth',
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
|
||||
|
||||
@register_model
|
||||
def swin_large_patch4_window12_384(pretrained=False, **kwargs):
|
||||
""" Swin-L @ 384x384, pretrained ImageNet-22k, fine tune 1k
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=4, window_size=12, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs)
|
||||
return _create_swin_transformer('swin_large_patch4_window12_384', pretrained=pretrained, **model_kwargs)
|
||||
# tiny 22k pretrain is worse than 1k, so moved after (untagged priority is based on order)
|
||||
'swin_tiny_patch4_window7_224.ms_in22k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.8/swin_tiny_patch4_window7_224_22kto1k_finetune.pth',),
|
||||
|
||||
'swin_tiny_patch4_window7_224.ms_in22k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.8/swin_tiny_patch4_window7_224_22k.pth',
|
||||
num_classes=21841),
|
||||
'swin_small_patch4_window7_224.ms_in22k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.8/swin_small_patch4_window7_224_22k.pth',
|
||||
num_classes=21841),
|
||||
'swin_base_patch4_window7_224.ms_in22k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth',
|
||||
num_classes=21841),
|
||||
'swin_base_patch4_window12_384.ms_in22k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth',
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=21841),
|
||||
'swin_large_patch4_window7_224.ms_in22k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth',
|
||||
num_classes=21841),
|
||||
'swin_large_patch4_window12_384.ms_in22k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth',
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=21841),
|
||||
|
||||
@register_model
|
||||
def swin_large_patch4_window7_224(pretrained=False, **kwargs):
|
||||
""" Swin-L @ 224x224, pretrained ImageNet-22k, fine tune 1k
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=4, window_size=7, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs)
|
||||
return _create_swin_transformer('swin_large_patch4_window7_224', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def swin_small_patch4_window7_224(pretrained=False, **kwargs):
|
||||
""" Swin-S @ 224x224, trained ImageNet-1k
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 18, 2), num_heads=(3, 6, 12, 24), **kwargs)
|
||||
return _create_swin_transformer('swin_small_patch4_window7_224', pretrained=pretrained, **model_kwargs)
|
||||
'swin_s3_tiny_224.ms_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/s3_t-1d53f6a8.pth'),
|
||||
'swin_s3_small_224.ms_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/s3_s-3bb4c69d.pth'),
|
||||
'swin_s3_base_224.ms_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/s3_b-a1e95db4.pth'),
|
||||
})
|
||||
|
||||
|
||||
@register_model
|
||||
|
@ -635,44 +707,53 @@ def swin_tiny_patch4_window7_224(pretrained=False, **kwargs):
|
|||
|
||||
|
||||
@register_model
|
||||
def swin_base_patch4_window12_384_in22k(pretrained=False, **kwargs):
|
||||
""" Swin-B @ 384x384, trained ImageNet-22k
|
||||
def swin_small_patch4_window7_224(pretrained=False, **kwargs):
|
||||
""" Swin-S @ 224x224
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=4, window_size=12, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs)
|
||||
return _create_swin_transformer('swin_base_patch4_window12_384_in22k', pretrained=pretrained, **model_kwargs)
|
||||
patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 18, 2), num_heads=(3, 6, 12, 24), **kwargs)
|
||||
return _create_swin_transformer('swin_small_patch4_window7_224', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def swin_base_patch4_window7_224_in22k(pretrained=False, **kwargs):
|
||||
""" Swin-B @ 224x224, trained ImageNet-22k
|
||||
def swin_base_patch4_window7_224(pretrained=False, **kwargs):
|
||||
""" Swin-B @ 224x224
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=4, window_size=7, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs)
|
||||
return _create_swin_transformer('swin_base_patch4_window7_224_in22k', pretrained=pretrained, **model_kwargs)
|
||||
return _create_swin_transformer('swin_base_patch4_window7_224', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def swin_large_patch4_window12_384_in22k(pretrained=False, **kwargs):
|
||||
""" Swin-L @ 384x384, trained ImageNet-22k
|
||||
def swin_base_patch4_window12_384(pretrained=False, **kwargs):
|
||||
""" Swin-B @ 384x384
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=4, window_size=12, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs)
|
||||
return _create_swin_transformer('swin_large_patch4_window12_384_in22k', pretrained=pretrained, **model_kwargs)
|
||||
patch_size=4, window_size=12, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs)
|
||||
return _create_swin_transformer('swin_base_patch4_window12_384', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def swin_large_patch4_window7_224_in22k(pretrained=False, **kwargs):
|
||||
""" Swin-L @ 224x224, trained ImageNet-22k
|
||||
def swin_large_patch4_window7_224(pretrained=False, **kwargs):
|
||||
""" Swin-L @ 224x224
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=4, window_size=7, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs)
|
||||
return _create_swin_transformer('swin_large_patch4_window7_224_in22k', pretrained=pretrained, **model_kwargs)
|
||||
return _create_swin_transformer('swin_large_patch4_window7_224', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def swin_large_patch4_window12_384(pretrained=False, **kwargs):
|
||||
""" Swin-L @ 384x384
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=4, window_size=12, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs)
|
||||
return _create_swin_transformer('swin_large_patch4_window12_384', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def swin_s3_tiny_224(pretrained=False, **kwargs):
|
||||
""" Swin-S3-T @ 224x224, ImageNet-1k. https://arxiv.org/abs/2111.14725
|
||||
""" Swin-S3-T @ 224x224, https://arxiv.org/abs/2111.14725
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=4, window_size=(7, 7, 14, 7), embed_dim=96, depths=(2, 2, 6, 2),
|
||||
|
@ -682,7 +763,7 @@ def swin_s3_tiny_224(pretrained=False, **kwargs):
|
|||
|
||||
@register_model
|
||||
def swin_s3_small_224(pretrained=False, **kwargs):
|
||||
""" Swin-S3-S @ 224x224, trained ImageNet-1k. https://arxiv.org/abs/2111.14725
|
||||
""" Swin-S3-S @ 224x224, https://arxiv.org/abs/2111.14725
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=4, window_size=(14, 14, 14, 7), embed_dim=96, depths=(2, 2, 18, 2),
|
||||
|
@ -692,10 +773,17 @@ def swin_s3_small_224(pretrained=False, **kwargs):
|
|||
|
||||
@register_model
|
||||
def swin_s3_base_224(pretrained=False, **kwargs):
|
||||
""" Swin-S3-B @ 224x224, trained ImageNet-1k. https://arxiv.org/abs/2111.14725
|
||||
""" Swin-S3-B @ 224x224, https://arxiv.org/abs/2111.14725
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=4, window_size=(7, 7, 14, 7), embed_dim=96, depths=(2, 2, 30, 2),
|
||||
num_heads=(3, 6, 12, 24), **kwargs)
|
||||
return _create_swin_transformer('swin_s3_base_224', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
register_model_deprecations(__name__, {
|
||||
'swin_base_patch4_window7_224_in22k': 'swin_base_patch4_window7_224.ms_in22k',
|
||||
'swin_base_patch4_window12_384_in22k': 'swin_base_patch4_window12_384.ms_in22k',
|
||||
'swin_large_patch4_window7_224_in22k': 'swin_large_patch4_window7_224.ms_in22k',
|
||||
'swin_large_patch4_window12_384_in22k': 'swin_large_patch4_window12_384.ms_in22k',
|
||||
})
|
||||
|
|
|
@ -13,7 +13,7 @@ Modifications and additions for timm hacked together by / Copyright 2022, Ross W
|
|||
# Written by Ze Liu
|
||||
# --------------------------------------------------------
|
||||
import math
|
||||
from typing import Tuple, Optional
|
||||
from typing import Callable, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -21,76 +21,14 @@ import torch.nn.functional as F
|
|||
import torch.utils.checkpoint as checkpoint
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, _assert
|
||||
from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, _assert, ClassifierHead
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features_fx import register_notrace_function
|
||||
from ._registry import register_model
|
||||
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
||||
|
||||
__all__ = ['SwinTransformerV2'] # model_registry will add each entrypoint fn to this
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
||||
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'patch_embed.proj', 'classifier': 'head',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
'swinv2_tiny_window8_256': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window8_256.pth',
|
||||
input_size=(3, 256, 256)
|
||||
),
|
||||
'swinv2_tiny_window16_256': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window16_256.pth',
|
||||
input_size=(3, 256, 256)
|
||||
),
|
||||
'swinv2_small_window8_256': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_small_patch4_window8_256.pth',
|
||||
input_size=(3, 256, 256)
|
||||
),
|
||||
'swinv2_small_window16_256': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_small_patch4_window16_256.pth',
|
||||
input_size=(3, 256, 256)
|
||||
),
|
||||
'swinv2_base_window8_256': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window8_256.pth',
|
||||
input_size=(3, 256, 256)
|
||||
),
|
||||
'swinv2_base_window16_256': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window16_256.pth',
|
||||
input_size=(3, 256, 256)
|
||||
),
|
||||
|
||||
'swinv2_base_window12_192_22k': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12_192_22k.pth',
|
||||
num_classes=21841, input_size=(3, 192, 192)
|
||||
),
|
||||
'swinv2_base_window12to16_192to256_22kft1k': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12to16_192to256_22kto1k_ft.pth',
|
||||
input_size=(3, 256, 256)
|
||||
),
|
||||
'swinv2_base_window12to24_192to384_22kft1k': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12to24_192to384_22kto1k_ft.pth',
|
||||
input_size=(3, 384, 384), crop_pct=1.0,
|
||||
),
|
||||
'swinv2_large_window12_192_22k': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12_192_22k.pth',
|
||||
num_classes=21841, input_size=(3, 192, 192)
|
||||
),
|
||||
'swinv2_large_window12to16_192to256_22kft1k': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12to16_192to256_22kto1k_ft.pth',
|
||||
input_size=(3, 256, 256)
|
||||
),
|
||||
'swinv2_large_window12to24_192to384_22kft1k': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12to24_192to384_22kto1k_ft.pth',
|
||||
input_size=(3, 384, 384), crop_pct=1.0,
|
||||
),
|
||||
}
|
||||
_int_or_tuple_2_t = Union[int, Tuple[int, int]]
|
||||
|
||||
|
||||
def window_partition(x, window_size: Tuple[int, int]):
|
||||
|
@ -141,9 +79,15 @@ class WindowAttention(nn.Module):
|
|||
"""
|
||||
|
||||
def __init__(
|
||||
self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
|
||||
pretrained_window_size=[0, 0]):
|
||||
|
||||
self,
|
||||
dim,
|
||||
window_size,
|
||||
num_heads,
|
||||
qkv_bias=True,
|
||||
attn_drop=0.,
|
||||
proj_drop=0.,
|
||||
pretrained_window_size=[0, 0],
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.window_size = window_size # Wh, Ww
|
||||
|
@ -231,8 +175,8 @@ class WindowAttention(nn.Module):
|
|||
attn = attn + relative_position_bias.unsqueeze(0)
|
||||
|
||||
if mask is not None:
|
||||
nW = mask.shape[0]
|
||||
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
||||
num_win = mask.shape[0]
|
||||
attn = attn.view(-1, num_win, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
||||
attn = attn.view(-1, self.num_heads, N, N)
|
||||
attn = self.softmax(attn)
|
||||
else:
|
||||
|
@ -246,29 +190,42 @@ class WindowAttention(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
class SwinTransformerBlock(nn.Module):
|
||||
r""" Swin Transformer Block.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
input_resolution (tuple[int]): Input resolution.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (int): Window size.
|
||||
shift_size (int): Shift size for SW-MSA.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
||||
drop (float, optional): Dropout rate. Default: 0.0
|
||||
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
||||
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
||||
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
||||
pretrained_window_size (int): Window size in pretraining.
|
||||
class SwinTransformerV2Block(nn.Module):
|
||||
""" Swin Transformer Block.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
|
||||
mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
|
||||
act_layer=nn.GELU, norm_layer=nn.LayerNorm, pretrained_window_size=0):
|
||||
self,
|
||||
dim,
|
||||
input_resolution,
|
||||
num_heads,
|
||||
window_size=7,
|
||||
shift_size=0,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=True,
|
||||
drop=0.,
|
||||
attn_drop=0.,
|
||||
drop_path=0.,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=nn.LayerNorm,
|
||||
pretrained_window_size=0,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
dim: Number of input channels.
|
||||
input_resolution: Input resolution.
|
||||
num_heads: Number of attention heads.
|
||||
window_size: Window size.
|
||||
shift_size: Shift size for SW-MSA.
|
||||
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias: If True, add a learnable bias to query, key, value.
|
||||
drop: Dropout rate.
|
||||
attn_drop: Attention dropout rate.
|
||||
drop_path: Stochastic depth rate.
|
||||
act_layer: Activation layer.
|
||||
norm_layer: Normalization layer.
|
||||
pretrained_window_size: Window size in pretraining.
|
||||
"""
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.input_resolution = to_2tuple(input_resolution)
|
||||
|
@ -280,13 +237,23 @@ class SwinTransformerBlock(nn.Module):
|
|||
self.mlp_ratio = mlp_ratio
|
||||
|
||||
self.attn = WindowAttention(
|
||||
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
|
||||
qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
|
||||
pretrained_window_size=to_2tuple(pretrained_window_size))
|
||||
dim,
|
||||
window_size=to_2tuple(self.window_size),
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop,
|
||||
pretrained_window_size=to_2tuple(pretrained_window_size),
|
||||
)
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
|
||||
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
|
||||
self.mlp = Mlp(
|
||||
in_features=dim,
|
||||
hidden_features=int(dim * mlp_ratio),
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
)
|
||||
self.norm2 = norm_layer(dim)
|
||||
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
|
||||
|
@ -322,10 +289,7 @@ class SwinTransformerBlock(nn.Module):
|
|||
return tuple(window_size), tuple(shift_size)
|
||||
|
||||
def _attn(self, x):
|
||||
H, W = self.input_resolution
|
||||
B, L, C = x.shape
|
||||
_assert(L == H * W, "input feature has wrong size")
|
||||
x = x.view(B, H, W, C)
|
||||
B, H, W, C = x.shape
|
||||
|
||||
# cyclic shift
|
||||
has_shift = any(self.shift_size)
|
||||
|
@ -350,113 +314,124 @@ class SwinTransformerBlock(nn.Module):
|
|||
x = torch.roll(shifted_x, shifts=self.shift_size, dims=(1, 2))
|
||||
else:
|
||||
x = shifted_x
|
||||
x = x.view(B, H * W, C)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
B, H, W, C = x.shape
|
||||
x = x + self.drop_path1(self.norm1(self._attn(x)))
|
||||
x = x.reshape(B, -1, C)
|
||||
x = x + self.drop_path2(self.norm2(self.mlp(x)))
|
||||
x = x.reshape(B, H, W, C)
|
||||
return x
|
||||
|
||||
|
||||
class PatchMerging(nn.Module):
|
||||
r""" Patch Merging Layer.
|
||||
|
||||
Args:
|
||||
input_resolution (tuple[int]): Resolution of input feature.
|
||||
dim (int): Number of input channels.
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
||||
""" Patch Merging Layer.
|
||||
"""
|
||||
|
||||
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
|
||||
def __init__(self, dim, out_dim=None, norm_layer=nn.LayerNorm):
|
||||
"""
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
out_dim (int): Number of output channels (or 2 * dim if None)
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
self.input_resolution = input_resolution
|
||||
self.dim = dim
|
||||
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
||||
self.norm = norm_layer(2 * dim)
|
||||
self.out_dim = out_dim or 2 * dim
|
||||
self.reduction = nn.Linear(4 * dim, self.out_dim, bias=False)
|
||||
self.norm = norm_layer(self.out_dim)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
x: B, H*W, C
|
||||
"""
|
||||
H, W = self.input_resolution
|
||||
B, L, C = x.shape
|
||||
_assert(L == H * W, "input feature has wrong size")
|
||||
_assert(H % 2 == 0, f"x size ({H}*{W}) are not even.")
|
||||
_assert(W % 2 == 0, f"x size ({H}*{W}) are not even.")
|
||||
|
||||
x = x.view(B, H, W, C)
|
||||
|
||||
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
||||
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
||||
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
||||
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
||||
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
||||
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
||||
|
||||
B, H, W, C = x.shape
|
||||
_assert(H % 2 == 0, f"x height ({H}) is not even.")
|
||||
_assert(W % 2 == 0, f"x width ({W}) is not even.")
|
||||
x = x.reshape(B, H // 2, 2, W // 2, 2, C).permute(0, 1, 3, 4, 2, 5).flatten(3)
|
||||
x = self.reduction(x)
|
||||
x = self.norm(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class BasicLayer(nn.Module):
|
||||
""" A basic Swin Transformer layer for one stage.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
input_resolution (tuple[int]): Input resolution.
|
||||
depth (int): Number of blocks.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (int): Local window size.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
||||
drop (float, optional): Dropout rate. Default: 0.0
|
||||
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
||||
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
||||
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
||||
pretrained_window_size (int): Local window size in pre-training.
|
||||
class SwinTransformerV2Stage(nn.Module):
|
||||
""" A Swin Transformer V2 Stage.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, dim, input_resolution, depth, num_heads, window_size,
|
||||
mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
|
||||
norm_layer=nn.LayerNorm, downsample=None, pretrained_window_size=0):
|
||||
|
||||
self,
|
||||
dim,
|
||||
out_dim,
|
||||
input_resolution,
|
||||
depth,
|
||||
num_heads,
|
||||
window_size,
|
||||
downsample=False,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=True,
|
||||
drop=0.,
|
||||
attn_drop=0.,
|
||||
drop_path=0.,
|
||||
norm_layer=nn.LayerNorm,
|
||||
pretrained_window_size=0,
|
||||
output_nchw=False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
dim: Number of input channels.
|
||||
input_resolution: Input resolution.
|
||||
depth: Number of blocks.
|
||||
num_heads: Number of attention heads.
|
||||
window_size: Local window size.
|
||||
downsample: Use downsample layer at start of the block.
|
||||
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias: If True, add a learnable bias to query, key, value.
|
||||
drop: Dropout rate
|
||||
attn_drop: Attention dropout rate.
|
||||
drop_path: Stochastic depth rate.
|
||||
norm_layer: Normalization layer.
|
||||
pretrained_window_size: Local window size in pretraining.
|
||||
output_nchw: Output tensors on NCHW format instead of NHWC.
|
||||
"""
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.input_resolution = input_resolution
|
||||
self.output_resolution = tuple(i // 2 for i in input_resolution) if downsample else input_resolution
|
||||
self.depth = depth
|
||||
self.output_nchw = output_nchw
|
||||
self.grad_checkpointing = False
|
||||
|
||||
# patch merging / downsample layer
|
||||
if downsample:
|
||||
self.downsample = PatchMerging(dim=dim, out_dim=out_dim, norm_layer=norm_layer)
|
||||
else:
|
||||
assert dim == out_dim
|
||||
self.downsample = nn.Identity()
|
||||
|
||||
# build blocks
|
||||
self.blocks = nn.ModuleList([
|
||||
SwinTransformerBlock(
|
||||
dim=dim, input_resolution=input_resolution,
|
||||
num_heads=num_heads, window_size=window_size,
|
||||
SwinTransformerV2Block(
|
||||
dim=out_dim,
|
||||
input_resolution=self.output_resolution,
|
||||
num_heads=num_heads,
|
||||
window_size=window_size,
|
||||
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
drop=drop, attn_drop=attn_drop,
|
||||
drop=drop,
|
||||
attn_drop=attn_drop,
|
||||
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
||||
norm_layer=norm_layer,
|
||||
pretrained_window_size=pretrained_window_size)
|
||||
pretrained_window_size=pretrained_window_size,
|
||||
)
|
||||
for i in range(depth)])
|
||||
|
||||
# patch merging layer
|
||||
if downsample is not None:
|
||||
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
|
||||
else:
|
||||
self.downsample = nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.downsample(x)
|
||||
|
||||
for blk in self.blocks:
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
x = checkpoint.checkpoint(blk, x)
|
||||
else:
|
||||
x = blk(x)
|
||||
x = self.downsample(x)
|
||||
return x
|
||||
|
||||
def _init_respostnorm(self):
|
||||
|
@ -468,88 +443,113 @@ class BasicLayer(nn.Module):
|
|||
|
||||
|
||||
class SwinTransformerV2(nn.Module):
|
||||
r""" Swin Transformer V2
|
||||
A PyTorch impl of : `Swin Transformer V2: Scaling Up Capacity and Resolution`
|
||||
- https://arxiv.org/abs/2111.09883
|
||||
Args:
|
||||
img_size (int | tuple(int)): Input image size. Default 224
|
||||
patch_size (int | tuple(int)): Patch size. Default: 4
|
||||
in_chans (int): Number of input image channels. Default: 3
|
||||
num_classes (int): Number of classes for classification head. Default: 1000
|
||||
embed_dim (int): Patch embedding dimension. Default: 96
|
||||
depths (tuple(int)): Depth of each Swin Transformer layer.
|
||||
num_heads (tuple(int)): Number of attention heads in different layers.
|
||||
window_size (int): Window size. Default: 7
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
|
||||
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
||||
drop_rate (float): Dropout rate. Default: 0
|
||||
attn_drop_rate (float): Attention dropout rate. Default: 0
|
||||
drop_path_rate (float): Stochastic depth rate. Default: 0.1
|
||||
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
||||
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
|
||||
patch_norm (bool): If True, add normalization after patch embedding. Default: True
|
||||
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
|
||||
pretrained_window_sizes (tuple(int)): Pretrained window sizes of each layer.
|
||||
""" Swin Transformer V2
|
||||
|
||||
A PyTorch impl of : `Swin Transformer V2: Scaling Up Capacity and Resolution`
|
||||
- https://arxiv.org/abs/2111.09883
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, global_pool='avg',
|
||||
embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24),
|
||||
window_size=7, mlp_ratio=4., qkv_bias=True,
|
||||
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
||||
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
|
||||
pretrained_window_sizes=(0, 0, 0, 0), **kwargs):
|
||||
self,
|
||||
img_size: _int_or_tuple_2_t = 224,
|
||||
patch_size: int = 4,
|
||||
in_chans: int = 3,
|
||||
num_classes: int = 1000,
|
||||
global_pool: str = 'avg',
|
||||
embed_dim: int = 96,
|
||||
depths: Tuple[int, ...] = (2, 2, 6, 2),
|
||||
num_heads: Tuple[int, ...] = (3, 6, 12, 24),
|
||||
window_size: _int_or_tuple_2_t = 7,
|
||||
mlp_ratio: float = 4.,
|
||||
qkv_bias: bool = True,
|
||||
drop_rate: float = 0.,
|
||||
attn_drop_rate: float = 0.,
|
||||
drop_path_rate: float = 0.1,
|
||||
norm_layer: Callable = nn.LayerNorm,
|
||||
pretrained_window_sizes: Tuple[int, ...] = (0, 0, 0, 0),
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
img_size: Input image size.
|
||||
patch_size: Patch size.
|
||||
in_chans: Number of input image channels.
|
||||
num_classes: Number of classes for classification head.
|
||||
embed_dim: Patch embedding dimension.
|
||||
depths: Depth of each Swin Transformer stage (layer).
|
||||
num_heads: Number of attention heads in different layers.
|
||||
window_size: Window size.
|
||||
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias: If True, add a learnable bias to query, key, value.
|
||||
drop_rate: Dropout rate.
|
||||
attn_drop_rate: Attention dropout rate.
|
||||
drop_path_rate: Stochastic depth rate.
|
||||
norm_layer: Normalization layer.
|
||||
patch_norm: If True, add normalization after patch embedding.
|
||||
pretrained_window_sizes: Pretrained window sizes of each layer.
|
||||
output_fmt: Output tensor format if not None, otherwise output 'NHWC' by default.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.num_classes = num_classes
|
||||
assert global_pool in ('', 'avg')
|
||||
self.global_pool = global_pool
|
||||
self.output_fmt = 'NHWC'
|
||||
self.num_layers = len(depths)
|
||||
self.embed_dim = embed_dim
|
||||
self.patch_norm = patch_norm
|
||||
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
|
||||
self.feature_info = []
|
||||
|
||||
if not isinstance(embed_dim, (tuple, list)):
|
||||
embed_dim = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
|
||||
|
||||
# split image into non-overlapping patches
|
||||
self.patch_embed = PatchEmbed(
|
||||
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
|
||||
norm_layer=norm_layer if self.patch_norm else None)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
embed_dim=embed_dim[0],
|
||||
norm_layer=norm_layer,
|
||||
output_fmt='NHWC',
|
||||
)
|
||||
|
||||
# absolute position embedding
|
||||
if ape:
|
||||
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
|
||||
trunc_normal_(self.absolute_pos_embed, std=.02)
|
||||
else:
|
||||
self.absolute_pos_embed = None
|
||||
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
|
||||
# stochastic depth
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
||||
|
||||
# build layers
|
||||
self.layers = nn.ModuleList()
|
||||
for i_layer in range(self.num_layers):
|
||||
layer = BasicLayer(
|
||||
dim=int(embed_dim * 2 ** i_layer),
|
||||
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
|
||||
layers = []
|
||||
in_dim = embed_dim[0]
|
||||
scale = 1
|
||||
for i in range(self.num_layers):
|
||||
out_dim = embed_dim[i]
|
||||
layers += [SwinTransformerV2Stage(
|
||||
dim=in_dim,
|
||||
out_dim=out_dim,
|
||||
input_resolution=(
|
||||
self.patch_embed.grid_size[0] // (2 ** i_layer),
|
||||
self.patch_embed.grid_size[1] // (2 ** i_layer)),
|
||||
depth=depths[i_layer],
|
||||
num_heads=num_heads[i_layer],
|
||||
self.patch_embed.grid_size[0] // scale,
|
||||
self.patch_embed.grid_size[1] // scale),
|
||||
depth=depths[i],
|
||||
downsample=i > 0,
|
||||
num_heads=num_heads[i],
|
||||
window_size=window_size,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
drop=drop_rate, attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
|
||||
drop_path=dpr[i],
|
||||
norm_layer=norm_layer,
|
||||
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
|
||||
pretrained_window_size=pretrained_window_sizes[i_layer]
|
||||
)
|
||||
self.layers.append(layer)
|
||||
pretrained_window_size=pretrained_window_sizes[i],
|
||||
)]
|
||||
in_dim = out_dim
|
||||
if i > 0:
|
||||
scale *= 2
|
||||
self.feature_info += [dict(num_chs=out_dim, reduction=4 * scale, module=f'layers.{i}')]
|
||||
|
||||
self.layers = nn.Sequential(*layers)
|
||||
self.norm = norm_layer(self.num_features)
|
||||
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
self.head = ClassifierHead(
|
||||
self.num_features,
|
||||
num_classes,
|
||||
pool_type=global_pool,
|
||||
drop_rate=drop_rate,
|
||||
input_fmt=self.output_fmt,
|
||||
)
|
||||
|
||||
self.apply(self._init_weights)
|
||||
for bly in self.layers:
|
||||
|
@ -563,7 +563,7 @@ class SwinTransformerV2(nn.Module):
|
|||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay(self):
|
||||
nod = {'absolute_pos_embed'}
|
||||
nod = set()
|
||||
for n, m in self.named_modules():
|
||||
if any([kw in n for kw in ("cpb_mlp", "logit_scale", 'relative_position_bias_table')]):
|
||||
nod.add(n)
|
||||
|
@ -587,31 +587,20 @@ class SwinTransformerV2(nn.Module):
|
|||
|
||||
@torch.jit.ignore
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
self.num_classes = num_classes
|
||||
if global_pool is not None:
|
||||
assert global_pool in ('', 'avg')
|
||||
self.global_pool = global_pool
|
||||
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
self.head.reset(num_classes, global_pool)
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.patch_embed(x)
|
||||
if self.absolute_pos_embed is not None:
|
||||
x = x + self.absolute_pos_embed
|
||||
x = self.pos_drop(x)
|
||||
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
|
||||
x = self.norm(x) # B L C
|
||||
x = self.layers(x)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
def forward_head(self, x, pre_logits: bool = False):
|
||||
if self.global_pool == 'avg':
|
||||
x = x.mean(dim=1)
|
||||
return x if pre_logits else self.head(x)
|
||||
return self.head(x, pre_logits=True) if pre_logits else self.head(x)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
|
@ -620,25 +609,102 @@ class SwinTransformerV2(nn.Module):
|
|||
|
||||
|
||||
def checkpoint_filter_fn(state_dict, model):
|
||||
state_dict = state_dict.get('model', state_dict)
|
||||
state_dict = state_dict.get('state_dict', state_dict)
|
||||
if 'head.fc.weight' in state_dict:
|
||||
return state_dict
|
||||
out_dict = {}
|
||||
if 'model' in state_dict:
|
||||
# For deit models
|
||||
state_dict = state_dict['model']
|
||||
import re
|
||||
for k, v in state_dict.items():
|
||||
if any([n in k for n in ('relative_position_index', 'relative_coords_table')]):
|
||||
continue # skip buffers that should not be persistent
|
||||
k = re.sub(r'layers.(\d+).downsample', lambda x: f'layers.{int(x.group(1)) + 1}.downsample', k)
|
||||
k = k.replace('head.', 'head.fc.')
|
||||
out_dict[k] = v
|
||||
return out_dict
|
||||
|
||||
|
||||
def _create_swin_transformer_v2(variant, pretrained=False, **kwargs):
|
||||
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 1, 1))))
|
||||
out_indices = kwargs.pop('out_indices', default_out_indices)
|
||||
|
||||
model = build_model_with_cfg(
|
||||
SwinTransformerV2, variant, pretrained,
|
||||
pretrained_filter_fn=checkpoint_filter_fn,
|
||||
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
|
||||
**kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 256, 256), 'pool_size': (8, 8),
|
||||
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'patch_embed.proj', 'classifier': 'head.fc',
|
||||
'license': 'mit', **kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = generate_default_cfgs({
|
||||
'swinv2_base_window12to16_192to256.ms_in22k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12to16_192to256_22kto1k_ft.pth',
|
||||
),
|
||||
'swinv2_base_window12to24_192to384.ms_in22k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12to24_192to384_22kto1k_ft.pth',
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0,
|
||||
),
|
||||
'swinv2_large_window12to16_192to256.ms_in22k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12to16_192to256_22kto1k_ft.pth',
|
||||
),
|
||||
'swinv2_large_window12to24_192to384.ms_in22k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12to24_192to384_22kto1k_ft.pth',
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0,
|
||||
),
|
||||
|
||||
'swinv2_tiny_window8_256.ms_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window8_256.pth',
|
||||
),
|
||||
'swinv2_tiny_window16_256.ms_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window16_256.pth',
|
||||
),
|
||||
'swinv2_small_window8_256.ms_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_small_patch4_window8_256.pth',
|
||||
),
|
||||
'swinv2_small_window16_256.ms_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_small_patch4_window16_256.pth',
|
||||
),
|
||||
'swinv2_base_window8_256.ms_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window8_256.pth',
|
||||
),
|
||||
'swinv2_base_window16_256.ms_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window16_256.pth',
|
||||
),
|
||||
|
||||
'swinv2_base_window12_192.ms_in22k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12_192_22k.pth',
|
||||
num_classes=21841, input_size=(3, 192, 192), pool_size=(6, 6)
|
||||
),
|
||||
'swinv2_large_window12_192.ms_in22k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12_192_22k.pth',
|
||||
num_classes=21841, input_size=(3, 192, 192), pool_size=(6, 6)
|
||||
),
|
||||
})
|
||||
|
||||
|
||||
@register_model
|
||||
def swinv2_tiny_window16_256(pretrained=False, **kwargs):
|
||||
"""
|
||||
|
@ -694,62 +760,72 @@ def swinv2_base_window8_256(pretrained=False, **kwargs):
|
|||
|
||||
|
||||
@register_model
|
||||
def swinv2_base_window12_192_22k(pretrained=False, **kwargs):
|
||||
def swinv2_base_window12_192(pretrained=False, **kwargs):
|
||||
"""
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
window_size=12, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs)
|
||||
return _create_swin_transformer_v2('swinv2_base_window12_192_22k', pretrained=pretrained, **model_kwargs)
|
||||
return _create_swin_transformer_v2('swinv2_base_window12_192', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def swinv2_base_window12to16_192to256_22kft1k(pretrained=False, **kwargs):
|
||||
def swinv2_base_window12to16_192to256(pretrained=False, **kwargs):
|
||||
"""
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
window_size=16, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32),
|
||||
pretrained_window_sizes=(12, 12, 12, 6), **kwargs)
|
||||
return _create_swin_transformer_v2(
|
||||
'swinv2_base_window12to16_192to256_22kft1k', pretrained=pretrained, **model_kwargs)
|
||||
'swinv2_base_window12to16_192to256', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def swinv2_base_window12to24_192to384_22kft1k(pretrained=False, **kwargs):
|
||||
def swinv2_base_window12to24_192to384(pretrained=False, **kwargs):
|
||||
"""
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
window_size=24, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32),
|
||||
pretrained_window_sizes=(12, 12, 12, 6), **kwargs)
|
||||
return _create_swin_transformer_v2(
|
||||
'swinv2_base_window12to24_192to384_22kft1k', pretrained=pretrained, **model_kwargs)
|
||||
'swinv2_base_window12to24_192to384', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def swinv2_large_window12_192_22k(pretrained=False, **kwargs):
|
||||
def swinv2_large_window12_192(pretrained=False, **kwargs):
|
||||
"""
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
window_size=12, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs)
|
||||
return _create_swin_transformer_v2('swinv2_large_window12_192_22k', pretrained=pretrained, **model_kwargs)
|
||||
return _create_swin_transformer_v2('swinv2_large_window12_192', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def swinv2_large_window12to16_192to256_22kft1k(pretrained=False, **kwargs):
|
||||
def swinv2_large_window12to16_192to256(pretrained=False, **kwargs):
|
||||
"""
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
window_size=16, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48),
|
||||
pretrained_window_sizes=(12, 12, 12, 6), **kwargs)
|
||||
return _create_swin_transformer_v2(
|
||||
'swinv2_large_window12to16_192to256_22kft1k', pretrained=pretrained, **model_kwargs)
|
||||
'swinv2_large_window12to16_192to256', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def swinv2_large_window12to24_192to384_22kft1k(pretrained=False, **kwargs):
|
||||
def swinv2_large_window12to24_192to384(pretrained=False, **kwargs):
|
||||
"""
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
window_size=24, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48),
|
||||
pretrained_window_sizes=(12, 12, 12, 6), **kwargs)
|
||||
return _create_swin_transformer_v2(
|
||||
'swinv2_large_window12to24_192to384_22kft1k', pretrained=pretrained, **model_kwargs)
|
||||
'swinv2_large_window12to24_192to384', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
register_model_deprecations(__name__, {
|
||||
'swinv2_base_window12_192_22k': 'swinv2_base_window12_192.ms_in22k',
|
||||
'swinv2_base_window12to16_192to256_22kft1k': 'swinv2_base_window12to16_192to256.ms_in22k_ft_in1k',
|
||||
'swinv2_base_window12to24_192to384_22kft1k': 'swinv2_base_window12to24_192to384.ms_in22k_ft_in1k',
|
||||
'swinv2_large_window12_192_22k': 'swinv2_large_window12_192.ms_in22k',
|
||||
'swinv2_large_window12to16_192to256_22kft1k': 'swinv2_large_window12to16_192to256.ms_in22k_ft_in1k',
|
||||
'swinv2_large_window12to24_192to384_22kft1k': 'swinv2_large_window12to24_192to384.ms_in22k_ft_in1k',
|
||||
})
|
||||
|
|
|
@ -37,71 +37,17 @@ import torch.nn.functional as F
|
|||
import torch.utils.checkpoint as checkpoint
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import DropPath, Mlp, to_2tuple, _assert
|
||||
from timm.layers import DropPath, Mlp, ClassifierHead, to_2tuple, _assert
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features_fx import register_notrace_function
|
||||
from ._manipulate import named_apply
|
||||
from ._registry import register_model
|
||||
from ._registry import generate_default_cfgs, register_model
|
||||
|
||||
__all__ = ['SwinTransformerV2Cr'] # model_registry will add each entrypoint fn to this
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000,
|
||||
'input_size': (3, 224, 224),
|
||||
'pool_size': (7, 7),
|
||||
'crop_pct': 0.9,
|
||||
'interpolation': 'bicubic',
|
||||
'fixed_input_size': True,
|
||||
'mean': IMAGENET_DEFAULT_MEAN,
|
||||
'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'patch_embed.proj',
|
||||
'classifier': 'head',
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
'swinv2_cr_tiny_384': _cfg(
|
||||
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
|
||||
'swinv2_cr_tiny_224': _cfg(
|
||||
url="", input_size=(3, 224, 224), crop_pct=0.9),
|
||||
'swinv2_cr_tiny_ns_224': _cfg(
|
||||
url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_tiny_ns_224-ba8166c6.pth",
|
||||
input_size=(3, 224, 224), crop_pct=0.9),
|
||||
'swinv2_cr_small_384': _cfg(
|
||||
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
|
||||
'swinv2_cr_small_224': _cfg(
|
||||
url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_small_224-0813c165.pth",
|
||||
input_size=(3, 224, 224), crop_pct=0.9),
|
||||
'swinv2_cr_small_ns_224': _cfg(
|
||||
url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_small_ns_224_iv-2ce90f8e.pth",
|
||||
input_size=(3, 224, 224), crop_pct=0.9),
|
||||
'swinv2_cr_base_384': _cfg(
|
||||
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
|
||||
'swinv2_cr_base_224': _cfg(
|
||||
url="", input_size=(3, 224, 224), crop_pct=0.9),
|
||||
'swinv2_cr_base_ns_224': _cfg(
|
||||
url="", input_size=(3, 224, 224), crop_pct=0.9),
|
||||
'swinv2_cr_large_384': _cfg(
|
||||
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
|
||||
'swinv2_cr_large_224': _cfg(
|
||||
url="", input_size=(3, 224, 224), crop_pct=0.9),
|
||||
'swinv2_cr_huge_384': _cfg(
|
||||
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
|
||||
'swinv2_cr_huge_224': _cfg(
|
||||
url="", input_size=(3, 224, 224), crop_pct=0.9),
|
||||
'swinv2_cr_giant_384': _cfg(
|
||||
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
|
||||
'swinv2_cr_giant_224': _cfg(
|
||||
url="", input_size=(3, 224, 224), crop_pct=0.9),
|
||||
}
|
||||
|
||||
|
||||
def bchw_to_bhwc(x: torch.Tensor) -> torch.Tensor:
|
||||
"""Permutes a tensor from the shape (B, C, H, W) to (B, H, W, C). """
|
||||
return x.permute(0, 2, 3, 1)
|
||||
|
@ -230,22 +176,14 @@ class WindowMultiHeadAttention(nn.Module):
|
|||
relative_position_bias = relative_position_bias.unsqueeze(0)
|
||||
return relative_position_bias
|
||||
|
||||
def _forward_sequential(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
"""
|
||||
# FIXME TODO figure out 'sequential' attention mentioned in paper (should reduce GPU memory)
|
||||
assert False, "not implemented"
|
||||
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
""" Forward pass.
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor of the shape (B * windows, N, C)
|
||||
mask (Optional[torch.Tensor]): Attention mask for the shift case
|
||||
|
||||
def _forward_batch(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""This function performs standard (non-sequential) scaled cosine self-attention.
|
||||
Returns:
|
||||
Output tensor of the shape [B * windows, N, C]
|
||||
"""
|
||||
Bw, L, C = x.shape
|
||||
|
||||
|
@ -272,22 +210,8 @@ class WindowMultiHeadAttention(nn.Module):
|
|||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
""" Forward pass.
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor of the shape (B * windows, N, C)
|
||||
mask (Optional[torch.Tensor]): Attention mask for the shift case
|
||||
|
||||
Returns:
|
||||
Output tensor of the shape [B * windows, N, C]
|
||||
"""
|
||||
if self.sequential_attn:
|
||||
return self._forward_sequential(x, mask)
|
||||
else:
|
||||
return self._forward_batch(x, mask)
|
||||
|
||||
|
||||
class SwinTransformerBlock(nn.Module):
|
||||
class SwinTransformerV2CrBlock(nn.Module):
|
||||
r"""This class implements the Swin transformer block.
|
||||
|
||||
Args:
|
||||
|
@ -321,7 +245,7 @@ class SwinTransformerBlock(nn.Module):
|
|||
sequential_attn: bool = False,
|
||||
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
||||
) -> None:
|
||||
super(SwinTransformerBlock, self).__init__()
|
||||
super(SwinTransformerV2CrBlock, self).__init__()
|
||||
self.dim: int = dim
|
||||
self.feat_size: Tuple[int, int] = feat_size
|
||||
self.target_shift_size: Tuple[int, int] = to_2tuple(shift_size)
|
||||
|
@ -410,9 +334,7 @@ class SwinTransformerBlock(nn.Module):
|
|||
self._make_attention_mask()
|
||||
|
||||
def _shifted_window_attn(self, x):
|
||||
H, W = self.feat_size
|
||||
B, L, C = x.shape
|
||||
x = x.view(B, H, W, C)
|
||||
B, H, W, C = x.shape
|
||||
|
||||
# cyclic shift
|
||||
sh, sw = self.shift_size
|
||||
|
@ -441,7 +363,6 @@ class SwinTransformerBlock(nn.Module):
|
|||
# x = torch.cat([x[:, :, -sw:], x[:, :, :-sw]], dim=2)
|
||||
x = torch.roll(x, shifts=(sh, sw), dims=(1, 2))
|
||||
|
||||
x = x.view(B, L, C)
|
||||
return x
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
|
@ -455,8 +376,12 @@ class SwinTransformerBlock(nn.Module):
|
|||
"""
|
||||
# post-norm branches (op -> norm -> drop)
|
||||
x = x + self.drop_path1(self.norm1(self._shifted_window_attn(x)))
|
||||
|
||||
B, H, W, C = x.shape
|
||||
x = x.reshape(B, -1, C)
|
||||
x = x + self.drop_path2(self.norm2(self.mlp(x)))
|
||||
x = self.norm3(x) # main-branch norm enabled for some blocks / stages (every 6 for Huge/Giant)
|
||||
x = x.reshape(B, H, W, C)
|
||||
return x
|
||||
|
||||
|
||||
|
@ -479,12 +404,10 @@ class PatchMerging(nn.Module):
|
|||
Returns:
|
||||
output (torch.Tensor): Output tensor of the shape [B, 2 * C, H // 2, W // 2]
|
||||
"""
|
||||
B, C, H, W = x.shape
|
||||
# unfold + BCHW -> BHWC together
|
||||
# ordering, 5, 3, 1 instead of 3, 5, 1 maintains compat with original swin v1 merge
|
||||
x = x.reshape(B, C, H // 2, 2, W // 2, 2).permute(0, 2, 4, 5, 3, 1).flatten(3)
|
||||
B, H, W, C = x.shape
|
||||
x = x.reshape(B, H // 2, 2, W // 2, 2, C).permute(0, 1, 3, 4, 2, 5).flatten(3)
|
||||
x = self.norm(x)
|
||||
x = bhwc_to_bchw(self.reduction(x))
|
||||
x = self.reduction(x)
|
||||
return x
|
||||
|
||||
|
||||
|
@ -511,7 +434,7 @@ class PatchEmbed(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
class SwinTransformerStage(nn.Module):
|
||||
class SwinTransformerV2CrStage(nn.Module):
|
||||
r"""This class implements a stage of the Swin transformer including multiple layers.
|
||||
|
||||
Args:
|
||||
|
@ -549,12 +472,16 @@ class SwinTransformerStage(nn.Module):
|
|||
extra_norm_stage: bool = False,
|
||||
sequential_attn: bool = False,
|
||||
) -> None:
|
||||
super(SwinTransformerStage, self).__init__()
|
||||
super(SwinTransformerV2CrStage, self).__init__()
|
||||
self.downscale: bool = downscale
|
||||
self.grad_checkpointing: bool = False
|
||||
self.feat_size: Tuple[int, int] = (feat_size[0] // 2, feat_size[1] // 2) if downscale else feat_size
|
||||
|
||||
self.downsample = PatchMerging(embed_dim, norm_layer=norm_layer) if downscale else nn.Identity()
|
||||
if downscale:
|
||||
self.downsample = PatchMerging(embed_dim, norm_layer=norm_layer)
|
||||
embed_dim = embed_dim * 2
|
||||
else:
|
||||
self.downsample = nn.Identity()
|
||||
|
||||
def _extra_norm(index):
|
||||
i = index + 1
|
||||
|
@ -562,9 +489,8 @@ class SwinTransformerStage(nn.Module):
|
|||
return True
|
||||
return i == depth if extra_norm_stage else False
|
||||
|
||||
embed_dim = embed_dim * 2 if downscale else embed_dim
|
||||
self.blocks = nn.Sequential(*[
|
||||
SwinTransformerBlock(
|
||||
SwinTransformerV2CrBlock(
|
||||
dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
feat_size=self.feat_size,
|
||||
|
@ -602,18 +528,15 @@ class SwinTransformerStage(nn.Module):
|
|||
Returns:
|
||||
output (torch.Tensor): Output tensor of the shape [B, 2 * C, H // 2, W // 2]
|
||||
"""
|
||||
x = bchw_to_bhwc(x)
|
||||
x = self.downsample(x)
|
||||
B, C, H, W = x.shape
|
||||
L = H * W
|
||||
|
||||
x = bchw_to_bhwc(x).reshape(B, L, C)
|
||||
for block in self.blocks:
|
||||
# Perform checkpointing if utilized
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
x = checkpoint.checkpoint(block, x)
|
||||
else:
|
||||
x = block(x)
|
||||
x = bhwc_to_bchw(x.reshape(B, H, W, -1))
|
||||
x = bhwc_to_bchw(x)
|
||||
return x
|
||||
|
||||
|
||||
|
@ -676,39 +599,54 @@ class SwinTransformerV2Cr(nn.Module):
|
|||
self.img_size: Tuple[int, int] = img_size
|
||||
self.window_size: int = window_size
|
||||
self.num_features: int = int(embed_dim * 2 ** (len(depths) - 1))
|
||||
self.feature_info = []
|
||||
|
||||
self.patch_embed = PatchEmbed(
|
||||
img_size=img_size, patch_size=patch_size, in_chans=in_chans,
|
||||
embed_dim=embed_dim, norm_layer=norm_layer)
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
embed_dim=embed_dim,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
patch_grid_size: Tuple[int, int] = self.patch_embed.grid_size
|
||||
|
||||
drop_path_rate = torch.linspace(0.0, drop_path_rate, sum(depths)).tolist()
|
||||
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
|
||||
stages = []
|
||||
for index, (depth, num_heads) in enumerate(zip(depths, num_heads)):
|
||||
stage_scale = 2 ** max(index - 1, 0)
|
||||
stages.append(
|
||||
SwinTransformerStage(
|
||||
embed_dim=embed_dim * stage_scale,
|
||||
depth=depth,
|
||||
downscale=index != 0,
|
||||
feat_size=(patch_grid_size[0] // stage_scale, patch_grid_size[1] // stage_scale),
|
||||
num_heads=num_heads,
|
||||
window_size=window_size,
|
||||
mlp_ratio=mlp_ratio,
|
||||
init_values=init_values,
|
||||
drop=drop_rate,
|
||||
drop_attn=attn_drop_rate,
|
||||
drop_path=drop_path_rate[sum(depths[:index]):sum(depths[:index + 1])],
|
||||
extra_norm_period=extra_norm_period,
|
||||
extra_norm_stage=extra_norm_stage or (index + 1) == len(depths), # last stage ends w/ norm
|
||||
sequential_attn=sequential_attn,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
)
|
||||
in_dim = embed_dim
|
||||
in_scale = 1
|
||||
for stage_idx, (depth, num_heads) in enumerate(zip(depths, num_heads)):
|
||||
stages += [SwinTransformerV2CrStage(
|
||||
embed_dim=in_dim,
|
||||
depth=depth,
|
||||
downscale=stage_idx != 0,
|
||||
feat_size=(
|
||||
patch_grid_size[0] // in_scale,
|
||||
patch_grid_size[1] // in_scale
|
||||
),
|
||||
num_heads=num_heads,
|
||||
window_size=window_size,
|
||||
mlp_ratio=mlp_ratio,
|
||||
init_values=init_values,
|
||||
drop=drop_rate,
|
||||
drop_attn=attn_drop_rate,
|
||||
drop_path=dpr[stage_idx],
|
||||
extra_norm_period=extra_norm_period,
|
||||
extra_norm_stage=extra_norm_stage or (stage_idx + 1) == len(depths), # last stage ends w/ norm
|
||||
sequential_attn=sequential_attn,
|
||||
norm_layer=norm_layer,
|
||||
)]
|
||||
if stage_idx != 0:
|
||||
in_dim *= 2
|
||||
in_scale *= 2
|
||||
self.feature_info += [dict(num_chs=in_dim, reduction=4 * in_scale, module=f'stages.{stage_idx}')]
|
||||
self.stages = nn.Sequential(*stages)
|
||||
|
||||
self.global_pool: str = global_pool
|
||||
self.head = nn.Linear(self.num_features, num_classes) if num_classes else nn.Identity()
|
||||
self.head = ClassifierHead(
|
||||
self.num_features,
|
||||
num_classes,
|
||||
pool_type=global_pool,
|
||||
drop_rate=drop_rate,
|
||||
)
|
||||
|
||||
# current weight init skips custom init and uses pytorch layer defaults, seems to work well
|
||||
# FIXME more experiments needed
|
||||
|
@ -765,7 +703,7 @@ class SwinTransformerV2Cr(nn.Module):
|
|||
Returns:
|
||||
head (nn.Module): Current classification head
|
||||
"""
|
||||
return self.head
|
||||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) -> None:
|
||||
"""Method results the classification head
|
||||
|
@ -774,10 +712,8 @@ class SwinTransformerV2Cr(nn.Module):
|
|||
num_classes (int): Number of classes to be predicted
|
||||
global_pool (str): Unused
|
||||
"""
|
||||
self.num_classes: int = num_classes
|
||||
if global_pool is not None:
|
||||
self.global_pool = global_pool
|
||||
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
self.num_classes = num_classes
|
||||
self.head.reset(num_classes, global_pool)
|
||||
|
||||
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.patch_embed(x)
|
||||
|
@ -785,9 +721,7 @@ class SwinTransformerV2Cr(nn.Module):
|
|||
return x
|
||||
|
||||
def forward_head(self, x, pre_logits: bool = False):
|
||||
if self.global_pool == 'avg':
|
||||
x = x.mean(dim=(2, 3))
|
||||
return x if pre_logits else self.head(x)
|
||||
return self.head(x, pre_logits=True) if pre_logits else self.head(x)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.forward_features(x)
|
||||
|
@ -814,30 +748,93 @@ def init_weights(module: nn.Module, name: str = ''):
|
|||
|
||||
def checkpoint_filter_fn(state_dict, model):
|
||||
""" convert patch embedding weight from manual patchify + linear proj to conv"""
|
||||
state_dict = state_dict.get('model', state_dict)
|
||||
state_dict = state_dict.get('state_dict', state_dict)
|
||||
if 'head.fc.weight' in state_dict:
|
||||
return state_dict
|
||||
out_dict = {}
|
||||
if 'model' in state_dict:
|
||||
# For deit models
|
||||
state_dict = state_dict['model']
|
||||
for k, v in state_dict.items():
|
||||
if 'tau' in k:
|
||||
# convert old tau based checkpoints -> logit_scale (inverse)
|
||||
v = torch.log(1 / v)
|
||||
k = k.replace('tau', 'logit_scale')
|
||||
k = k.replace('head.', 'head.fc.')
|
||||
out_dict[k] = v
|
||||
return out_dict
|
||||
|
||||
|
||||
def _create_swin_transformer_v2_cr(variant, pretrained=False, **kwargs):
|
||||
if kwargs.get('features_only', None):
|
||||
raise RuntimeError('features_only not implemented for Vision Transformer models.')
|
||||
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 1, 1))))
|
||||
out_indices = kwargs.pop('out_indices', default_out_indices)
|
||||
|
||||
model = build_model_with_cfg(
|
||||
SwinTransformerV2Cr, variant, pretrained,
|
||||
pretrained_filter_fn=checkpoint_filter_fn,
|
||||
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
|
||||
**kwargs
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000,
|
||||
'input_size': (3, 224, 224),
|
||||
'pool_size': (7, 7),
|
||||
'crop_pct': 0.9,
|
||||
'interpolation': 'bicubic',
|
||||
'fixed_input_size': True,
|
||||
'mean': IMAGENET_DEFAULT_MEAN,
|
||||
'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'patch_embed.proj',
|
||||
'classifier': 'head.fc',
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = generate_default_cfgs({
|
||||
'swinv2_cr_tiny_384.untrained': _cfg(
|
||||
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
|
||||
'swinv2_cr_tiny_224.untrained': _cfg(
|
||||
url="", input_size=(3, 224, 224), crop_pct=0.9),
|
||||
'swinv2_cr_tiny_ns_224.sw_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_tiny_ns_224-ba8166c6.pth",
|
||||
input_size=(3, 224, 224), crop_pct=0.9),
|
||||
'swinv2_cr_small_384.untrained': _cfg(
|
||||
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
|
||||
'swinv2_cr_small_224.sw_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_small_224-0813c165.pth",
|
||||
input_size=(3, 224, 224), crop_pct=0.9),
|
||||
'swinv2_cr_small_ns_224.sw_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_small_ns_224_iv-2ce90f8e.pth",
|
||||
input_size=(3, 224, 224), crop_pct=0.9),
|
||||
'swinv2_cr_small_ns_256.untrained': _cfg(
|
||||
url="", input_size=(3, 256, 256), crop_pct=1.0, pool_size=(8, 8)),
|
||||
'swinv2_cr_base_384.untrained': _cfg(
|
||||
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
|
||||
'swinv2_cr_base_224.untrained': _cfg(
|
||||
url="", input_size=(3, 224, 224), crop_pct=0.9),
|
||||
'swinv2_cr_base_ns_224.untrained': _cfg(
|
||||
url="", input_size=(3, 224, 224), crop_pct=0.9),
|
||||
'swinv2_cr_large_384.untrained': _cfg(
|
||||
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
|
||||
'swinv2_cr_large_224.untrained': _cfg(
|
||||
url="", input_size=(3, 224, 224), crop_pct=0.9),
|
||||
'swinv2_cr_huge_384.untrained': _cfg(
|
||||
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
|
||||
'swinv2_cr_huge_224.untrained': _cfg(
|
||||
url="", input_size=(3, 224, 224), crop_pct=0.9),
|
||||
'swinv2_cr_giant_384.untrained': _cfg(
|
||||
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
|
||||
'swinv2_cr_giant_224.untrained': _cfg(
|
||||
url="", input_size=(3, 224, 224), crop_pct=0.9),
|
||||
})
|
||||
|
||||
|
||||
@register_model
|
||||
def swinv2_cr_tiny_384(pretrained=False, **kwargs):
|
||||
"""Swin-T V2 CR @ 384x384, trained ImageNet-1k"""
|
||||
|
@ -915,6 +912,19 @@ def swinv2_cr_small_ns_224(pretrained=False, **kwargs):
|
|||
return _create_swin_transformer_v2_cr('swinv2_cr_small_ns_224', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def swinv2_cr_small_ns_256(pretrained=False, **kwargs):
|
||||
"""Swin-S V2 CR @ 256x256, trained ImageNet-1k"""
|
||||
model_kwargs = dict(
|
||||
embed_dim=96,
|
||||
depths=(2, 2, 18, 2),
|
||||
num_heads=(3, 6, 12, 24),
|
||||
extra_norm_stage=True,
|
||||
**kwargs
|
||||
)
|
||||
return _create_swin_transformer_v2_cr('swinv2_cr_small_ns_256', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def swinv2_cr_base_384(pretrained=False, **kwargs):
|
||||
"""Swin-B V2 CR @ 384x384, trained ImageNet-1k"""
|
||||
|
@ -961,8 +971,7 @@ def swinv2_cr_large_384(pretrained=False, **kwargs):
|
|||
num_heads=(6, 12, 24, 48),
|
||||
**kwargs
|
||||
)
|
||||
return _create_swin_transformer_v2_cr('swinv2_cr_large_384', pretrained=pretrained, **model_kwargs
|
||||
)
|
||||
return _create_swin_transformer_v2_cr('swinv2_cr_large_384', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
|
|
|
@ -41,9 +41,7 @@ from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_,
|
|||
resample_abs_pos_embed, RmsNorm
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv
|
||||
from ._pretrained import generate_default_cfgs
|
||||
from ._registry import register_model
|
||||
|
||||
from ._registry import generate_default_cfgs, register_model
|
||||
|
||||
__all__ = ['VisionTransformer'] # model_registry will add each entrypoint fn to this
|
||||
|
||||
|
|
|
@ -20,8 +20,7 @@ import torch.nn as nn
|
|||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import StdConv2dSame, StdConv2d, to_2tuple
|
||||
from ._pretrained import generate_default_cfgs
|
||||
from ._registry import register_model
|
||||
from ._registry import generate_default_cfgs, register_model
|
||||
from .resnet import resnet26d, resnet50d
|
||||
from .resnetv2 import ResNetV2, create_resnetv2_stem
|
||||
from .vision_transformer import _create_vision_transformer
|
||||
|
|
|
@ -17,8 +17,7 @@ from torch.utils.checkpoint import checkpoint
|
|||
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
from timm.layers import PatchEmbed, Mlp, DropPath, RelPosMlp, RelPosBias
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._pretrained import generate_default_cfgs
|
||||
from ._registry import register_model
|
||||
from ._registry import generate_default_cfgs, register_model
|
||||
|
||||
__all__ = ['VisionTransformerRelPos'] # model_registry will add each entrypoint fn to this
|
||||
|
||||
|
|
|
@ -1 +1 @@
|
|||
__version__ = '0.8.15dev0'
|
||||
__version__ = '0.8.16dev0'
|
||||
|
|
|
@ -255,8 +255,7 @@ def validate(args):
|
|||
|
||||
if args.valid_labels:
|
||||
with open(args.valid_labels, 'r') as f:
|
||||
valid_labels = {int(line.rstrip()) for line in f}
|
||||
valid_labels = [i in valid_labels for i in range(args.num_classes)]
|
||||
valid_labels = [int(line.rstrip()) for line in f]
|
||||
else:
|
||||
valid_labels = None
|
||||
|
||||
|
|
Loading…
Reference in New Issue