Merge pull request #1628 from huggingface/focalnet_and_swin_refactor

Add FocalNet arch, refactor Swin V1/V2 for better feature extraction and HF hub multi-weight support
pull/1741/head
Ross Wightman 2023-03-18 20:09:36 -07:00 committed by GitHub
commit 0d5c5c39fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
45 changed files with 57846 additions and 1146 deletions

View File

@ -27,7 +27,8 @@ except ImportError:
import timm
from timm import list_models, create_model, set_scriptable, get_pretrained_cfg_value
from timm.models._features_fx import _leaf_modules, _autowrap_functions
from timm.layers import Format, get_spatial_dim, get_channel_dim
from timm.models import get_notrace_modules, get_notrace_functions
if hasattr(torch._C, '_jit_set_profiling_executor'):
# legacy executor is too slow to compile large models for unit tests
@ -37,10 +38,10 @@ if hasattr(torch._C, '_jit_set_profiling_executor'):
# transformer models don't support many of the spatial / feature based model functionalities
NON_STD_FILTERS = [
'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit*',
'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*',
'eva_*', 'flexivit*'
'poolformer_*', 'volo_*', 'sequencer2d_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*',
'eva_*', 'flexivit*',
]
NUM_NON_STD = len(NON_STD_FILTERS)
@ -52,7 +53,7 @@ if 'GITHUB_ACTIONS' in os.environ:
'*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', '*efficientnetv2_xl*',
'*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', 'vit_huge*', 'vit_gi*', 'swin*huge*',
'swin*giant*', 'convnextv2_huge*', 'maxvit_xlarge*', 'davit_giant', 'davit_huge']
NON_STD_EXCLUDE_FILTERS = ['vit_huge*', 'vit_gi*', 'swin*giant*', 'eva_giant*']
NON_STD_EXCLUDE_FILTERS = ['vit_huge*', 'vit_gi*', 'eva_giant*']
else:
EXCLUDE_FILTERS = []
NON_STD_EXCLUDE_FILTERS = ['vit_gi*']
@ -145,7 +146,8 @@ def test_model_backward(model_name, batch_size):
@pytest.mark.cfg
@pytest.mark.timeout(300)
@pytest.mark.parametrize('model_name', list_models(exclude_filters=NON_STD_FILTERS, include_tags=True))
@pytest.mark.parametrize('model_name', list_models(
exclude_filters=EXCLUDE_FILTERS + NON_STD_FILTERS, include_tags=True))
@pytest.mark.parametrize('batch_size', [1])
def test_model_default_cfgs(model_name, batch_size):
"""Run a single forward pass with each model"""
@ -156,6 +158,10 @@ def test_model_default_cfgs(model_name, batch_size):
pool_size = cfg['pool_size']
input_size = model.default_cfg['input_size']
output_fmt = getattr(model, 'output_fmt', 'NCHW')
spatial_axis = get_spatial_dim(output_fmt)
assert len(spatial_axis) == 2 # TODO add 1D sequence support
feat_axis = get_channel_dim(output_fmt)
if all([x <= MAX_FWD_OUT_SIZE for x in input_size]) and \
not any([fnmatch.fnmatch(model_name, x) for x in EXCLUDE_FILTERS]):
@ -165,13 +171,16 @@ def test_model_default_cfgs(model_name, batch_size):
# test forward_features (always unpooled)
outputs = model.forward_features(input_tensor)
assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2], 'unpooled feature shape != config'
assert outputs.shape[spatial_axis[0]] == pool_size[0], 'unpooled feature shape != config'
assert outputs.shape[spatial_axis[1]] == pool_size[1], 'unpooled feature shape != config'
if not isinstance(model, (timm.models.MobileNetV3, timm.models.GhostNet, timm.models.VGG)):
assert outputs.shape[feat_axis] == model.num_features
# test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features
model.reset_classifier(0)
outputs = model.forward(input_tensor)
assert len(outputs.shape) == 2
assert outputs.shape[-1] == model.num_features
assert outputs.shape[1] == model.num_features
# test model forward without pooling and classifier
model.reset_classifier(0, '') # reset classifier and set global pooling to pass-through
@ -179,7 +188,7 @@ def test_model_default_cfgs(model_name, batch_size):
assert len(outputs.shape) == 4
if not isinstance(model, (timm.models.MobileNetV3, timm.models.GhostNet, timm.models.VGG)):
# mobilenetv3/ghostnet/vgg forward_features vs removed pooling differ due to location or lack of GAP
assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2]
assert outputs.shape[spatial_axis[0]] == pool_size[0] and outputs.shape[spatial_axis[1]] == pool_size[1]
if 'pruned' not in model_name: # FIXME better pruned model handling
# test classifier + global pool deletion via __init__
@ -187,7 +196,7 @@ def test_model_default_cfgs(model_name, batch_size):
outputs = model.forward(input_tensor)
assert len(outputs.shape) == 4
if not isinstance(model, (timm.models.MobileNetV3, timm.models.GhostNet, timm.models.VGG)):
assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2]
assert outputs.shape[spatial_axis[0]] == pool_size[0] and outputs.shape[spatial_axis[1]] == pool_size[1]
# check classifier name matches default_cfg
if cfg.get('num_classes', None):
@ -330,176 +339,181 @@ def test_model_forward_features(model_name, batch_size):
model = create_model(model_name, pretrained=False, features_only=True)
model.eval()
expected_channels = model.feature_info.channels()
expected_reduction = model.feature_info.reduction()
assert len(expected_channels) >= 4 # all models here should have at least 4 feature levels by default, some 5 or 6
input_size = _get_input_size(model=model, target=TARGET_FFEAT_SIZE)
if max(input_size) > MAX_FFEAT_SIZE:
pytest.skip("Fixed input size model > limit.")
output_fmt = getattr(model, 'output_fmt', 'NCHW')
feat_axis = get_channel_dim(output_fmt)
spatial_axis = get_spatial_dim(output_fmt)
import math
outputs = model(torch.randn((batch_size, *input_size)))
assert len(expected_channels) == len(outputs)
for e, o in zip(expected_channels, outputs):
assert e == o.shape[1]
spatial_size = input_size[-2:]
for e, r, o in zip(expected_channels, expected_reduction, outputs):
assert e == o.shape[feat_axis]
assert o.shape[spatial_axis[0]] <= math.ceil(spatial_size[0] / r) + 1
assert o.shape[spatial_axis[1]] <= math.ceil(spatial_size[1] / r) + 1
assert o.shape[0] == batch_size
assert not torch.isnan(o).any()
if not _IS_MAC:
# MACOS test runners are really slow, only running tests below this point if not on a Darwin runner...
def _create_fx_model(model, train=False):
# This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode
# So we trace once and look at the graph, and get the indices of the nodes that lead into the original fx output
# node. Then we use those indices to select from train_nodes returned by torchvision get_graph_node_names
tracer_kwargs = dict(
leaf_modules=get_notrace_modules(),
autowrap_functions=get_notrace_functions(),
#enable_cpatching=True,
param_shapes_constant=True
)
train_nodes, eval_nodes = get_graph_node_names(model, tracer_kwargs=tracer_kwargs)
def _create_fx_model(model, train=False):
# This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode
# So we trace once and look at the graph, and get the indices of the nodes that lead into the original fx output
# node. Then we use those indices to select from train_nodes returned by torchvision get_graph_node_names
tracer_kwargs = dict(
leaf_modules=list(_leaf_modules),
autowrap_functions=list(_autowrap_functions),
#enable_cpatching=True,
param_shapes_constant=True
)
train_nodes, eval_nodes = get_graph_node_names(model, tracer_kwargs=tracer_kwargs)
eval_return_nodes = [eval_nodes[-1]]
train_return_nodes = [train_nodes[-1]]
if train:
tracer = NodePathTracer(**tracer_kwargs)
graph = tracer.trace(model)
graph_nodes = list(reversed(graph.nodes))
output_node_names = [n.name for n in graph_nodes[0]._input_nodes.keys()]
graph_node_names = [n.name for n in graph_nodes]
output_node_indices = [-graph_node_names.index(node_name) for node_name in output_node_names]
train_return_nodes = [train_nodes[ix] for ix in output_node_indices]
eval_return_nodes = [eval_nodes[-1]]
train_return_nodes = [train_nodes[-1]]
if train:
tracer = NodePathTracer(**tracer_kwargs)
graph = tracer.trace(model)
graph_nodes = list(reversed(graph.nodes))
output_node_names = [n.name for n in graph_nodes[0]._input_nodes.keys()]
graph_node_names = [n.name for n in graph_nodes]
output_node_indices = [-graph_node_names.index(node_name) for node_name in output_node_names]
train_return_nodes = [train_nodes[ix] for ix in output_node_indices]
fx_model = create_feature_extractor(
model,
train_return_nodes=train_return_nodes,
eval_return_nodes=eval_return_nodes,
tracer_kwargs=tracer_kwargs,
)
return fx_model
fx_model = create_feature_extractor(
model,
train_return_nodes=train_return_nodes,
eval_return_nodes=eval_return_nodes,
tracer_kwargs=tracer_kwargs,
)
return fx_model
EXCLUDE_FX_FILTERS = ['vit_gi*']
# not enough memory to run fx on more models than other tests
if 'GITHUB_ACTIONS' in os.environ:
EXCLUDE_FX_FILTERS += [
'beit_large*',
'mixer_l*',
'*nfnet_f2*',
'*resnext101_32x32d',
'resnetv2_152x2*',
'resmlp_big*',
'resnetrs270',
'swin_large*',
'vgg*',
'vit_large*',
'vit_base_patch8*',
'xcit_large*',
]
EXCLUDE_FX_FILTERS = ['vit_gi*']
# not enough memory to run fx on more models than other tests
if 'GITHUB_ACTIONS' in os.environ:
EXCLUDE_FX_FILTERS += [
'beit_large*',
'mixer_l*',
'*nfnet_f2*',
'*resnext101_32x32d',
'resnetv2_152x2*',
'resmlp_big*',
'resnetrs270',
'swin_large*',
'vgg*',
'vit_large*',
'vit_base_patch8*',
'xcit_large*',
]
@pytest.mark.fxforward
@pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FX_FILTERS))
@pytest.mark.parametrize('batch_size', [1])
def test_model_forward_fx(model_name, batch_size):
"""
Symbolically trace each model and run single forward pass through the resulting GraphModule
Also check that the output of a forward pass through the GraphModule is the same as that from the original Module
"""
if not has_fx_feature_extraction:
pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.")
model = create_model(model_name, pretrained=False)
model.eval()
input_size = _get_input_size(model=model, target=TARGET_FWD_FX_SIZE)
if max(input_size) > MAX_FWD_FX_SIZE:
pytest.skip("Fixed input size model > limit.")
with torch.no_grad():
inputs = torch.randn((batch_size, *input_size))
outputs = model(inputs)
if isinstance(outputs, tuple):
outputs = torch.cat(outputs)
model = _create_fx_model(model)
fx_outputs = tuple(model(inputs).values())
if isinstance(fx_outputs, tuple):
fx_outputs = torch.cat(fx_outputs)
assert torch.all(fx_outputs == outputs)
assert outputs.shape[0] == batch_size
assert not torch.isnan(outputs).any(), 'Output included NaNs'
@pytest.mark.fxbackward
@pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(
exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FX_FILTERS, name_matches_cfg=True))
@pytest.mark.parametrize('batch_size', [2])
def test_model_backward_fx(model_name, batch_size):
"""Symbolically trace each model and run single backward pass through the resulting GraphModule"""
if not has_fx_feature_extraction:
pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.")
input_size = _get_input_size(model_name=model_name, target=TARGET_BWD_FX_SIZE)
if max(input_size) > MAX_BWD_FX_SIZE:
pytest.skip("Fixed input size model > limit.")
model = create_model(model_name, pretrained=False, num_classes=42)
model.train()
num_params = sum([x.numel() for x in model.parameters()])
if 'GITHUB_ACTIONS' in os.environ and num_params > 100e6:
pytest.skip("Skipping FX backward test on model with more than 100M params.")
model = _create_fx_model(model, train=True)
outputs = tuple(model(torch.randn((batch_size, *input_size))).values())
if isinstance(outputs, tuple):
outputs = torch.cat(outputs)
outputs.mean().backward()
for n, x in model.named_parameters():
assert x.grad is not None, f'No gradient for {n}'
num_grad = sum([x.grad.numel() for x in model.parameters() if x.grad is not None])
assert outputs.shape[-1] == 42
assert num_params == num_grad, 'Some parameters are missing gradients'
assert not torch.isnan(outputs).any(), 'Output included NaNs'
if 'GITHUB_ACTIONS' not in os.environ:
# FIXME this test is causing GitHub actions to run out of RAM and abruptly kill the test process
# reason: model is scripted after fx tracing, but beit has torch.jit.is_scripting() control flow
EXCLUDE_FX_JIT_FILTERS = [
'deit_*_distilled_patch16_224',
'levit*',
'pit_*_distilled_224',
] + EXCLUDE_FX_FILTERS
@pytest.mark.fxforward
@pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FX_FILTERS))
@pytest.mark.parametrize(
'model_name', list_models(
exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS + EXCLUDE_FX_JIT_FILTERS, name_matches_cfg=True))
@pytest.mark.parametrize('batch_size', [1])
def test_model_forward_fx(model_name, batch_size):
"""
Symbolically trace each model and run single forward pass through the resulting GraphModule
Also check that the output of a forward pass through the GraphModule is the same as that from the original Module
"""
def test_model_forward_fx_torchscript(model_name, batch_size):
"""Symbolically trace each model, script it, and run single forward pass"""
if not has_fx_feature_extraction:
pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.")
model = create_model(model_name, pretrained=False)
input_size = _get_input_size(model_name=model_name, target=TARGET_JIT_SIZE)
if max(input_size) > MAX_JIT_SIZE:
pytest.skip("Fixed input size model > limit.")
with set_scriptable(True):
model = create_model(model_name, pretrained=False)
model.eval()
input_size = _get_input_size(model=model, target=TARGET_FWD_FX_SIZE)
if max(input_size) > MAX_FWD_FX_SIZE:
pytest.skip("Fixed input size model > limit.")
model = torch.jit.script(_create_fx_model(model))
with torch.no_grad():
inputs = torch.randn((batch_size, *input_size))
outputs = model(inputs)
outputs = tuple(model(torch.randn((batch_size, *input_size))).values())
if isinstance(outputs, tuple):
outputs = torch.cat(outputs)
model = _create_fx_model(model)
fx_outputs = tuple(model(inputs).values())
if isinstance(fx_outputs, tuple):
fx_outputs = torch.cat(fx_outputs)
assert torch.all(fx_outputs == outputs)
assert outputs.shape[0] == batch_size
assert not torch.isnan(outputs).any(), 'Output included NaNs'
@pytest.mark.fxbackward
@pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(
exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FX_FILTERS, name_matches_cfg=True))
@pytest.mark.parametrize('batch_size', [2])
def test_model_backward_fx(model_name, batch_size):
"""Symbolically trace each model and run single backward pass through the resulting GraphModule"""
if not has_fx_feature_extraction:
pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.")
input_size = _get_input_size(model_name=model_name, target=TARGET_BWD_FX_SIZE)
if max(input_size) > MAX_BWD_FX_SIZE:
pytest.skip("Fixed input size model > limit.")
model = create_model(model_name, pretrained=False, num_classes=42)
model.train()
num_params = sum([x.numel() for x in model.parameters()])
if 'GITHUB_ACTIONS' in os.environ and num_params > 100e6:
pytest.skip("Skipping FX backward test on model with more than 100M params.")
model = _create_fx_model(model, train=True)
outputs = tuple(model(torch.randn((batch_size, *input_size))).values())
if isinstance(outputs, tuple):
outputs = torch.cat(outputs)
outputs.mean().backward()
for n, x in model.named_parameters():
assert x.grad is not None, f'No gradient for {n}'
num_grad = sum([x.grad.numel() for x in model.parameters() if x.grad is not None])
assert outputs.shape[-1] == 42
assert num_params == num_grad, 'Some parameters are missing gradients'
assert not torch.isnan(outputs).any(), 'Output included NaNs'
if 'GITHUB_ACTIONS' not in os.environ:
# FIXME this test is causing GitHub actions to run out of RAM and abruptly kill the test process
# reason: model is scripted after fx tracing, but beit has torch.jit.is_scripting() control flow
EXCLUDE_FX_JIT_FILTERS = [
'deit_*_distilled_patch16_224',
'levit*',
'pit_*_distilled_224',
] + EXCLUDE_FX_FILTERS
@pytest.mark.timeout(120)
@pytest.mark.parametrize(
'model_name', list_models(
exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS + EXCLUDE_FX_JIT_FILTERS, name_matches_cfg=True))
@pytest.mark.parametrize('batch_size', [1])
def test_model_forward_fx_torchscript(model_name, batch_size):
"""Symbolically trace each model, script it, and run single forward pass"""
if not has_fx_feature_extraction:
pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.")
input_size = _get_input_size(model_name=model_name, target=TARGET_JIT_SIZE)
if max(input_size) > MAX_JIT_SIZE:
pytest.skip("Fixed input size model > limit.")
with set_scriptable(True):
model = create_model(model_name, pretrained=False)
model.eval()
model = torch.jit.script(_create_fx_model(model))
with torch.no_grad():
outputs = tuple(model(torch.randn((batch_size, *input_size))).values())
if isinstance(outputs, tuple):
outputs = torch.cat(outputs)
assert outputs.shape[0] == batch_size
assert not torch.isnan(outputs).any(), 'Output included NaNs'

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -88,6 +88,7 @@ class IterableImageDataset(data.IterableDataset):
root,
reader=None,
split='train',
class_map=None,
is_training=False,
batch_size=None,
seed=42,
@ -102,6 +103,7 @@ class IterableImageDataset(data.IterableDataset):
reader,
root=root,
split=split,
class_map=class_map,
is_training=is_training,
batch_size=batch_size,
seed=seed,

View File

@ -157,6 +157,7 @@ def create_dataset(
root,
reader=name,
split=split,
class_map=class_map,
is_training=is_training,
download=download,
batch_size=batch_size,
@ -169,6 +170,7 @@ def create_dataset(
root,
reader=name,
split=split,
class_map=class_map,
is_training=is_training,
batch_size=batch_size,
repeats=repeats,

View File

@ -7,12 +7,14 @@ from typing import Dict, List, Optional, Union
from .dataset_info import DatasetInfo
# NOTE no ambiguity wrt to mapping from # classes to ImageNet subset so far, but likely to change
_NUM_CLASSES_TO_SUBSET = {
1000: 'imagenet-1k',
11821: 'imagenet-12k',
21841: 'imagenet-22k',
21843: 'imagenet-21k-goog',
11221: 'imagenet-21k-miil',
11221: 'imagenet-21k-miil', # miil subset of fall11
11821: 'imagenet-12k', # timm specific 12k subset of fall11
21841: 'imagenet-22k', # as in fall11.tar
21842: 'imagenet-22k-ms', # a Microsoft (for FocalNet) remapping of 22k w/ moves ImageNet-1k classes to first 1000
21843: 'imagenet-21k-goog', # Google's ImageNet full has two classes not in fall11
}
_SUBSETS = {
@ -22,6 +24,7 @@ _SUBSETS = {
'imagenet21k': 'imagenet21k_goog_synsets.txt',
'imagenet21kgoog': 'imagenet21k_goog_synsets.txt',
'imagenet21kmiil': 'imagenet21k_miil_synsets.txt',
'imagenet22kms': 'imagenet22k_ms_synsets.txt',
}
_LEMMA_FILE = 'imagenet_synset_to_lemma.txt'
_DEFINITION_FILE = 'imagenet_synset_to_definition.txt'

View File

@ -34,6 +34,7 @@ except ImportError as e:
print("Please install tensorflow_datasets package `pip install tensorflow-datasets`.")
exit(1)
from .class_map import load_class_map
from .reader import Reader
from .shared_count import SharedCount
@ -94,6 +95,7 @@ class ReaderTfds(Reader):
root,
name,
split='train',
class_map=None,
is_training=False,
batch_size=None,
download=False,
@ -151,7 +153,12 @@ class ReaderTfds(Reader):
# NOTE: the tfds command line app can be used download & prepare datasets if you don't enable download flag
if download:
self.builder.download_and_prepare()
self.class_to_idx = get_class_labels(self.builder.info) if self.target_name == 'label' else {}
self.remap_class = False
if class_map:
self.class_to_idx = load_class_map(class_map)
self.remap_class = True
else:
self.class_to_idx = get_class_labels(self.builder.info) if self.target_name == 'label' else {}
self.split_info = self.builder.info.splits[split]
self.num_samples = self.split_info.num_examples
@ -299,6 +306,8 @@ class ReaderTfds(Reader):
target_data = sample[self.target_name]
if self.target_img_mode:
target_data = Image.fromarray(target_data, mode=self.target_img_mode)
elif self.remap_class:
target_data = self.class_to_idx[target_data]
yield input_data, target_data
sample_count += 1
if self.is_training and sample_count >= target_sample_count:

View File

@ -29,6 +29,7 @@ except ImportError:
wds = None
expand_urls = None
from .class_map import load_class_map
from .reader import Reader
from .shared_count import SharedCount
@ -42,13 +43,13 @@ def _load_info(root, basename='info'):
info_yaml = os.path.join(root, basename + '.yaml')
err_str = ''
try:
with wds.gopen.gopen(info_json) as f:
with wds.gopen(info_json) as f:
info_dict = json.load(f)
return info_dict
except Exception as e:
err_str = str(e)
try:
with wds.gopen.gopen(info_yaml) as f:
with wds.gopen(info_yaml) as f:
info_dict = yaml.safe_load(f)
return info_dict
except Exception:
@ -110,8 +111,8 @@ def _parse_split_info(split: str, info: Dict):
filenames=split_filenames,
)
else:
if split not in info['splits']:
raise RuntimeError(f"split {split} not found in info ({info['splits'].keys()})")
if 'splits' not in info or split not in info['splits']:
raise RuntimeError(f"split {split} not found in info ({info.get('splits', {}).keys()})")
split = split
split_info = info['splits'][split]
split_info = _info_convert(split_info)
@ -290,6 +291,7 @@ class ReaderWds(Reader):
batch_size=None,
repeats=0,
seed=42,
class_map=None,
input_name='jpg',
input_image='RGB',
target_name='cls',
@ -320,6 +322,12 @@ class ReaderWds(Reader):
self.num_samples = self.split_info.num_samples
if not self.num_samples:
raise RuntimeError(f'Invalid split definition, no samples found.')
self.remap_class = False
if class_map:
self.class_to_idx = load_class_map(class_map)
self.remap_class = True
else:
self.class_to_idx = {}
# Distributed world state
self.dist_rank = 0
@ -431,7 +439,10 @@ class ReaderWds(Reader):
i = 0
# _logger.info(f'start {i}, {self.worker_id}') # FIXME temporary debug
for sample in ds:
yield sample[self.image_key], sample[self.target_key]
target = sample[self.target_key]
if self.remap_class:
target = self.class_to_idx[target]
yield sample[self.image_key], target
i += 1
# _logger.info(f'end {i}, {self.worker_id}') # FIXME temporary debug

View File

@ -20,6 +20,7 @@ from .evo_norm import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2,\
EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a
from .fast_norm import is_fast_norm, set_fast_norm, fast_group_norm, fast_layer_norm
from .filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d
from .format import Format, get_channel_dim, get_spatial_dim, nchw_to, nhwc_to
from .gather_excite import GatherExcite
from .global_context import GlobalContext
from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple

View File

@ -9,31 +9,37 @@ Both a functional and a nn.Module version of the pooling is provided.
Hacked together by / Copyright 2020 Ross Wightman
"""
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from .format import get_spatial_dim, get_channel_dim
_int_tuple_2_t = Union[int, Tuple[int, int]]
def adaptive_pool_feat_mult(pool_type='avg'):
if pool_type == 'catavgmax':
if pool_type.endswith('catavgmax'):
return 2
else:
return 1
def adaptive_avgmax_pool2d(x, output_size=1):
def adaptive_avgmax_pool2d(x, output_size: _int_tuple_2_t = 1):
x_avg = F.adaptive_avg_pool2d(x, output_size)
x_max = F.adaptive_max_pool2d(x, output_size)
return 0.5 * (x_avg + x_max)
def adaptive_catavgmax_pool2d(x, output_size=1):
def adaptive_catavgmax_pool2d(x, output_size: _int_tuple_2_t = 1):
x_avg = F.adaptive_avg_pool2d(x, output_size)
x_max = F.adaptive_max_pool2d(x, output_size)
return torch.cat((x_avg, x_max), 1)
def select_adaptive_pool2d(x, pool_type='avg', output_size=1):
def select_adaptive_pool2d(x, pool_type='avg', output_size: _int_tuple_2_t = 1):
"""Selectable global pooling function with dynamic input kernel size
"""
if pool_type == 'avg':
@ -49,17 +55,56 @@ def select_adaptive_pool2d(x, pool_type='avg', output_size=1):
return x
class FastAdaptiveAvgPool2d(nn.Module):
def __init__(self, flatten=False):
super(FastAdaptiveAvgPool2d, self).__init__()
class FastAdaptiveAvgPool(nn.Module):
def __init__(self, flatten: bool = False, input_fmt: F = 'NCHW'):
super(FastAdaptiveAvgPool, self).__init__()
self.flatten = flatten
self.dim = get_spatial_dim(input_fmt)
def forward(self, x):
return x.mean((2, 3), keepdim=not self.flatten)
return x.mean(self.dim, keepdim=not self.flatten)
class FastAdaptiveMaxPool(nn.Module):
def __init__(self, flatten: bool = False, input_fmt: str = 'NCHW'):
super(FastAdaptiveMaxPool, self).__init__()
self.flatten = flatten
self.dim = get_spatial_dim(input_fmt)
def forward(self, x):
return x.amax(self.dim, keepdim=not self.flatten)
class FastAdaptiveAvgMaxPool(nn.Module):
def __init__(self, flatten: bool = False, input_fmt: str = 'NCHW'):
super(FastAdaptiveAvgMaxPool, self).__init__()
self.flatten = flatten
self.dim = get_spatial_dim(input_fmt)
def forward(self, x):
x_avg = x.mean(self.dim, keepdim=not self.flatten)
x_max = x.amax(self.dim, keepdim=not self.flatten)
return 0.5 * x_avg + 0.5 * x_max
class FastAdaptiveCatAvgMaxPool(nn.Module):
def __init__(self, flatten: bool = False, input_fmt: str = 'NCHW'):
super(FastAdaptiveCatAvgMaxPool, self).__init__()
self.flatten = flatten
self.dim_reduce = get_spatial_dim(input_fmt)
if flatten:
self.dim_cat = 1
else:
self.dim_cat = get_channel_dim(input_fmt)
def forward(self, x):
x_avg = x.mean(self.dim_reduce, keepdim=not self.flatten)
x_max = x.amax(self.dim_reduce, keepdim=not self.flatten)
return torch.cat((x_avg, x_max), self.dim_cat)
class AdaptiveAvgMaxPool2d(nn.Module):
def __init__(self, output_size=1):
def __init__(self, output_size: _int_tuple_2_t = 1):
super(AdaptiveAvgMaxPool2d, self).__init__()
self.output_size = output_size
@ -68,7 +113,7 @@ class AdaptiveAvgMaxPool2d(nn.Module):
class AdaptiveCatAvgMaxPool2d(nn.Module):
def __init__(self, output_size=1):
def __init__(self, output_size: _int_tuple_2_t = 1):
super(AdaptiveCatAvgMaxPool2d, self).__init__()
self.output_size = output_size
@ -79,26 +124,41 @@ class AdaptiveCatAvgMaxPool2d(nn.Module):
class SelectAdaptivePool2d(nn.Module):
"""Selectable global pooling layer with dynamic input kernel size
"""
def __init__(self, output_size=1, pool_type='fast', flatten=False):
def __init__(
self,
output_size: _int_tuple_2_t = 1,
pool_type: str = 'fast',
flatten: bool = False,
input_fmt: str = 'NCHW',
):
super(SelectAdaptivePool2d, self).__init__()
assert input_fmt in ('NCHW', 'NHWC')
self.pool_type = pool_type or '' # convert other falsy values to empty string for consistent TS typing
self.flatten = nn.Flatten(1) if flatten else nn.Identity()
if pool_type == '':
if not pool_type:
self.pool = nn.Identity() # pass through
elif pool_type == 'fast':
assert output_size == 1
self.pool = FastAdaptiveAvgPool2d(flatten)
self.flatten = nn.Flatten(1) if flatten else nn.Identity()
elif pool_type.startswith('fast') or input_fmt != 'NCHW':
assert output_size == 1, 'Fast pooling and non NCHW input formats require output_size == 1.'
if pool_type.endswith('avgmax'):
self.pool = FastAdaptiveAvgMaxPool(flatten, input_fmt=input_fmt)
elif pool_type.endswith('catavgmax'):
self.pool = FastAdaptiveCatAvgMaxPool(flatten, input_fmt=input_fmt)
elif pool_type.endswith('max'):
self.pool = FastAdaptiveMaxPool(flatten, input_fmt=input_fmt)
else:
self.pool = FastAdaptiveAvgPool(flatten, input_fmt=input_fmt)
self.flatten = nn.Identity()
elif pool_type == 'avg':
self.pool = nn.AdaptiveAvgPool2d(output_size)
elif pool_type == 'avgmax':
self.pool = AdaptiveAvgMaxPool2d(output_size)
elif pool_type == 'catavgmax':
self.pool = AdaptiveCatAvgMaxPool2d(output_size)
elif pool_type == 'max':
self.pool = nn.AdaptiveMaxPool2d(output_size)
else:
assert False, 'Invalid pool type: %s' % pool_type
assert input_fmt == 'NCHW'
if pool_type == 'avgmax':
self.pool = AdaptiveAvgMaxPool2d(output_size)
elif pool_type == 'catavgmax':
self.pool = AdaptiveCatAvgMaxPool2d(output_size)
elif pool_type == 'max':
self.pool = nn.AdaptiveMaxPool2d(output_size)
else:
self.pool = nn.AdaptiveAvgPool2d(output_size)
self.flatten = nn.Flatten(1) if flatten else nn.Identity()
def is_identity(self):
return not self.pool_type

View File

@ -15,13 +15,23 @@ from .create_act import get_act_layer
from .create_norm import get_norm_layer
def _create_pool(num_features, num_classes, pool_type='avg', use_conv=False):
def _create_pool(
num_features: int,
num_classes: int,
pool_type: str = 'avg',
use_conv: bool = False,
input_fmt: Optional[str] = None,
):
flatten_in_pool = not use_conv # flatten when we use a Linear layer after pooling
if not pool_type:
assert num_classes == 0 or use_conv,\
'Pooling can only be disabled if classifier is also removed or conv classifier is used'
flatten_in_pool = False # disable flattening if pooling is pass-through (no pooling)
global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=flatten_in_pool)
global_pool = SelectAdaptivePool2d(
pool_type=pool_type,
flatten=flatten_in_pool,
input_fmt=input_fmt,
)
num_pooled_features = num_features * global_pool.feat_mult()
return global_pool, num_pooled_features
@ -36,9 +46,25 @@ def _create_fc(num_features, num_classes, use_conv=False):
return fc
def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False):
global_pool, num_pooled_features = _create_pool(num_features, num_classes, pool_type, use_conv=use_conv)
fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv)
def create_classifier(
num_features: int,
num_classes: int,
pool_type: str = 'avg',
use_conv: bool = False,
input_fmt: str = 'NCHW',
):
global_pool, num_pooled_features = _create_pool(
num_features,
num_classes,
pool_type,
use_conv=use_conv,
input_fmt=input_fmt,
)
fc = _create_fc(
num_pooled_features,
num_classes,
use_conv=use_conv,
)
return global_pool, fc
@ -52,6 +78,7 @@ class ClassifierHead(nn.Module):
pool_type: str = 'avg',
drop_rate: float = 0.,
use_conv: bool = False,
input_fmt: str = 'NCHW',
):
"""
Args:
@ -64,28 +91,43 @@ class ClassifierHead(nn.Module):
self.drop_rate = drop_rate
self.in_features = in_features
self.use_conv = use_conv
self.input_fmt = input_fmt
self.global_pool, num_pooled_features = _create_pool(in_features, num_classes, pool_type, use_conv=use_conv)
self.fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv)
self.global_pool, self.fc = create_classifier(
in_features,
num_classes,
pool_type,
use_conv=use_conv,
input_fmt=input_fmt,
)
self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity()
def reset(self, num_classes, global_pool=None):
if global_pool is not None:
if global_pool != self.global_pool.pool_type:
self.global_pool, _ = _create_pool(self.in_features, num_classes, global_pool, use_conv=self.use_conv)
self.flatten = nn.Flatten(1) if self.use_conv and global_pool else nn.Identity()
num_pooled_features = self.in_features * self.global_pool.feat_mult()
self.fc = _create_fc(num_pooled_features, num_classes, use_conv=self.use_conv)
def reset(self, num_classes, pool_type=None):
if pool_type is not None and pool_type != self.global_pool.pool_type:
self.global_pool, self.fc = create_classifier(
self.in_features,
num_classes,
pool_type=pool_type,
use_conv=self.use_conv,
input_fmt=self.input_fmt,
)
self.flatten = nn.Flatten(1) if self.use_conv and pool_type else nn.Identity()
else:
num_pooled_features = self.in_features * self.global_pool.feat_mult()
self.fc = _create_fc(
num_pooled_features,
num_classes,
use_conv=self.use_conv,
)
def forward(self, x, pre_logits: bool = False):
x = self.global_pool(x)
if self.drop_rate:
x = F.dropout(x, p=float(self.drop_rate), training=self.training)
if pre_logits:
return x.flatten(1)
else:
x = self.fc(x)
return self.flatten(x)
x = self.fc(x)
return self.flatten(x)
class NormMlpClassifierHead(nn.Module):

View File

@ -0,0 +1,58 @@
from enum import Enum
from typing import Union
import torch
class Format(str, Enum):
NCHW = 'NCHW'
NHWC = 'NHWC'
NCL = 'NCL'
NLC = 'NLC'
FormatT = Union[str, Format]
def get_spatial_dim(fmt: FormatT):
fmt = Format(fmt)
if fmt is Format.NLC:
dim = (1,)
elif fmt is Format.NCL:
dim = (2,)
elif fmt is Format.NHWC:
dim = (1, 2)
else:
dim = (2, 3)
return dim
def get_channel_dim(fmt: FormatT):
fmt = Format(fmt)
if fmt is Format.NHWC:
dim = 3
elif fmt is Format.NLC:
dim = 2
else:
dim = 1
return dim
def nchw_to(x: torch.Tensor, fmt: Format):
if fmt == Format.NHWC:
x = x.permute(0, 2, 3, 1)
elif fmt == Format.NLC:
x = x.flatten(2).transpose(1, 2)
elif fmt == Format.NCL:
x = x.flatten(2)
return x
def nhwc_to(x: torch.Tensor, fmt: Format):
if fmt == Format.NCHW:
x = x.permute(0, 3, 1, 2)
elif fmt == Format.NLC:
x = x.flatten(1, 2)
elif fmt == Format.NCL:
x = x.flatten(1, 2).transpose(1, 2)
return x

View File

@ -9,12 +9,13 @@ Based on code in:
Hacked together by / Copyright 2020 Ross Wightman
"""
import logging
from typing import List
from typing import List, Optional, Callable
import torch
from torch import nn as nn
import torch.nn.functional as F
from .format import Format, nchw_to
from .helpers import to_2tuple
from .trace_utils import _assert
@ -24,15 +25,18 @@ _logger = logging.getLogger(__name__)
class PatchEmbed(nn.Module):
""" 2D Image to Patch Embedding
"""
output_fmt: Format
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
embed_dim=768,
norm_layer=None,
flatten=True,
bias=True,
img_size: int = 224,
patch_size: int = 16,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer: Optional[Callable] = None,
flatten: bool = True,
output_fmt: Optional[str] = None,
bias: bool = True,
):
super().__init__()
img_size = to_2tuple(img_size)
@ -41,7 +45,13 @@ class PatchEmbed(nn.Module):
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
if output_fmt is not None:
self.flatten = False
self.output_fmt = Format(output_fmt)
else:
# flatten spatial dim and transpose to channels last, kept for bwd compat
self.flatten = flatten
self.output_fmt = Format.NCHW
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
@ -52,7 +62,9 @@ class PatchEmbed(nn.Module):
_assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
elif self.output_fmt != Format.NCHW:
x = nchw_to(x, self.output_fmt)
x = self.norm(x)
return x

View File

@ -15,26 +15,48 @@ from .weight_init import trunc_normal_
def gen_relative_position_index(
q_size: Tuple[int, int],
k_size: Tuple[int, int] = None,
class_token: bool = False) -> torch.Tensor:
k_size: Optional[Tuple[int, int]] = None,
class_token: bool = False,
) -> torch.Tensor:
# Adapted with significant modifications from Swin / BeiT codebases
# get pair-wise relative position index for each token inside the window
q_coords = torch.stack(torch.meshgrid([torch.arange(q_size[0]), torch.arange(q_size[1])])).flatten(1) # 2, Wh, Ww
if k_size is None:
k_coords = q_coords
k_size = q_size
coords = torch.stack(
torch.meshgrid([
torch.arange(q_size[0]),
torch.arange(q_size[1])
])
).flatten(1) # 2, Wh, Ww
relative_coords = coords[:, :, None] - coords[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0) # Qh*Qw, Kh*Kw, 2
num_relative_distance = (2 * q_size[0] - 1) * (2 * q_size[1] - 1) + 3
else:
# different q vs k sizes is a WIP
k_coords = torch.stack(torch.meshgrid([torch.arange(k_size[0]), torch.arange(k_size[1])])).flatten(1)
relative_coords = q_coords[:, :, None] - k_coords[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0) # Wh*Ww, Wh*Ww, 2
# FIXME different q vs k sizes is a WIP, need to better offset the two grids?
q_coords = torch.stack(
torch.meshgrid([
torch.arange(q_size[0]),
torch.arange(q_size[1])
])
).flatten(1) # 2, Wh, Ww
k_coords = torch.stack(
torch.meshgrid([
torch.arange(k_size[0]),
torch.arange(k_size[1])
])
).flatten(1)
relative_coords = q_coords[:, :, None] - k_coords[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0) # Qh*Qw, Kh*Kw, 2
# relative_coords[:, :, 0] += max(q_size[0], k_size[0]) - 1 # shift to start from 0
# relative_coords[:, :, 1] += max(q_size[1], k_size[1]) - 1
# relative_coords[:, :, 0] *= k_size[1] + q_size[1] - 1
# relative_position_index = relative_coords.sum(-1) # Qh*Qw, Kh*Kw
num_relative_distance = (q_size[0] + k_size[0] - 1) * (q_size[1] + q_size[1] - 1) + 3
_, relative_position_index = torch.unique(relative_coords.view(-1, 2), return_inverse=True, dim=0)
if class_token:
# handle cls to token & token 2 cls & cls to cls as per beit for rel pos bias
# NOTE not intended or tested with MLP log-coords
max_size = (max(q_size[0], k_size[0]), max(q_size[1], k_size[1]))
num_relative_distance = (2 * max_size[0] - 1) * (2 * max_size[1] - 1) + 3
relative_position_index = F.pad(relative_position_index, [1, 0, 1, 0])
relative_position_index[0, 0:] = num_relative_distance - 3
relative_position_index[0:, 0] = num_relative_distance - 2
@ -59,7 +81,7 @@ class RelPosBias(nn.Module):
self.relative_position_bias_table = nn.Parameter(torch.zeros(num_relative_distance, num_heads))
self.register_buffer(
"relative_position_index",
gen_relative_position_index(self.window_size, class_token=prefix_tokens > 0),
gen_relative_position_index(self.window_size, class_token=prefix_tokens > 0).view(-1),
persistent=False,
)
@ -69,7 +91,7 @@ class RelPosBias(nn.Module):
trunc_normal_(self.relative_position_bias_table, std=.02)
def get_bias(self) -> torch.Tensor:
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)]
relative_position_bias = self.relative_position_bias_table[self.relative_position_index]
# win_h * win_w, win_h * win_w, num_heads
relative_position_bias = relative_position_bias.view(self.bias_shape).permute(2, 0, 1)
return relative_position_bias.unsqueeze(0).contiguous()
@ -148,7 +170,7 @@ class RelPosMlp(nn.Module):
self.register_buffer(
"relative_position_index",
gen_relative_position_index(window_size),
gen_relative_position_index(window_size).view(-1),
persistent=False)
# get relative_coords_table
@ -160,8 +182,7 @@ class RelPosMlp(nn.Module):
def get_bias(self) -> torch.Tensor:
relative_position_bias = self.mlp(self.rel_coords_log)
if self.relative_position_index is not None:
relative_position_bias = relative_position_bias.view(-1, self.num_heads)[
self.relative_position_index.view(-1)] # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.view(-1, self.num_heads)[self.relative_position_index]
relative_position_bias = relative_position_bias.view(self.bias_shape)
relative_position_bias = relative_position_bias.permute(2, 0, 1)
relative_position_bias = self.bias_act(relative_position_bias)

View File

@ -17,6 +17,7 @@ from .edgenext import *
from .efficientformer import *
from .efficientformer_v2 import *
from .efficientnet import *
from .focalnet import *
from .gcvit import *
from .ghostnet import *
from .gluon_resnet import *
@ -71,13 +72,14 @@ from ._builder import build_model_with_cfg, load_pretrained, load_custom_pretrai
from ._factory import create_model, parse_model_name, safe_model_name
from ._features import FeatureInfo, FeatureHooks, FeatureHookNet, FeatureListNet, FeatureDictNet
from ._features_fx import FeatureGraphNet, GraphExtractNet, create_feature_extractor, \
register_notrace_module, register_notrace_function
from ._helpers import clean_state_dict, load_state_dict, load_checkpoint, remap_checkpoint, resume_checkpoint
register_notrace_module, is_notrace_module, get_notrace_modules, \
register_notrace_function, is_notrace_function, get_notrace_functions
from ._helpers import clean_state_dict, load_state_dict, load_checkpoint, remap_state_dict, resume_checkpoint
from ._hub import load_model_config_from_hf, load_state_dict_from_hf, push_to_hf_hub
from ._manipulate import model_parameters, named_apply, named_modules, named_modules_with_params, \
group_modules, group_parameters, checkpoint_seq, adapt_input_conv
from ._pretrained import PretrainedCfg, DefaultCfg, \
filter_pretrained_cfg, generate_default_cfgs, split_model_name_tag
from ._pretrained import PretrainedCfg, DefaultCfg, filter_pretrained_cfg
from ._prune import adapt_model_from_string
from ._registry import register_model, model_entrypoint, list_models, list_pretrained, is_model, list_modules, \
is_model_in_modules, is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value
from ._registry import split_model_name_tag, get_arch_name, generate_default_cfgs, register_model, \
register_model_deprecations, model_entrypoint, list_models, list_pretrained, get_deprecated_models, \
is_model, list_modules, is_model_in_modules, is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value

View File

@ -392,6 +392,9 @@ def build_model_with_cfg(
# Wrap the model in a feature extraction module if enabled
if features:
feature_cls = FeatureListNet
output_fmt = getattr(model, 'output_fmt', None)
if output_fmt is not None:
feature_cfg.setdefault('output_fmt', output_fmt)
if 'feature_cls' in feature_cfg:
feature_cls = feature_cfg.pop('feature_cls')
if isinstance(feature_cls, str):
@ -403,7 +406,7 @@ def build_model_with_cfg(
else:
assert False, f'Unknown feature class {feature_cls}'
model = feature_cls(model, **feature_cfg)
model.pretrained_cfg = pretrained_cfg_for_features(pretrained_cfg) # add back default_cfg
model.default_cfg = model.pretrained_cfg # alias for backwards compat
model.pretrained_cfg = pretrained_cfg_for_features(pretrained_cfg) # add back pretrained cfg
model.default_cfg = model.pretrained_cfg # alias for rename backwards compat (default_cfg -> pretrained_cfg)
return model

View File

@ -3,10 +3,10 @@ from typing import Any, Dict, Optional, Union
from urllib.parse import urlsplit
from timm.layers import set_layer_config
from ._pretrained import PretrainedCfg, split_model_name_tag
from ._helpers import load_checkpoint
from ._hub import load_model_config_from_hf
from ._registry import is_model, model_entrypoint
from ._pretrained import PretrainedCfg
from ._registry import is_model, model_entrypoint, split_model_name_tag
__all__ = ['parse_model_name', 'safe_model_name', 'create_model']

View File

@ -17,6 +17,8 @@ import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from timm.layers import Format
__all__ = ['FeatureInfo', 'FeatureHooks', 'FeatureDictNet', 'FeatureListNet', 'FeatureHookNet']
@ -181,6 +183,7 @@ class FeatureDictNet(nn.ModuleDict):
model: nn.Module,
out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
out_map: Sequence[Union[int, str]] = None,
output_fmt: str = 'NCHW',
feature_concat: bool = False,
flatten_sequential: bool = False,
):
@ -195,6 +198,7 @@ class FeatureDictNet(nn.ModuleDict):
"""
super(FeatureDictNet, self).__init__()
self.feature_info = _get_feature_info(model, out_indices)
self.output_fmt = Format(output_fmt)
self.concat = feature_concat
self.grad_checkpointing = False
self.return_layers = {}
@ -253,6 +257,7 @@ class FeatureListNet(FeatureDictNet):
self,
model: nn.Module,
out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
output_fmt: str = 'NCHW',
feature_concat: bool = False,
flatten_sequential: bool = False,
):
@ -264,9 +269,10 @@ class FeatureListNet(FeatureDictNet):
first element e.g. `x[0]`
flatten_sequential: Flatten first two-levels of sequential modules in model (re-writes model modules)
"""
super(FeatureListNet, self).__init__(
super().__init__(
model,
out_indices=out_indices,
output_fmt=output_fmt,
feature_concat=feature_concat,
flatten_sequential=flatten_sequential,
)
@ -293,7 +299,8 @@ class FeatureHookNet(nn.ModuleDict):
model: nn.Module,
out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
out_map: Sequence[Union[int, str]] = None,
out_as_dict: bool = False,
return_dict: bool = False,
output_fmt: str = 'NCHW',
no_rewrite: bool = False,
flatten_sequential: bool = False,
default_hook_type: str = 'forward',
@ -304,16 +311,17 @@ class FeatureHookNet(nn.ModuleDict):
model: Model from which to extract features.
out_indices: Output indices of the model features to extract.
out_map: Return id mapping for each output index, otherwise str(index) is used.
out_as_dict: Output features as a dict.
return_dict: Output features as a dict.
no_rewrite: Enforce that model is not re-written if True, ie no modules are removed / changed.
flatten_sequential arg must also be False if this is set True.
flatten_sequential: Re-write modules by flattening first two levels of nn.Sequential containers.
default_hook_type: The default hook type to use if not specified in model.feature_info.
"""
super(FeatureHookNet, self).__init__()
super().__init__()
assert not torch.jit.is_scripting()
self.feature_info = _get_feature_info(model, out_indices)
self.out_as_dict = out_as_dict
self.return_dict = return_dict
self.output_fmt = Format(output_fmt)
self.grad_checkpointing = False
layers = OrderedDict()
@ -356,4 +364,4 @@ class FeatureHookNet(nn.ModuleDict):
else:
x = module(x)
out = self.hooks.get_output(x.device)
return out if self.out_as_dict else list(out.values())
return out if self.return_dict else list(out.values())

View File

@ -19,6 +19,11 @@ from timm.layers import Conv2dSame, ScaledStdConv2dSame, CondConv2d, StdConv2dSa
from timm.layers.non_local_attn import BilinearAttnTransform
from timm.layers.pool2d_same import MaxPool2dSame, AvgPool2dSame
__all__ = ['register_notrace_module', 'is_notrace_module', 'get_notrace_modules',
'register_notrace_function', 'is_notrace_function', 'get_notrace_functions',
'create_feature_extractor', 'FeatureGraphNet', 'GraphExtractNet']
# NOTE: By default, any modules from timm.models.layers that we want to treat as leaf modules go here
# BUT modules from timm.models should use the registration mechanism below
_leaf_modules = {
@ -35,10 +40,6 @@ except ImportError:
pass
__all__ = ['register_notrace_module', 'register_notrace_function', 'create_feature_extractor',
'FeatureGraphNet', 'GraphExtractNet']
def register_notrace_module(module: Type[nn.Module]):
"""
Any module not under timm.models.layers should get this decorator if we don't want to trace through it.
@ -47,6 +48,14 @@ def register_notrace_module(module: Type[nn.Module]):
return module
def is_notrace_module(module: Type[nn.Module]):
return module in _leaf_modules
def get_notrace_modules():
return list(_leaf_modules)
# Functions we want to autowrap (treat them as leaves)
_autowrap_functions = set()
@ -59,6 +68,14 @@ def register_notrace_function(func: Callable):
return func
def is_notrace_function(func: Callable):
return func in _autowrap_functions
def get_notrace_functions():
return list(_autowrap_functions)
def create_feature_extractor(model: nn.Module, return_nodes: Union[Dict[str, str], List[str]]):
assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
return _create_feature_extractor(

View File

@ -5,6 +5,7 @@ Hacked together by / Copyright 2020 Ross Wightman
import logging
import os
from collections import OrderedDict
from typing import Any, Callable, Dict, Optional, Union
import torch
try:
@ -13,30 +14,32 @@ try:
except ImportError:
_has_safetensors = False
import timm.models._builder
_logger = logging.getLogger(__name__)
__all__ = ['clean_state_dict', 'load_state_dict', 'load_checkpoint', 'remap_checkpoint', 'resume_checkpoint']
__all__ = ['clean_state_dict', 'load_state_dict', 'load_checkpoint', 'remap_state_dict', 'resume_checkpoint']
def clean_state_dict(state_dict):
def clean_state_dict(state_dict: Dict[str, Any]) -> Dict[str, Any]:
# 'clean' checkpoint by removing .module prefix from state dict if it exists from parallel training
cleaned_state_dict = OrderedDict()
cleaned_state_dict = {}
for k, v in state_dict.items():
name = k[7:] if k.startswith('module.') else k
cleaned_state_dict[name] = v
return cleaned_state_dict
def load_state_dict(checkpoint_path, use_ema=True):
def load_state_dict(
checkpoint_path: str,
use_ema: bool = True,
device: Union[str, torch.device] = 'cpu',
) -> Dict[str, Any]:
if checkpoint_path and os.path.isfile(checkpoint_path):
# Check if safetensors or not and load weights accordingly
if str(checkpoint_path).endswith(".safetensors"):
assert _has_safetensors, "`pip install safetensors` to use .safetensors"
checkpoint = safetensors.torch.load_file(checkpoint_path, device='cpu')
checkpoint = safetensors.torch.load_file(checkpoint_path, device=device)
else:
checkpoint = torch.load(checkpoint_path, map_location='cpu')
checkpoint = torch.load(checkpoint_path, map_location=device)
state_dict_key = ''
if isinstance(checkpoint, dict):
@ -56,22 +59,37 @@ def load_state_dict(checkpoint_path, use_ema=True):
raise FileNotFoundError()
def load_checkpoint(model, checkpoint_path, use_ema=True, strict=True, remap=False):
def load_checkpoint(
model: torch.nn.Module,
checkpoint_path: str,
use_ema: bool = True,
device: Union[str, torch.device] = 'cpu',
strict: bool = True,
remap: bool = False,
filter_fn: Optional[Callable] = None,
):
if os.path.splitext(checkpoint_path)[-1].lower() in ('.npz', '.npy'):
# numpy checkpoint, try to load via model specific load_pretrained fn
if hasattr(model, 'load_pretrained'):
timm.models._model_builder.load_pretrained(checkpoint_path)
model.load_pretrained(checkpoint_path)
else:
raise NotImplementedError('Model cannot load numpy checkpoint')
return
state_dict = load_state_dict(checkpoint_path, use_ema)
state_dict = load_state_dict(checkpoint_path, use_ema, device=device)
if remap:
state_dict = remap_checkpoint(model, state_dict)
state_dict = remap_state_dict(state_dict, model)
elif filter_fn:
state_dict = filter_fn(state_dict, model)
incompatible_keys = model.load_state_dict(state_dict, strict=strict)
return incompatible_keys
def remap_checkpoint(model, state_dict, allow_reshape=True):
def remap_state_dict(
state_dict: Dict[str, Any],
model: torch.nn.Module,
allow_reshape: bool = True
):
""" remap checkpoint by iterating over state dicts in order (ignoring original keys).
This assumes models (and originating state dict) were created with params registered in same order.
"""
@ -87,7 +105,13 @@ def remap_checkpoint(model, state_dict, allow_reshape=True):
return out_dict
def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True):
def resume_checkpoint(
model: torch.nn.Module,
checkpoint_path: str,
optimizer: torch.optim.Optimizer = None,
loss_scaler: Any = None,
log_info: bool = True,
):
resume_epoch = None
if os.path.isfile(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location='cpu')

View File

@ -3,7 +3,7 @@ import math
import re
from collections import defaultdict
from itertools import chain
from typing import Callable, Union, Dict
from typing import Any, Callable, Dict, Iterator, Tuple, Type, Union
import torch
from torch import nn as nn
@ -13,7 +13,7 @@ __all__ = ['model_parameters', 'named_apply', 'named_modules', 'named_modules_wi
'group_with_matcher', 'group_modules', 'group_parameters', 'flatten_modules', 'checkpoint_seq']
def model_parameters(model, exclude_head=False):
def model_parameters(model: nn.Module, exclude_head: bool = False):
if exclude_head:
# FIXME this a bit of a quick and dirty hack to skip classifier head params based on ordering
return [p for p in model.parameters()][:-2]
@ -21,7 +21,12 @@ def model_parameters(model, exclude_head=False):
return model.parameters()
def named_apply(fn: Callable, module: nn.Module, name='', depth_first=True, include_root=False) -> nn.Module:
def named_apply(
fn: Callable,
module: nn.Module, name='',
depth_first: bool = True,
include_root: bool = False,
) -> nn.Module:
if not depth_first and include_root:
fn(module=module, name=name)
for child_name, child_module in module.named_children():
@ -32,7 +37,12 @@ def named_apply(fn: Callable, module: nn.Module, name='', depth_first=True, incl
return module
def named_modules(module: nn.Module, name='', depth_first=True, include_root=False):
def named_modules(
module: nn.Module,
name: str = '',
depth_first: bool = True,
include_root: bool = False,
):
if not depth_first and include_root:
yield name, module
for child_name, child_module in module.named_children():
@ -43,7 +53,12 @@ def named_modules(module: nn.Module, name='', depth_first=True, include_root=Fal
yield name, module
def named_modules_with_params(module: nn.Module, name='', depth_first=True, include_root=False):
def named_modules_with_params(
module: nn.Module,
name: str = '',
depth_first: bool = True,
include_root: bool = False,
):
if module._parameters and not depth_first and include_root:
yield name, module
for child_name, child_module in module.named_children():
@ -58,9 +73,9 @@ MATCH_PREV_GROUP = (99999,)
def group_with_matcher(
named_objects,
named_objects: Iterator[Tuple[str, Any]],
group_matcher: Union[Dict, Callable],
output_values: bool = False,
return_values: bool = False,
reverse: bool = False
):
if isinstance(group_matcher, dict):
@ -96,7 +111,7 @@ def group_with_matcher(
# map layers into groups via ordinals (ints or tuples of ints) from matcher
grouping = defaultdict(list)
for k, v in named_objects:
grouping[_get_grouping(k)].append(v if output_values else k)
grouping[_get_grouping(k)].append(v if return_values else k)
# remap to integers
layer_id_to_param = defaultdict(list)
@ -107,7 +122,7 @@ def group_with_matcher(
layer_id_to_param[lid].extend(grouping[k])
if reverse:
assert not output_values, "reverse mapping only sensible for name output"
assert not return_values, "reverse mapping only sensible for name output"
# output reverse mapping
param_to_layer_id = {}
for lid, lm in layer_id_to_param.items():
@ -121,24 +136,29 @@ def group_with_matcher(
def group_parameters(
module: nn.Module,
group_matcher,
output_values=False,
reverse=False,
return_values: bool = False,
reverse: bool = False,
):
return group_with_matcher(
module.named_parameters(), group_matcher, output_values=output_values, reverse=reverse)
module.named_parameters(), group_matcher, return_values=return_values, reverse=reverse)
def group_modules(
module: nn.Module,
group_matcher,
output_values=False,
reverse=False,
return_values: bool = False,
reverse: bool = False,
):
return group_with_matcher(
named_modules_with_params(module), group_matcher, output_values=output_values, reverse=reverse)
named_modules_with_params(module), group_matcher, return_values=return_values, reverse=reverse)
def flatten_modules(named_modules, depth=1, prefix='', module_types='sequential'):
def flatten_modules(
named_modules: Iterator[Tuple[str, nn.Module]],
depth: int = 1,
prefix: Union[str, Tuple[str, ...]] = '',
module_types: Union[str, Tuple[Type[nn.Module]]] = 'sequential',
):
prefix_is_tuple = isinstance(prefix, tuple)
if isinstance(module_types, str):
if module_types == 'container':

View File

@ -4,7 +4,7 @@ from dataclasses import dataclass, field, replace, asdict
from typing import Any, Deque, Dict, Tuple, Optional, Union
__all__ = ['PretrainedCfg', 'filter_pretrained_cfg', 'DefaultCfg', 'split_model_name_tag', 'generate_default_cfgs']
__all__ = ['PretrainedCfg', 'filter_pretrained_cfg', 'DefaultCfg']
@dataclass
@ -91,41 +91,3 @@ class DefaultCfg:
def default_with_tag(self):
tag = self.tags[0]
return tag, self.cfgs[tag]
def split_model_name_tag(model_name: str, no_tag: str = '') -> Tuple[str, str]:
model_name, *tag_list = model_name.split('.', 1)
tag = tag_list[0] if tag_list else no_tag
return model_name, tag
def generate_default_cfgs(cfgs: Dict[str, Union[Dict[str, Any], PretrainedCfg]]):
out = defaultdict(DefaultCfg)
default_set = set() # no tag and tags ending with * are prioritized as default
for k, v in cfgs.items():
if isinstance(v, dict):
v = PretrainedCfg(**v)
has_weights = v.has_weights
model, tag = split_model_name_tag(k)
is_default_set = model in default_set
priority = (has_weights and not tag) or (tag.endswith('*') and not is_default_set)
tag = tag.strip('*')
default_cfg = out[model]
if priority:
default_cfg.tags.appendleft(tag)
default_set.add(model)
elif has_weights and not default_cfg.is_pretrained:
default_cfg.tags.appendleft(tag)
else:
default_cfg.tags.append(tag)
if has_weights:
default_cfg.is_pretrained = True
default_cfg.cfgs[tag] = v
return out

View File

@ -5,16 +5,19 @@ Hacked together by / Copyright 2020 Ross Wightman
import fnmatch
import re
import sys
import warnings
from collections import defaultdict, deque
from copy import deepcopy
from dataclasses import replace
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Sequence, Union, Tuple
from ._pretrained import PretrainedCfg, DefaultCfg, split_model_name_tag
from ._pretrained import PretrainedCfg, DefaultCfg
__all__ = [
'split_model_name_tag', 'get_arch_name', 'register_model', 'generate_default_cfgs',
'list_models', 'list_pretrained', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules',
'get_pretrained_cfg_value', 'is_model_pretrained', 'get_arch_name']
'get_pretrained_cfg_value', 'is_model_pretrained'
]
_module_to_models: Dict[str, Set[str]] = defaultdict(set) # dict of sets to check membership of model in module
_model_to_module: Dict[str, str] = {} # mapping of model names to module names
@ -23,12 +26,52 @@ _model_has_pretrained: Set[str] = set() # set of model names that have pretrain
_model_default_cfgs: Dict[str, PretrainedCfg] = {} # central repo for model arch -> default cfg objects
_model_pretrained_cfgs: Dict[str, PretrainedCfg] = {} # central repo for model arch.tag -> pretrained cfgs
_model_with_tags: Dict[str, List[str]] = defaultdict(list) # shortcut to map each model arch to all model + tag names
_module_to_deprecated_models: Dict[str, Dict[str, Optional[str]]] = defaultdict(dict)
_deprecated_models: Dict[str, Optional[str]] = {}
def split_model_name_tag(model_name: str, no_tag: str = '') -> Tuple[str, str]:
model_name, *tag_list = model_name.split('.', 1)
tag = tag_list[0] if tag_list else no_tag
return model_name, tag
def get_arch_name(model_name: str) -> str:
return split_model_name_tag(model_name)[0]
def generate_default_cfgs(cfgs: Dict[str, Union[Dict[str, Any], PretrainedCfg]]):
out = defaultdict(DefaultCfg)
default_set = set() # no tag and tags ending with * are prioritized as default
for k, v in cfgs.items():
if isinstance(v, dict):
v = PretrainedCfg(**v)
has_weights = v.has_weights
model, tag = split_model_name_tag(k)
is_default_set = model in default_set
priority = (has_weights and not tag) or (tag.endswith('*') and not is_default_set)
tag = tag.strip('*')
default_cfg = out[model]
if priority:
default_cfg.tags.appendleft(tag)
default_set.add(model)
elif has_weights and not default_cfg.is_pretrained:
default_cfg.tags.appendleft(tag)
else:
default_cfg.tags.append(tag)
if has_weights:
default_cfg.is_pretrained = True
default_cfg.cfgs[tag] = v
return out
def register_model(fn: Callable[..., Any]) -> Callable[..., Any]:
# lookup containing module
mod = sys.modules[fn.__module__]
@ -87,6 +130,37 @@ def register_model(fn: Callable[..., Any]) -> Callable[..., Any]:
return fn
def _deprecated_model_shim(deprecated_name: str, current_fn: Callable = None, current_tag: str = ''):
def _fn(pretrained=False, **kwargs):
assert current_fn is not None, f'Model {deprecated_name} has been removed with no replacement.'
warnings.warn(f'Mapping deprecated model {deprecated_name} to current {current_fn.__name__}', stacklevel=2)
pretrained_cfg = kwargs.pop('pretrained_cfg', None)
return current_fn(pretrained=pretrained, pretrained_cfg=pretrained_cfg or current_tag, **kwargs)
return _fn
def register_model_deprecations(module_name: str, deprecation_map: Dict[str, Optional[str]]):
mod = sys.modules[module_name]
module_name_split = module_name.split('.')
module_name = module_name_split[-1] if len(module_name_split) else ''
for deprecated, current in deprecation_map.items():
if hasattr(mod, '__all__'):
mod.__all__.append(deprecated)
current_fn = None
current_tag = ''
if current:
current_name, current_tag = split_model_name_tag(current)
current_fn = getattr(mod, current_name)
deprecated_entrypoint_fn = _deprecated_model_shim(deprecated, current_fn, current_tag)
setattr(mod, deprecated, deprecated_entrypoint_fn)
_model_entrypoints[deprecated] = deprecated_entrypoint_fn
_model_to_module[deprecated] = module_name
_module_to_models[module_name].add(deprecated)
_deprecated_models[deprecated] = current
_module_to_deprecated_models[module_name][deprecated] = current
def _natural_key(string_: str) -> List[Union[int, str]]:
"""See https://blog.codinghorror.com/sorting-for-humans-natural-sort-order/"""
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
@ -122,16 +196,14 @@ def list_models(
# FIXME should this be default behaviour? or default to include_tags=True?
include_tags = pretrained
if module:
all_models: Iterable[str] = list(_module_to_models[module])
else:
all_models = _model_entrypoints.keys()
all_models: Set[str] = _module_to_models[module] if module else set(_model_entrypoints.keys())
all_models = all_models - _deprecated_models.keys() # remove deprecated models from listings
if include_tags:
# expand model names to include names w/ pretrained tags
models_with_tags = []
models_with_tags: Set[str] = set()
for m in all_models:
models_with_tags.extend(_model_with_tags[m])
models_with_tags.update(_model_with_tags[m])
all_models = models_with_tags
if filter:
@ -142,7 +214,7 @@ def list_models(
if len(include_models):
models = models.union(include_models)
else:
models = set(all_models)
models = all_models
if exclude_filters:
if not isinstance(exclude_filters, (tuple, list)):
@ -173,6 +245,11 @@ def list_pretrained(
)
def get_deprecated_models(module: str = '') -> Dict[str, str]:
all_deprecated = _module_to_deprecated_models[module] if module else _deprecated_models
return deepcopy(all_deprecated)
def is_model(model_name: str) -> bool:
""" Check if a model name exists
"""

View File

@ -63,8 +63,7 @@ from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_
from ._builder import build_model_with_cfg
from ._pretrained import generate_default_cfgs
from ._registry import register_model
from ._registry import generate_default_cfgs, register_model
from .vision_transformer import checkpoint_filter_fn
__all__ = ['Beit']

View File

@ -50,8 +50,7 @@ from timm.layers import trunc_normal_, AvgPool2dSame, DropPath, Mlp, GlobalRespo
from timm.layers import NormMlpClassifierHead, ClassifierHead
from ._builder import build_model_with_cfg
from ._manipulate import named_apply, checkpoint_seq
from ._pretrained import generate_default_cfgs
from ._registry import register_model
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
__all__ = ['ConvNeXt'] # model_registry will add each entrypoint fn to this
@ -406,7 +405,7 @@ class ConvNeXt(nn.Module):
return self.head.fc
def reset_classifier(self, num_classes=0, global_pool=None):
self.head.reset(num_classes, global_pool=global_pool)
self.head.reset(num_classes, global_pool)
def forward_features(self, x):
x = self.stem(x)
@ -415,7 +414,7 @@ class ConvNeXt(nn.Module):
return x
def forward_head(self, x, pre_logits: bool = False):
return self.head(x, pre_logits=pre_logits)
return self.head(x, pre_logits=True) if pre_logits else self.head(x)
def forward(self, x):
x = self.forward_features(x)
@ -519,6 +518,13 @@ def _cfgv2(url='', **kwargs):
default_cfgs = generate_default_cfgs({
# timm specific variants
'convnext_tiny.in12k_ft_in1k': _cfg(
hf_hub_id='timm/',
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_small.in12k_ft_in1k': _cfg(
hf_hub_id='timm/',
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_atto.d2_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_d2-01bb0f51.pth',
hf_hub_id='timm/',
@ -558,12 +564,6 @@ default_cfgs = generate_default_cfgs({
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_tiny_hnf_a2h-ab7e9df2.pth',
hf_hub_id='timm/',
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_tiny.in12k_ft_in1k': _cfg(
hf_hub_id='timm/',
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_small.in12k_ft_in1k': _cfg(
hf_hub_id='timm/',
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_tiny.in12k_ft_in1k_384': _cfg(
hf_hub_id='timm/',
@ -582,25 +582,6 @@ default_cfgs = generate_default_cfgs({
hf_hub_id='timm/',
crop_pct=0.95, num_classes=11821),
'convnext_tiny.fb_in1k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth",
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_small.fb_in1k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth",
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_base.fb_in1k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth",
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_large.fb_in1k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth",
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_xlarge.untrained': _cfg(),
'convnext_xxlarge.untrained': _cfg(),
'convnext_tiny.fb_in22k_ft_in1k': _cfg(
url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_224.pth',
hf_hub_id='timm/',
@ -622,6 +603,23 @@ default_cfgs = generate_default_cfgs({
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_tiny.fb_in1k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth",
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_small.fb_in1k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth",
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_base.fb_in1k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth",
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_large.fb_in1k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth",
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_tiny.fb_in22k_ft_in1k_384': _cfg(
url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_384.pth',
hf_hub_id='timm/',
@ -1038,3 +1036,22 @@ def convnextv2_huge(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[352, 704, 1408, 2816], use_grn=True, ls_init_value=None)
model = _create_convnext('convnextv2_huge', pretrained=pretrained, **dict(model_args, **kwargs))
return model
register_model_deprecations(__name__, {
'convnext_tiny_in22ft1k': 'convnext_tiny.fb_in22k_ft_in1k',
'convnext_small_in22ft1k': 'convnext_small.fb_in22k_ft_in1k',
'convnext_base_in22ft1k': 'convnext_base.fb_in22k_ft_in1k',
'convnext_large_in22ft1k': 'convnext_large.fb_in22k_ft_in1k',
'convnext_xlarge_in22ft1k': 'convnext_xlarge.fb_in22k_ft_in1k',
'convnext_tiny_384_in22ft1k': 'convnext_tiny.fb_in22k_ft_in1k_384',
'convnext_small_384_in22ft1k': 'convnext_small.fb_in22k_ft_in1k_384',
'convnext_base_384_in22ft1k': 'convnext_base.fb_in22k_ft_in1k_384',
'convnext_large_384_in22ft1k': 'convnext_large.fb_in22k_ft_in1k_384',
'convnext_xlarge_384_in22ft1k': 'convnext_xlarge.fb_in22k_ft_in1k_384',
'convnext_tiny_in22k': 'convnext_tiny.fb_in22k',
'convnext_small_in22k': 'convnext_small.fb_in22k',
'convnext_base_in22k': 'convnext_base.fb_in22k',
'convnext_large_in22k': 'convnext_large.fb_in22k',
'convnext_xlarge_in22k': 'convnext_xlarge.fb_in22k',
})

View File

@ -11,8 +11,6 @@ DaViT model defs and weights adapted from https://github.com/dingmyu/davit, orig
# Copyright (c) 2022 Mingyu Ding
# All rights reserved.
# This source code is licensed under the MIT license
import itertools
from collections import OrderedDict
from functools import partial
from typing import Tuple
@ -22,13 +20,12 @@ import torch.nn.functional as F
from torch import Tensor
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, to_2tuple, trunc_normal_, SelectAdaptivePool2d, Mlp, LayerNorm2d, get_norm_layer
from timm.layers import DropPath, to_2tuple, trunc_normal_, Mlp, LayerNorm2d, get_norm_layer
from timm.layers import NormMlpClassifierHead, ClassifierHead
from ._builder import build_model_with_cfg
from ._features_fx import register_notrace_function
from ._manipulate import checkpoint_seq
from ._pretrained import generate_default_cfgs
from ._registry import register_model
from ._registry import generate_default_cfgs, register_model
__all__ = ['DaViT']

View File

@ -359,10 +359,9 @@ class DLA(nn.Module):
if self.drop_rate > 0.:
x = F.dropout(x, p=self.drop_rate, training=self.training)
if pre_logits:
return x.flatten(1)
else:
x = self.fc(x)
return self.flatten(x)
x = self.fc(x)
return self.flatten(x)
def forward(self, x):
x = self.forward_features(x)

View File

@ -298,10 +298,9 @@ class DPN(nn.Module):
if self.drop_rate > 0.:
x = F.dropout(x, p=self.drop_rate, training=self.training)
if pre_logits:
return x.flatten(1)
else:
x = self.classifier(x)
return self.flatten(x)
x = self.classifier(x)
return self.flatten(x)
def forward(self, x):
x = self.forward_features(x)

View File

@ -21,8 +21,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, trunc_normal_, to_2tuple, Mlp
from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq
from ._pretrained import generate_default_cfgs
from ._registry import register_model
from ._registry import generate_default_cfgs, register_model
__all__ = ['EfficientFormer'] # model_registry will add each entrypoint fn to this

View File

@ -26,8 +26,7 @@ from timm.layers import create_conv2d, create_norm_layer, get_act_layer, get_nor
from timm.layers import DropPath, trunc_normal_, to_2tuple, to_ntuple
from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq
from ._pretrained import generate_default_cfgs
from ._registry import register_model
from ._registry import generate_default_cfgs, register_model
EfficientFormer_width = {

View File

@ -51,8 +51,7 @@ from ._efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficie
round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
from ._features import FeatureInfo, FeatureHooks
from ._manipulate import checkpoint_seq
from ._pretrained import generate_default_cfgs
from ._registry import register_model
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
__all__ = ['EfficientNet', 'EfficientNetFeatures']
@ -1064,42 +1063,46 @@ default_cfgs = generate_default_cfgs({
'efficientnetv2_xl.untrained': _cfg(
input_size=(3, 384, 384), test_input_size=(3, 512, 512), pool_size=(12, 12), crop_pct=1.0),
'tf_efficientnet_b0.aa_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_aa-827b6e33.pth',
'tf_efficientnet_b0.ns_jft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_ns-c0e6a31c.pth',
hf_hub_id='timm/',
input_size=(3, 224, 224)),
'tf_efficientnet_b1.aa_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_aa-ea7a6ee0.pth',
'tf_efficientnet_b1.ns_jft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_ns-99dd0c41.pth',
hf_hub_id='timm/',
input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882),
'tf_efficientnet_b2.aa_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_aa-60c94f97.pth',
'tf_efficientnet_b2.ns_jft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_ns-00306e48.pth',
hf_hub_id='timm/',
input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890),
'tf_efficientnet_b3.aa_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_aa-84b4657e.pth',
'tf_efficientnet_b3.ns_jft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_ns-9d44bf68.pth',
hf_hub_id='timm/',
input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904),
'tf_efficientnet_b4.aa_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_aa-818f208c.pth',
'tf_efficientnet_b4.ns_jft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_ns-d6313a46.pth',
hf_hub_id='timm/',
input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922),
'tf_efficientnet_b5.ra_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ra-9a3e5369.pth',
'tf_efficientnet_b5.ns_jft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ns-6f26d0cf.pth',
hf_hub_id='timm/',
input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934),
'tf_efficientnet_b6.aa_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_aa-80ba17e4.pth',
'tf_efficientnet_b6.ns_jft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_ns-51548356.pth',
hf_hub_id='timm/',
input_size=(3, 528, 528), pool_size=(17, 17), crop_pct=0.942),
'tf_efficientnet_b7.ra_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ra-6c08e654.pth',
'tf_efficientnet_b7.ns_jft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ns-1dbc32de.pth',
hf_hub_id='timm/',
input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949),
'tf_efficientnet_b8.ra_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b8_ra-572d5dd9.pth',
'tf_efficientnet_l2.ns_jft_in1k_475': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns_475-bebbd00a.pth',
hf_hub_id='timm/',
input_size=(3, 672, 672), pool_size=(21, 21), crop_pct=0.954),
input_size=(3, 475, 475), pool_size=(15, 15), crop_pct=0.936),
'tf_efficientnet_l2.ns_jft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns-df73bb44.pth',
hf_hub_id='timm/',
input_size=(3, 800, 800), pool_size=(25, 25), crop_pct=0.96),
'tf_efficientnet_b0.ap_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_ap-f262efe1.pth',
@ -1146,46 +1149,42 @@ default_cfgs = generate_default_cfgs({
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD,
input_size=(3, 672, 672), pool_size=(21, 21), crop_pct=0.954),
'tf_efficientnet_b0.ns_jft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_ns-c0e6a31c.pth',
'tf_efficientnet_b0.aa_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_aa-827b6e33.pth',
hf_hub_id='timm/',
input_size=(3, 224, 224)),
'tf_efficientnet_b1.ns_jft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_ns-99dd0c41.pth',
'tf_efficientnet_b1.aa_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_aa-ea7a6ee0.pth',
hf_hub_id='timm/',
input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882),
'tf_efficientnet_b2.ns_jft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_ns-00306e48.pth',
'tf_efficientnet_b2.aa_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_aa-60c94f97.pth',
hf_hub_id='timm/',
input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890),
'tf_efficientnet_b3.ns_jft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_ns-9d44bf68.pth',
'tf_efficientnet_b3.aa_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_aa-84b4657e.pth',
hf_hub_id='timm/',
input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904),
'tf_efficientnet_b4.ns_jft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_ns-d6313a46.pth',
'tf_efficientnet_b4.aa_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_aa-818f208c.pth',
hf_hub_id='timm/',
input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922),
'tf_efficientnet_b5.ns_jft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ns-6f26d0cf.pth',
'tf_efficientnet_b5.ra_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ra-9a3e5369.pth',
hf_hub_id='timm/',
input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934),
'tf_efficientnet_b6.ns_jft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_ns-51548356.pth',
'tf_efficientnet_b6.aa_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_aa-80ba17e4.pth',
hf_hub_id='timm/',
input_size=(3, 528, 528), pool_size=(17, 17), crop_pct=0.942),
'tf_efficientnet_b7.ns_jft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ns-1dbc32de.pth',
'tf_efficientnet_b7.ra_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ra-6c08e654.pth',
hf_hub_id='timm/',
input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949),
'tf_efficientnet_l2.ns_jft_in1k_475': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns_475-bebbd00a.pth',
'tf_efficientnet_b8.ra_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b8_ra-572d5dd9.pth',
hf_hub_id='timm/',
input_size=(3, 475, 475), pool_size=(15, 15), crop_pct=0.936),
'tf_efficientnet_l2.ns_jft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns-df73bb44.pth',
hf_hub_id='timm/',
input_size=(3, 800, 800), pool_size=(25, 25), crop_pct=0.96),
input_size=(3, 672, 672), pool_size=(21, 21), crop_pct=0.954),
'tf_efficientnet_es.in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_es-ca1afbfe.pth',
@ -1248,22 +1247,6 @@ default_cfgs = generate_default_cfgs({
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.920, interpolation='bilinear'),
'tf_efficientnetv2_s.in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s-eb54923e.pth',
hf_hub_id='timm/',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
input_size=(3, 300, 300), test_input_size=(3, 384, 384), pool_size=(10, 10), crop_pct=1.0),
'tf_efficientnetv2_m.in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_m-cc09e0cd.pth',
hf_hub_id='timm/',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'tf_efficientnetv2_l.in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_l-d664b728.pth',
hf_hub_id='timm/',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'tf_efficientnetv2_s.in21k_ft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s_21ft1k-d7dafa41.pth',
hf_hub_id='timm/',
@ -1285,6 +1268,22 @@ default_cfgs = generate_default_cfgs({
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
input_size=(3, 384, 384), test_input_size=(3, 512, 512), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'tf_efficientnetv2_s.in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s-eb54923e.pth',
hf_hub_id='timm/',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
input_size=(3, 300, 300), test_input_size=(3, 384, 384), pool_size=(10, 10), crop_pct=1.0),
'tf_efficientnetv2_m.in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_m-cc09e0cd.pth',
hf_hub_id='timm/',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'tf_efficientnetv2_l.in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_l-d664b728.pth',
hf_hub_id='timm/',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'tf_efficientnetv2_s.in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s_21k-6337ad01.pth',
hf_hub_id='timm/',
@ -2289,3 +2288,34 @@ def tinynet_d(pretrained=False, **kwargs):
def tinynet_e(pretrained=False, **kwargs):
model = _gen_tinynet('tinynet_e', 0.51, 0.6, pretrained=pretrained, **kwargs)
return model
register_model_deprecations(__name__, {
'tf_efficientnet_b0_ap': 'tf_efficientnet_b0.ap_in1k',
'tf_efficientnet_b1_ap': 'tf_efficientnet_b1.ap_in1k',
'tf_efficientnet_b2_ap': 'tf_efficientnet_b2.ap_in1k',
'tf_efficientnet_b3_ap': 'tf_efficientnet_b3.ap_in1k',
'tf_efficientnet_b4_ap': 'tf_efficientnet_b4.ap_in1k',
'tf_efficientnet_b5_ap': 'tf_efficientnet_b5.ap_in1k',
'tf_efficientnet_b6_ap': 'tf_efficientnet_b6.ap_in1k',
'tf_efficientnet_b7_ap': 'tf_efficientnet_b7.ap_in1k',
'tf_efficientnet_b8_ap': 'tf_efficientnet_b8.ap_in1k',
'tf_efficientnet_b0_ns': 'tf_efficientnet_b0.ns_jft_in1k',
'tf_efficientnet_b1_ns': 'tf_efficientnet_b1.ns_jft_in1k',
'tf_efficientnet_b2_ns': 'tf_efficientnet_b2.ns_jft_in1k',
'tf_efficientnet_b3_ns': 'tf_efficientnet_b3.ns_jft_in1k',
'tf_efficientnet_b4_ns': 'tf_efficientnet_b4.ns_jft_in1k',
'tf_efficientnet_b5_ns': 'tf_efficientnet_b5.ns_jft_in1k',
'tf_efficientnet_b6_ns': 'tf_efficientnet_b6.ns_jft_in1k',
'tf_efficientnet_b7_ns': 'tf_efficientnet_b7.ns_jft_in1k',
'tf_efficientnet_l2_ns_475': 'tf_efficientnet_l2.ns_jft_in1k_475',
'tf_efficientnet_l2_ns': 'tf_efficientnet_l2.ns_jft_in1k',
'tf_efficientnetv2_s_in21ft1k': 'tf_efficientnetv2_s.in21k_ft_in1k',
'tf_efficientnetv2_m_in21ft1k': 'tf_efficientnetv2_m.in21k_ft_in1k',
'tf_efficientnetv2_l_in21ft1k': 'tf_efficientnetv2_l.in21k_ft_in1k',
'tf_efficientnetv2_xl_in21ft1k': 'tf_efficientnetv2_xl.in21k_ft_in1k',
'tf_efficientnetv2_s_in21k': 'tf_efficientnetv2_s.in21k',
'tf_efficientnetv2_m_in21k': 'tf_efficientnetv2_m.in21k',
'tf_efficientnetv2_l_in21k': 'tf_efficientnetv2_l.in21k',
'tf_efficientnetv2_xl_in21k': 'tf_efficientnetv2_xl.in21k',
})

View File

@ -0,0 +1,643 @@
""" FocalNet
As described in `Focal Modulation Networks` - https://arxiv.org/abs/2203.11926
Significant modifications and refactoring from the original impl at https://github.com/microsoft/FocalNet
This impl is/has:
* fully convolutional, NCHW tensor layout throughout, seemed to have minimal performance impact but more flexible
* re-ordered downsample / layer so that striding always at beginning of layer (stage)
* no input size constraints or input resolution/H/W tracking through the model
* torchscript fixed and a number of quirks cleaned up
* feature extraction support via `features_only=True`
"""
# --------------------------------------------------------
# FocalNets -- Focal Modulation Networks
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Jianwei Yang (jianwyan@microsoft.com)
# --------------------------------------------------------
from functools import partial
from typing import Callable, Optional, Tuple
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import Mlp, DropPath, LayerNorm2d, trunc_normal_, ClassifierHead, NormMlpClassifierHead
from ._builder import build_model_with_cfg
from ._manipulate import named_apply
from ._registry import generate_default_cfgs, register_model
__all__ = ['FocalNet']
class FocalModulation(nn.Module):
def __init__(
self,
dim: int,
focal_window,
focal_level: int,
focal_factor: int = 2,
bias: bool = True,
use_post_norm: bool = False,
normalize_modulator: bool = False,
proj_drop: float = 0.,
norm_layer: Callable = LayerNorm2d,
):
super().__init__()
self.dim = dim
self.focal_window = focal_window
self.focal_level = focal_level
self.focal_factor = focal_factor
self.use_post_norm = use_post_norm
self.normalize_modulator = normalize_modulator
self.input_split = [dim, dim, self.focal_level + 1]
self.f = nn.Conv2d(dim, 2 * dim + (self.focal_level + 1), kernel_size=1, bias=bias)
self.h = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
self.act = nn.GELU()
self.proj = nn.Conv2d(dim, dim, kernel_size=1)
self.proj_drop = nn.Dropout(proj_drop)
self.focal_layers = nn.ModuleList()
self.kernel_sizes = []
for k in range(self.focal_level):
kernel_size = self.focal_factor * k + self.focal_window
self.focal_layers.append(nn.Sequential(
nn.Conv2d(dim, dim, kernel_size=kernel_size, groups=dim, padding=kernel_size // 2, bias=False),
nn.GELU(),
))
self.kernel_sizes.append(kernel_size)
self.norm = norm_layer(dim) if self.use_post_norm else nn.Identity()
def forward(self, x):
"""
Args:
x: input features with shape of (B, H, W, C)
"""
C = x.shape[1]
# pre linear projection
x = self.f(x)
q, ctx, gates = torch.split(x, self.input_split, 1)
# context aggreation
ctx_all = 0
for l, focal_layer in enumerate(self.focal_layers):
ctx = focal_layer(ctx)
ctx_all = ctx_all + ctx * gates[:, l:l + 1]
ctx_global = self.act(ctx.mean((2, 3), keepdim=True))
ctx_all = ctx_all + ctx_global * gates[:, self.focal_level:]
# normalize context
if self.normalize_modulator:
ctx_all = ctx_all / (self.focal_level + 1)
# focal modulation
x_out = q * self.h(ctx_all)
x_out = self.norm(x_out)
# post linear projection
x_out = self.proj(x_out)
x_out = self.proj_drop(x_out)
return x_out
class LayerScale2d(nn.Module):
def __init__(self, dim, init_values=1e-5, inplace=False):
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x):
gamma = self.gamma.view(1, -1, 1, 1)
return x.mul_(gamma) if self.inplace else x * gamma
class FocalNetBlock(nn.Module):
""" Focal Modulation Network Block.
"""
def __init__(
self,
dim: int,
mlp_ratio: float = 4.,
focal_level: int = 1,
focal_window: int = 3,
use_post_norm: bool = False,
use_post_norm_in_modulation: bool = False,
normalize_modulator: bool = False,
layerscale_value: float = 1e-4,
proj_drop: float = 0.,
drop_path: float = 0.,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm2d,
):
"""
Args:
dim: Number of input channels.
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
focal_level: Number of focal levels.
focal_window: Focal window size at first focal level.
use_post_norm: Whether to use layer norm after modulation.
use_post_norm_in_modulation: Whether to use layer norm in modulation.
layerscale_value: Initial layerscale value.
proj_drop: Dropout rate.
drop_path: Stochastic depth rate.
act_layer: Activation layer.
norm_layer: Normalization layer.
"""
super().__init__()
self.dim = dim
self.mlp_ratio = mlp_ratio
self.focal_window = focal_window
self.focal_level = focal_level
self.use_post_norm = use_post_norm
self.norm1 = norm_layer(dim) if not use_post_norm else nn.Identity()
self.modulation = FocalModulation(
dim,
focal_window=focal_window,
focal_level=self.focal_level,
use_post_norm=use_post_norm_in_modulation,
normalize_modulator=normalize_modulator,
proj_drop=proj_drop,
norm_layer=norm_layer,
)
self.norm1_post = norm_layer(dim) if use_post_norm else nn.Identity()
self.ls1 = LayerScale2d(dim, layerscale_value) if layerscale_value is not None else nn.Identity()
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim) if not use_post_norm else nn.Identity()
self.mlp = Mlp(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
drop=proj_drop,
use_conv=True,
)
self.norm2_post = norm_layer(dim) if use_post_norm else nn.Identity()
self.ls2 = LayerScale2d(dim, layerscale_value) if layerscale_value is not None else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
shortcut = x
# Focal Modulation
x = self.norm1(x)
x = self.modulation(x)
x = self.norm1_post(x)
x = shortcut + self.drop_path1(self.ls1(x))
# FFN
x = x + self.drop_path2(self.ls2(self.norm2_post(self.mlp(self.norm2(x)))))
return x
class FocalNetStage(nn.Module):
""" A basic Focal Transformer layer for one stage.
"""
def __init__(
self,
dim: int,
out_dim: int,
depth: int,
mlp_ratio: float = 4.,
downsample: bool = True,
focal_level: int = 1,
focal_window: int = 1,
use_overlap_down: bool = False,
use_post_norm: bool = False,
use_post_norm_in_modulation: bool = False,
normalize_modulator: bool = False,
layerscale_value: float = 1e-4,
proj_drop: float = 0.,
drop_path: float = 0.,
norm_layer: Callable = LayerNorm2d,
):
"""
Args:
dim: Number of input channels.
out_dim: Number of output channels.
depth: Number of blocks.
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
downsample: Downsample layer at start of the layer.
focal_level: Number of focal levels
focal_window: Focal window size at first focal level
use_overlap_down: User overlapped convolution in downsample layer.
use_post_norm: Whether to use layer norm after modulation.
use_post_norm_in_modulation: Whether to use layer norm in modulation.
layerscale_value: Initial layerscale value
proj_drop: Dropout rate for projections.
drop_path: Stochastic depth rate.
norm_layer: Normalization layer.
"""
super().__init__()
self.dim = dim
self.depth = depth
self.grad_checkpointing = False
if downsample:
self.downsample = Downsample(
in_chs=dim,
out_chs=out_dim,
stride=2,
overlap=use_overlap_down,
norm_layer=norm_layer,
)
else:
self.downsample = nn.Identity()
# build blocks
self.blocks = nn.ModuleList([
FocalNetBlock(
dim=out_dim,
mlp_ratio=mlp_ratio,
focal_level=focal_level,
focal_window=focal_window,
use_post_norm=use_post_norm,
use_post_norm_in_modulation=use_post_norm_in_modulation,
normalize_modulator=normalize_modulator,
layerscale_value=layerscale_value,
proj_drop=proj_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer,
)
for i in range(depth)])
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
def forward(self, x):
x = self.downsample(x)
for blk in self.blocks:
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint.checkpoint(blk, x)
else:
x = blk(x)
return x
class Downsample(nn.Module):
def __init__(
self,
in_chs: int,
out_chs: int,
stride: int = 4,
overlap: bool = False,
norm_layer: Optional[Callable] = None,
):
"""
Args:
in_chs: Number of input image channels.
out_chs: Number of linear projection output channels.
stride: Downsample stride.
overlap: Use overlapping convolutions if True.
norm_layer: Normalization layer.
"""
super().__init__()
self.stride = stride
padding = 0
kernel_size = stride
if overlap:
assert stride in (2, 4)
if stride == 4:
kernel_size, padding = 7, 2
elif stride == 2:
kernel_size, padding = 3, 1
self.proj = nn.Conv2d(in_chs, out_chs, kernel_size=kernel_size, stride=stride, padding=padding)
self.norm = norm_layer(out_chs) if norm_layer is not None else nn.Identity()
def forward(self, x):
x = self.proj(x)
x = self.norm(x)
return x
class FocalNet(nn.Module):
"""" Focal Modulation Networks (FocalNets)
"""
def __init__(
self,
in_chans: int = 3,
num_classes: int = 1000,
global_pool: str = 'avg',
embed_dim: int = 96,
depths: Tuple[int, ...] = (2, 2, 6, 2),
mlp_ratio: float = 4.,
focal_levels: Tuple[int, ...] = (2, 2, 2, 2),
focal_windows: Tuple[int, ...] = (3, 3, 3, 3),
use_overlap_down: bool = False,
use_post_norm: bool = False,
use_post_norm_in_modulation: bool = False,
normalize_modulator: bool = False,
head_hidden_size: Optional[int] = None,
head_init_scale: float = 1.0,
layerscale_value: Optional[float] = None,
drop_rate: bool = 0.,
proj_drop_rate: bool = 0.,
drop_path_rate: bool = 0.1,
norm_layer: Callable = partial(LayerNorm2d, eps=1e-5),
):
"""
Args:
in_chans: Number of input image channels.
num_classes: Number of classes for classification head.
embed_dim: Patch embedding dimension.
depths: Depth of each Focal Transformer layer.
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
focal_levels: How many focal levels at all stages. Note that this excludes the finest-grain level.
focal_windows: The focal window size at all stages.
use_overlap_down: Whether to use convolutional embedding.
use_post_norm: Whether to use layernorm after modulation (it helps stablize training of large models)
layerscale_value: Value for layer scale.
drop_rate: Dropout rate.
drop_path_rate: Stochastic depth rate.
norm_layer: Normalization layer.
"""
super().__init__()
self.num_layers = len(depths)
embed_dim = [embed_dim * (2 ** i) for i in range(self.num_layers)]
self.num_classes = num_classes
self.embed_dim = embed_dim
self.num_features = embed_dim[-1]
self.feature_info = []
self.stem = Downsample(
in_chs=in_chans,
out_chs=embed_dim[0],
overlap=use_overlap_down,
norm_layer=norm_layer,
)
in_dim = embed_dim[0]
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
layers = []
for i_layer in range(self.num_layers):
out_dim = embed_dim[i_layer]
layer = FocalNetStage(
dim=in_dim,
out_dim=out_dim,
depth=depths[i_layer],
mlp_ratio=mlp_ratio,
downsample=i_layer > 0,
focal_level=focal_levels[i_layer],
focal_window=focal_windows[i_layer],
use_overlap_down=use_overlap_down,
use_post_norm=use_post_norm,
use_post_norm_in_modulation=use_post_norm_in_modulation,
normalize_modulator=normalize_modulator,
layerscale_value=layerscale_value,
proj_drop=proj_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
norm_layer=norm_layer,
)
in_dim = out_dim
layers += [layer]
self.feature_info += [dict(num_chs=out_dim, reduction=4 * 2 ** i_layer, module=f'layers.{i_layer}')]
self.layers = nn.Sequential(*layers)
if head_hidden_size:
self.norm = nn.Identity()
self.head = NormMlpClassifierHead(
self.num_features,
num_classes,
hidden_size=head_hidden_size,
pool_type=global_pool,
drop_rate=drop_rate,
norm_layer=norm_layer,
)
else:
self.norm = norm_layer(self.num_features)
self.head = ClassifierHead(
self.num_features,
num_classes,
pool_type=global_pool,
drop_rate=drop_rate
)
named_apply(partial(_init_weights, head_init_scale=head_init_scale), self)
@torch.jit.ignore
def no_weight_decay(self):
return {''}
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
for l in self.layers:
l.set_grad_checkpointing(enable=enable)
@torch.jit.ignore
def get_classifier(self):
return self.head.fc
def reset_classifier(self, num_classes, global_pool=None):
self.head.reset(num_classes, pool_type=global_pool)
def forward_features(self, x):
x = self.stem(x)
x = self.layers(x)
x = self.norm(x)
return x
def forward_head(self, x, pre_logits: bool = False):
return self.head(x, pre_logits=pre_logits)
def forward(self, x):
x = self.forward_features(x)
x = self.forward_head(x)
return x
def _init_weights(module, name=None, head_init_scale=1.0):
if isinstance(module, nn.Conv2d):
trunc_normal_(module.weight, std=.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Linear):
trunc_normal_(module.weight, std=.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
if name and 'head.fc' in name:
module.weight.data.mul_(head_init_scale)
module.bias.data.mul_(head_init_scale)
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': .9, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'stem.proj', 'classifier': 'head.fc',
'license': 'mit', **kwargs
}
default_cfgs = generate_default_cfgs({
"focalnet_tiny_srf.ms_in1k": _cfg(
hf_hub_id='timm/'),
"focalnet_small_srf.ms_in1k": _cfg(
hf_hub_id='timm/'),
"focalnet_base_srf.ms_in1k": _cfg(
hf_hub_id='timm/'),
"focalnet_tiny_lrf.ms_in1k": _cfg(
hf_hub_id='timm/'),
"focalnet_small_lrf.ms_in1k": _cfg(
hf_hub_id='timm/'),
"focalnet_base_lrf.ms_in1k": _cfg(
hf_hub_id='timm/'),
"focalnet_large_fl3.ms_in22k": _cfg(
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=21842),
"focalnet_large_fl4.ms_in22k": _cfg(
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=21842),
"focalnet_xlarge_fl3.ms_in22k": _cfg(
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=21842),
"focalnet_xlarge_fl4.ms_in22k": _cfg(
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=21842),
"focalnet_huge_fl3.ms_in22k": _cfg(
hf_hub_id='timm/',
num_classes=21842),
"focalnet_huge_fl4.ms_in22k": _cfg(
hf_hub_id='timm/',
num_classes=0),
})
def checkpoint_filter_fn(state_dict, model: FocalNet):
state_dict = state_dict.get('model', state_dict)
if 'stem.proj.weight' in state_dict:
return state_dict
import re
out_dict = {}
dest_dict = model.state_dict()
for k, v in state_dict.items():
k = re.sub(r'gamma_([0-9])', r'ls\1.gamma', k)
k = k.replace('patch_embed', 'stem')
k = re.sub(r'layers.(\d+).downsample', lambda x: f'layers.{int(x.group(1)) + 1}.downsample', k)
if 'norm' in k and k not in dest_dict:
k = re.sub(r'norm([0-9])', r'norm\1_post', k)
k = k.replace('ln.', 'norm.')
k = k.replace('head', 'head.fc')
if k in dest_dict and dest_dict[k].numel() == v.numel() and dest_dict[k].shape != v.shape:
v = v.reshape(dest_dict[k].shape)
out_dict[k] = v
return out_dict
def _create_focalnet(variant, pretrained=False, **kwargs):
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1))))
out_indices = kwargs.pop('out_indices', default_out_indices)
model = build_model_with_cfg(
FocalNet, variant, pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
**kwargs)
return model
@register_model
def focalnet_tiny_srf(pretrained=False, **kwargs):
model_kwargs = dict(depths=[2, 2, 6, 2], embed_dim=96, **kwargs)
return _create_focalnet('focalnet_tiny_srf', pretrained=pretrained, **model_kwargs)
@register_model
def focalnet_small_srf(pretrained=False, **kwargs):
model_kwargs = dict(depths=[2, 2, 18, 2], embed_dim=96, **kwargs)
return _create_focalnet('focalnet_small_srf', pretrained=pretrained, **model_kwargs)
@register_model
def focalnet_base_srf(pretrained=False, **kwargs):
model_kwargs = dict(depths=[2, 2, 18, 2], embed_dim=128, **kwargs)
return _create_focalnet('focalnet_base_srf', pretrained=pretrained, **model_kwargs)
@register_model
def focalnet_tiny_lrf(pretrained=False, **kwargs):
model_kwargs = dict(depths=[2, 2, 6, 2], embed_dim=96, focal_levels=[3, 3, 3, 3], **kwargs)
return _create_focalnet('focalnet_tiny_lrf', pretrained=pretrained, **model_kwargs)
@register_model
def focalnet_small_lrf(pretrained=False, **kwargs):
model_kwargs = dict(depths=[2, 2, 18, 2], embed_dim=96, focal_levels=[3, 3, 3, 3], **kwargs)
return _create_focalnet('focalnet_small_lrf', pretrained=pretrained, **model_kwargs)
@register_model
def focalnet_base_lrf(pretrained=False, **kwargs):
model_kwargs = dict(depths=[2, 2, 18, 2], embed_dim=128, focal_levels=[3, 3, 3, 3], **kwargs)
return _create_focalnet('focalnet_base_lrf', pretrained=pretrained, **model_kwargs)
# FocalNet large+ models
@register_model
def focalnet_large_fl3(pretrained=False, **kwargs):
model_kwargs = dict(
depths=[2, 2, 18, 2], embed_dim=192, focal_levels=[3, 3, 3, 3], focal_windows=[5] * 4,
use_post_norm=True, use_overlap_down=True, layerscale_value=1e-4, **kwargs)
return _create_focalnet('focalnet_large_fl3', pretrained=pretrained, **model_kwargs)
@register_model
def focalnet_large_fl4(pretrained=False, **kwargs):
model_kwargs = dict(
depths=[2, 2, 18, 2], embed_dim=192, focal_levels=[4, 4, 4, 4],
use_post_norm=True, use_overlap_down=True, layerscale_value=1e-4, **kwargs)
return _create_focalnet('focalnet_large_fl4', pretrained=pretrained, **model_kwargs)
@register_model
def focalnet_xlarge_fl3(pretrained=False, **kwargs):
model_kwargs = dict(
depths=[2, 2, 18, 2], embed_dim=256, focal_levels=[3, 3, 3, 3], focal_windows=[5] * 4,
use_post_norm=True, use_overlap_down=True, layerscale_value=1e-4, **kwargs)
return _create_focalnet('focalnet_xlarge_fl3', pretrained=pretrained, **model_kwargs)
@register_model
def focalnet_xlarge_fl4(pretrained=False, **kwargs):
model_kwargs = dict(
depths=[2, 2, 18, 2], embed_dim=256, focal_levels=[4, 4, 4, 4],
use_post_norm=True, use_overlap_down=True, layerscale_value=1e-4, **kwargs)
return _create_focalnet('focalnet_xlarge_fl4', pretrained=pretrained, **model_kwargs)
@register_model
def focalnet_huge_fl3(pretrained=False, **kwargs):
model_kwargs = dict(
depths=[2, 2, 18, 2], embed_dim=352, focal_levels=[3, 3, 3, 3], focal_windows=[3] * 4,
use_post_norm=True, use_post_norm_in_modulation=True, use_overlap_down=True, layerscale_value=1e-4, **kwargs)
return _create_focalnet('focalnet_huge_fl3', pretrained=pretrained, **model_kwargs)
@register_model
def focalnet_huge_fl4(pretrained=False, **kwargs):
model_kwargs = dict(
depths=[2, 2, 18, 2], embed_dim=352, focal_levels=[4, 4, 4, 4],
use_post_norm=True, use_post_norm_in_modulation=True, use_overlap_down=True, layerscale_value=1e-4, **kwargs)
return _create_focalnet('focalnet_huge_fl4', pretrained=pretrained, **model_kwargs)

View File

@ -29,12 +29,11 @@ import torch.utils.checkpoint as checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, to_2tuple, to_ntuple, Mlp, ClassifierHead, LayerNorm2d, \
get_attn, get_act_layer, get_norm_layer, _assert
get_attn, get_act_layer, get_norm_layer, RelPosBias, _assert
from ._builder import build_model_with_cfg
from ._features_fx import register_notrace_function
from ._manipulate import named_apply
from ._registry import register_model
from .vision_transformer_relpos import RelPosBias # FIXME move to common location
__all__ = ['GlobalContextVit']
@ -222,7 +221,7 @@ class WindowAttentionGlobal(nn.Module):
q, k, v = qkv.unbind(0)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
attn = q @ k.transpose(-2, -1).contiguous() # NOTE contiguous() fixes an odd jit bug in PyTorch 2.0
attn = self.rel_pos(attn)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)

View File

@ -24,7 +24,6 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
# Copyright 2020 Ross Wightman, Apache-2.0 License
from collections import OrderedDict
from dataclasses import dataclass
from functools import partial
from typing import Dict
@ -35,9 +34,7 @@ from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN
from timm.layers import to_ntuple, to_2tuple, get_act_layer, DropPath, trunc_normal_
from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq
from ._pretrained import generate_default_cfgs
from ._registry import register_model
from ._registry import generate_default_cfgs, register_model
__all__ = ['Levit']

View File

@ -52,8 +52,7 @@ from timm.layers import RelPosMlp, RelPosBias, RelPosBiasTf
from ._builder import build_model_with_cfg
from ._features_fx import register_notrace_function
from ._manipulate import named_apply, checkpoint_seq
from ._pretrained import generate_default_cfgs
from ._registry import register_model
from ._registry import generate_default_cfgs, register_model
__all__ = ['MaxxVitCfg', 'MaxxVitConvCfg', 'MaxxVitTransformerCfg', 'MaxxVit']

View File

@ -22,8 +22,7 @@ from ._efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficie
round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
from ._features import FeatureInfo, FeatureHooks
from ._manipulate import checkpoint_seq
from ._pretrained import generate_default_cfgs
from ._registry import register_model
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
__all__ = ['MobileNetV3', 'MobileNetV3Features']
@ -145,13 +144,12 @@ class MobileNetV3(nn.Module):
x = self.global_pool(x)
x = self.conv_head(x)
x = self.act2(x)
x = self.flatten(x)
if pre_logits:
return x.flatten(1)
else:
x = self.flatten(x)
if self.drop_rate > 0.:
x = F.dropout(x, p=self.drop_rate, training=self.training)
return self.classifier(x)
return x
if self.drop_rate > 0.:
x = F.dropout(x, p=self.drop_rate, training=self.training)
return self.classifier(x)
def forward(self, x):
x = self.forward_features(x)
@ -797,3 +795,9 @@ def lcnet_150(pretrained=False, **kwargs):
""" PP-LCNet 1.5"""
model = _gen_lcnet('lcnet_150', 1.5, pretrained=pretrained, **kwargs)
return model
register_model_deprecations(__name__, {
'mobilenetv3_large_100_miil': 'mobilenetv3_large_100.miil_in21k_ft_in1k',
'mobilenetv3_large_100_miil_in21k': 'mobilenetv3_large_100.miil_in21k',
})

View File

@ -17,86 +17,24 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W
# --------------------------------------------------------
import logging
import math
from typing import Optional
from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, to_ntuple, trunc_normal_, _assert
from timm.layers import PatchEmbed, Mlp, DropPath, ClassifierHead, to_2tuple, to_ntuple, trunc_normal_, _assert
from ._builder import build_model_with_cfg
from ._features_fx import register_notrace_function
from ._manipulate import checkpoint_seq, named_apply
from ._registry import register_model
from .vision_transformer import checkpoint_filter_fn, get_init_weights_vit
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
from .vision_transformer import get_init_weights_vit
__all__ = ['SwinTransformer'] # model_registry will add each entrypoint fn to this
_logger = logging.getLogger(__name__)
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'patch_embed.proj', 'classifier': 'head',
**kwargs
}
default_cfgs = {
'swin_base_patch4_window12_384': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22kto1k.pth',
input_size=(3, 384, 384), crop_pct=1.0),
'swin_base_patch4_window7_224': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22kto1k.pth',
),
'swin_large_patch4_window12_384': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22kto1k.pth',
input_size=(3, 384, 384), crop_pct=1.0),
'swin_large_patch4_window7_224': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22kto1k.pth',
),
'swin_small_patch4_window7_224': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth',
),
'swin_tiny_patch4_window7_224': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth',
),
'swin_base_patch4_window12_384_in22k': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth',
input_size=(3, 384, 384), crop_pct=1.0, num_classes=21841),
'swin_base_patch4_window7_224_in22k': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth',
num_classes=21841),
'swin_large_patch4_window12_384_in22k': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth',
input_size=(3, 384, 384), crop_pct=1.0, num_classes=21841),
'swin_large_patch4_window7_224_in22k': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth',
num_classes=21841),
'swin_s3_tiny_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/s3_t-1d53f6a8.pth'
),
'swin_s3_small_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/s3_s-3bb4c69d.pth'
),
'swin_s3_base_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/s3_b-a1e95db4.pth'
)
}
_int_or_tuple_2_t = Union[int, Tuple[int, int]]
def window_partition(x, window_size: int):
@ -132,7 +70,7 @@ def window_reverse(windows, window_size: int, H: int, W: int):
return x
def get_relative_position_index(win_h, win_w):
def get_relative_position_index(win_h: int, win_w: int):
# get pair-wise relative position index for each token inside the window
coords = torch.stack(torch.meshgrid([torch.arange(win_h), torch.arange(win_w)])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
@ -145,21 +83,30 @@ def get_relative_position_index(win_h, win_w):
class WindowAttention(nn.Module):
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
head_dim (int): Number of channels per head (dim // num_heads if not set)
window_size (tuple[int]): The height and width of the window.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports shifted and non-shifted windows.
"""
def __init__(self, dim, num_heads, head_dim=None, window_size=7, qkv_bias=True, attn_drop=0., proj_drop=0.):
def __init__(
self,
dim: int,
num_heads: int,
head_dim: Optional[int] = None,
window_size: _int_or_tuple_2_t = 7,
qkv_bias: bool = True,
attn_drop: float = 0.,
proj_drop: float = 0.,
):
"""
Args:
dim: Number of input channels.
num_heads: Number of attention heads.
head_dim: Number of channels per head (dim // num_heads if not set)
window_size: The height and width of the window.
qkv_bias: If True, add a learnable bias to query, key, value.
attn_drop: Dropout ratio of attention weight.
proj_drop: Dropout ratio of output.
"""
super().__init__()
self.dim = dim
self.window_size = to_2tuple(window_size) # Wh, Ww
@ -198,7 +145,7 @@ class WindowAttention(nn.Module):
"""
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
q, k, v = qkv.unbind(0)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
@ -206,7 +153,7 @@ class WindowAttention(nn.Module):
if mask is not None:
num_win = mask.shape[0]
attn = attn.view(B_ // num_win, num_win, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, num_win, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
@ -221,28 +168,41 @@ class WindowAttention(nn.Module):
class SwinTransformerBlock(nn.Module):
r""" Swin Transformer Block.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion.
window_size (int): Window size.
num_heads (int): Number of attention heads.
head_dim (int): Enforce the number of channels per head
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
""" Swin Transformer Block.
"""
def __init__(
self, dim, input_resolution, num_heads=4, head_dim=None, window_size=7, shift_size=0,
mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
self,
dim: int,
input_resolution: _int_or_tuple_2_t,
num_heads: int = 4,
head_dim: Optional[int] = None,
window_size: _int_or_tuple_2_t = 7,
shift_size: int = 0,
mlp_ratio: float = 4.,
qkv_bias: bool = True,
drop: float = 0.,
attn_drop: float = 0.,
drop_path: float = 0.,
act_layer: Callable = nn.GELU,
norm_layer: Callable = nn.LayerNorm,
):
"""
Args:
dim: Number of input channels.
input_resolution: Input resolution.
window_size: Window size.
num_heads: Number of attention heads.
head_dim: Enforce the number of channels per head
shift_size: Shift size for SW-MSA.
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
qkv_bias: If True, add a learnable bias to query, key, value.
drop: Dropout rate.
attn_drop: Attention dropout rate.
drop_path: Stochastic depth rate.
act_layer: Activation layer.
norm_layer: Normalization layer.
"""
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
@ -257,12 +217,23 @@ class SwinTransformerBlock(nn.Module):
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim, num_heads=num_heads, head_dim=head_dim, window_size=to_2tuple(self.window_size),
qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
dim,
num_heads=num_heads,
head_dim=head_dim,
window_size=to_2tuple(self.window_size),
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=drop,
)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
self.mlp = Mlp(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
drop=drop,
)
if self.shift_size > 0:
# calculate attention mask for SW-MSA
@ -285,17 +256,15 @@ class SwinTransformerBlock(nn.Module):
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
else:
attn_mask = None
self.register_buffer("attn_mask", attn_mask)
def forward(self, x):
H, W = self.input_resolution
B, L, C = x.shape
_assert(L == H * W, "input feature has wrong size")
B, H, W, C = x.shape
_assert(H == self.input_resolution[0], "input feature has wrong size")
_assert(W == self.input_resolution[1], "input feature has wrong size")
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# cyclic shift
if self.shift_size > 0:
@ -319,176 +288,221 @@ class SwinTransformerBlock(nn.Module):
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
x = x.view(B, H * W, C)
# FFN
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
x = x.reshape(B, -1, C)
x = x + self.drop_path(self.mlp(self.norm2(x)))
x = x.reshape(B, H, W, C)
return x
class PatchMerging(nn.Module):
r""" Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
""" Patch Merging Layer.
"""
def __init__(self, input_resolution, dim, out_dim=None, norm_layer=nn.LayerNorm):
def __init__(
self,
dim: int,
out_dim: Optional[int] = None,
norm_layer: Callable = nn.LayerNorm,
):
"""
Args:
dim: Number of input channels.
out_dim: Number of output channels (or 2 * dim if None)
norm_layer: Normalization layer.
"""
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.out_dim = out_dim or 2 * dim
self.norm = norm_layer(4 * dim)
self.reduction = nn.Linear(4 * dim, self.out_dim, bias=False)
def forward(self, x):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
B, L, C = x.shape
_assert(L == H * W, "input feature has wrong size")
_assert(H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even.")
x = x.view(B, H, W, C)
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
B, H, W, C = x.shape
_assert(H % 2 == 0, f"x height ({H}) is not even.")
_assert(W % 2 == 0, f"x width ({W}) is not even.")
x = x.reshape(B, H // 2, 2, W // 2, 2, C).permute(0, 1, 3, 4, 2, 5).flatten(3)
x = self.norm(x)
x = self.reduction(x)
return x
class BasicLayer(nn.Module):
class SwinTransformerStage(nn.Module):
""" A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
head_dim (int): Channels per head (dim // num_heads if not set)
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
"""
def __init__(
self, dim, out_dim, input_resolution, depth, num_heads=4, head_dim=None,
window_size=7, mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, downsample=None):
self,
dim: int,
out_dim: int,
input_resolution: Tuple[int, int],
depth: int,
downsample: bool = True,
num_heads: int = 4,
head_dim: Optional[int] = None,
window_size: _int_or_tuple_2_t = 7,
mlp_ratio: float = 4.,
qkv_bias: bool = True,
drop: float = 0.,
attn_drop: float = 0.,
drop_path: Union[List[float], float] = 0.,
norm_layer: Callable = nn.LayerNorm,
):
"""
Args:
dim: Number of input channels.
input_resolution: Input resolution.
depth: Number of blocks.
downsample: Downsample layer at the end of the layer.
num_heads: Number of attention heads.
head_dim: Channels per head (dim // num_heads if not set)
window_size: Local window size.
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
qkv_bias: If True, add a learnable bias to query, key, value.
drop: Dropout rate.
attn_drop: Attention dropout rate.
drop_path: Stochastic depth rate.
norm_layer: Normalization layer.
"""
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.output_resolution = tuple(i // 2 for i in input_resolution) if downsample else input_resolution
self.depth = depth
self.grad_checkpointing = False
# patch merging layer
if downsample:
self.downsample = PatchMerging(
dim=dim,
out_dim=out_dim,
norm_layer=norm_layer,
)
else:
assert dim == out_dim
self.downsample = nn.Identity()
# build blocks
self.blocks = nn.Sequential(*[
SwinTransformerBlock(
dim=dim, input_resolution=input_resolution, num_heads=num_heads, head_dim=head_dim,
window_size=window_size, shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop, attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer)
dim=out_dim,
input_resolution=self.output_resolution,
num_heads=num_heads,
head_dim=head_dim,
window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer,
)
for i in range(depth)])
# patch merging layer
if downsample is not None:
self.downsample = downsample(input_resolution, dim=dim, out_dim=out_dim, norm_layer=norm_layer)
else:
self.downsample = None
def forward(self, x):
x = self.downsample(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x)
else:
x = self.blocks(x)
if self.downsample is not None:
x = self.downsample(x)
return x
class SwinTransformer(nn.Module):
r""" Swin Transformer
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
https://arxiv.org/pdf/2103.14030
""" Swin Transformer
Args:
img_size (int | tuple(int)): Input image size. Default 224
patch_size (int | tuple(int)): Patch size. Default: 4
in_chans (int): Number of input image channels. Default: 3
num_classes (int): Number of classes for classification head. Default: 1000
embed_dim (int): Patch embedding dimension. Default: 96
depths (tuple(int)): Depth of each Swin Transformer layer.
num_heads (tuple(int)): Number of attention heads in different layers.
head_dim (int, tuple(int)):
window_size (int): Window size. Default: 7
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
drop_rate (float): Dropout rate. Default: 0
attn_drop_rate (float): Attention dropout rate. Default: 0
drop_path_rate (float): Stochastic depth rate. Default: 0.1
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
patch_norm (bool): If True, add normalization after patch embedding. Default: True
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
https://arxiv.org/pdf/2103.14030
"""
def __init__(
self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, global_pool='avg',
embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), head_dim=None,
window_size=7, mlp_ratio=4., qkv_bias=True,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm, ape=False, patch_norm=True, weight_init='', **kwargs):
self,
img_size: _int_or_tuple_2_t = 224,
patch_size: int = 4,
in_chans: int = 3,
num_classes: int = 1000,
global_pool: str = 'avg',
embed_dim: int = 96,
depths: Tuple[int, ...] = (2, 2, 6, 2),
num_heads: Tuple[int, ...] = (3, 6, 12, 24),
head_dim: Optional[int] = None,
window_size: _int_or_tuple_2_t = 7,
mlp_ratio: float = 4.,
qkv_bias: bool = True,
drop_rate: float = 0.,
attn_drop_rate: float = 0.,
drop_path_rate: float = 0.1,
norm_layer: Union[str, Callable] = nn.LayerNorm,
weight_init: str = '',
**kwargs,
):
"""
Args:
img_size: Input image size.
patch_size: Patch size.
in_chans: Number of input image channels.
num_classes: Number of classes for classification head.
embed_dim: Patch embedding dimension.
depths: Depth of each Swin Transformer layer.
num_heads: Number of attention heads in different layers.
head_dim: Dimension of self-attention heads.
window_size: Window size.
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
qkv_bias: If True, add a learnable bias to query, key, value.
drop_rate: Dropout rate.
attn_drop_rate (float): Attention dropout rate.
drop_path_rate (float): Stochastic depth rate.
norm_layer (nn.Module): Normalization layer.
"""
super().__init__()
assert global_pool in ('', 'avg')
self.num_classes = num_classes
self.global_pool = global_pool
self.output_fmt = 'NHWC'
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
self.feature_info = []
if not isinstance(embed_dim, (tuple, list)):
embed_dim = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
# split image into non-overlapping patches
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
norm_layer=norm_layer if patch_norm else None)
num_patches = self.patch_embed.num_patches
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim[0],
norm_layer=norm_layer,
output_fmt='NHWC',
)
self.patch_grid = self.patch_embed.grid_size
# absolute position embedding
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) if ape else None
self.pos_drop = nn.Dropout(p=drop_rate)
# build layers
if not isinstance(embed_dim, (tuple, list)):
embed_dim = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
embed_out_dim = embed_dim[1:] + [None]
head_dim = to_ntuple(self.num_layers)(head_dim)
window_size = to_ntuple(self.num_layers)(window_size)
mlp_ratio = to_ntuple(self.num_layers)(mlp_ratio)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
layers = []
in_dim = embed_dim[0]
scale = 1
for i in range(self.num_layers):
layers += [BasicLayer(
dim=embed_dim[i],
out_dim=embed_out_dim[i],
input_resolution=(self.patch_grid[0] // (2 ** i), self.patch_grid[1] // (2 ** i)),
out_dim = embed_dim[i]
layers += [SwinTransformerStage(
dim=in_dim,
out_dim=out_dim,
input_resolution=(
self.patch_grid[0] // scale,
self.patch_grid[1] // scale
),
depth=depths[i],
downsample=i > 0,
num_heads=num_heads[i],
head_dim=head_dim[i],
window_size=window_size[i],
@ -496,29 +510,35 @@ class SwinTransformer(nn.Module):
qkv_bias=qkv_bias,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i]):sum(depths[:i + 1])],
drop_path=dpr[i],
norm_layer=norm_layer,
downsample=PatchMerging if (i < self.num_layers - 1) else None
)]
in_dim = out_dim
if i > 0:
scale *= 2
self.feature_info += [dict(num_chs=out_dim, reduction=4 * scale, module=f'layers.{i}')]
self.layers = nn.Sequential(*layers)
self.norm = norm_layer(self.num_features)
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
self.head = ClassifierHead(
self.num_features,
num_classes,
pool_type=global_pool,
drop_rate=drop_rate,
input_fmt=self.output_fmt,
)
if weight_init != 'skip':
self.init_weights(weight_init)
@torch.jit.ignore
def init_weights(self, mode=''):
assert mode in ('jax', 'jax_nlhb', 'moco', '')
if self.absolute_pos_embed is not None:
trunc_normal_(self.absolute_pos_embed, std=.02)
head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
named_apply(get_init_weights_vit(mode, head_bias=head_bias), self)
@torch.jit.ignore
def no_weight_decay(self):
nwd = {'absolute_pos_embed'}
nwd = set()
for n, _ in self.named_parameters():
if 'relative_position_bias_table' in n:
nwd.add(n)
@ -527,7 +547,7 @@ class SwinTransformer(nn.Module):
@torch.jit.ignore
def group_matcher(self, coarse=False):
return dict(
stem=r'^absolute_pos_embed|patch_embed', # stem and embed
stem=r'^patch_embed', # stem and embed
blocks=r'^layers\.(\d+)' if coarse else [
(r'^layers\.(\d+).downsample', (0,)),
(r'^layers\.(\d+)\.\w+\.(\d+)', None),
@ -542,28 +562,20 @@ class SwinTransformer(nn.Module):
@torch.jit.ignore
def get_classifier(self):
return self.head
return self.head.fc
def reset_classifier(self, num_classes, global_pool=None):
self.num_classes = num_classes
if global_pool is not None:
assert global_pool in ('', 'avg')
self.global_pool = global_pool
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
self.head.reset(num_classes, pool_type=global_pool)
def forward_features(self, x):
x = self.patch_embed(x)
if self.absolute_pos_embed is not None:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
x = self.layers(x)
x = self.norm(x) # B L C
x = self.norm(x)
return x
def forward_head(self, x, pre_logits: bool = False):
if self.global_pool == 'avg':
x = x.mean(dim=1)
return x if pre_logits else self.head(x)
return self.head(x, pre_logits=True) if pre_logits else self.head(x)
def forward(self, x):
x = self.forward_features(x)
@ -571,58 +583,118 @@ class SwinTransformer(nn.Module):
return x
def checkpoint_filter_fn(state_dict, model):
""" convert patch embedding weight from manual patchify + linear proj to conv"""
if 'head.fc.weight' in state_dict:
return state_dict
import re
out_dict = {}
state_dict = state_dict.get('model', state_dict)
state_dict = state_dict.get('state_dict', state_dict)
for k, v in state_dict.items():
k = re.sub(r'layers.(\d+).downsample', lambda x: f'layers.{int(x.group(1)) + 1}.downsample', k)
k = k.replace('head.', 'head.fc.')
out_dict[k] = v
return out_dict
def _create_swin_transformer(variant, pretrained=False, **kwargs):
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1))))
out_indices = kwargs.pop('out_indices', default_out_indices)
model = build_model_with_cfg(
SwinTransformer, variant, pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
**kwargs)
return model
@register_model
def swin_base_patch4_window12_384(pretrained=False, **kwargs):
""" Swin-B @ 384x384, pretrained ImageNet-22k, fine tune 1k
"""
model_kwargs = dict(
patch_size=4, window_size=12, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs)
return _create_swin_transformer('swin_base_patch4_window12_384', pretrained=pretrained, **model_kwargs)
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'patch_embed.proj', 'classifier': 'head.fc',
'license': 'mit', **kwargs
}
@register_model
def swin_base_patch4_window7_224(pretrained=False, **kwargs):
""" Swin-B @ 224x224, pretrained ImageNet-22k, fine tune 1k
"""
model_kwargs = dict(
patch_size=4, window_size=7, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs)
return _create_swin_transformer('swin_base_patch4_window7_224', pretrained=pretrained, **model_kwargs)
default_cfgs = generate_default_cfgs({
'swin_small_patch4_window7_224.ms_in22k_ft_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.8/swin_small_patch4_window7_224_22kto1k_finetune.pth', ),
'swin_base_patch4_window7_224.ms_in22k_ft_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22kto1k.pth',),
'swin_base_patch4_window12_384.ms_in22k_ft_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22kto1k.pth',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
'swin_large_patch4_window7_224.ms_in22k_ft_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22kto1k.pth',),
'swin_large_patch4_window12_384.ms_in22k_ft_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22kto1k.pth',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
'swin_tiny_patch4_window7_224.ms_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth',),
'swin_small_patch4_window7_224.ms_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth',),
'swin_base_patch4_window7_224.ms_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth',),
'swin_base_patch4_window12_384.ms_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384.pth',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
@register_model
def swin_large_patch4_window12_384(pretrained=False, **kwargs):
""" Swin-L @ 384x384, pretrained ImageNet-22k, fine tune 1k
"""
model_kwargs = dict(
patch_size=4, window_size=12, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs)
return _create_swin_transformer('swin_large_patch4_window12_384', pretrained=pretrained, **model_kwargs)
# tiny 22k pretrain is worse than 1k, so moved after (untagged priority is based on order)
'swin_tiny_patch4_window7_224.ms_in22k_ft_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.8/swin_tiny_patch4_window7_224_22kto1k_finetune.pth',),
'swin_tiny_patch4_window7_224.ms_in22k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.8/swin_tiny_patch4_window7_224_22k.pth',
num_classes=21841),
'swin_small_patch4_window7_224.ms_in22k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.8/swin_small_patch4_window7_224_22k.pth',
num_classes=21841),
'swin_base_patch4_window7_224.ms_in22k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth',
num_classes=21841),
'swin_base_patch4_window12_384.ms_in22k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=21841),
'swin_large_patch4_window7_224.ms_in22k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth',
num_classes=21841),
'swin_large_patch4_window12_384.ms_in22k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=21841),
@register_model
def swin_large_patch4_window7_224(pretrained=False, **kwargs):
""" Swin-L @ 224x224, pretrained ImageNet-22k, fine tune 1k
"""
model_kwargs = dict(
patch_size=4, window_size=7, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs)
return _create_swin_transformer('swin_large_patch4_window7_224', pretrained=pretrained, **model_kwargs)
@register_model
def swin_small_patch4_window7_224(pretrained=False, **kwargs):
""" Swin-S @ 224x224, trained ImageNet-1k
"""
model_kwargs = dict(
patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 18, 2), num_heads=(3, 6, 12, 24), **kwargs)
return _create_swin_transformer('swin_small_patch4_window7_224', pretrained=pretrained, **model_kwargs)
'swin_s3_tiny_224.ms_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/s3_t-1d53f6a8.pth'),
'swin_s3_small_224.ms_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/s3_s-3bb4c69d.pth'),
'swin_s3_base_224.ms_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/s3_b-a1e95db4.pth'),
})
@register_model
@ -635,44 +707,53 @@ def swin_tiny_patch4_window7_224(pretrained=False, **kwargs):
@register_model
def swin_base_patch4_window12_384_in22k(pretrained=False, **kwargs):
""" Swin-B @ 384x384, trained ImageNet-22k
def swin_small_patch4_window7_224(pretrained=False, **kwargs):
""" Swin-S @ 224x224
"""
model_kwargs = dict(
patch_size=4, window_size=12, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs)
return _create_swin_transformer('swin_base_patch4_window12_384_in22k', pretrained=pretrained, **model_kwargs)
patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 18, 2), num_heads=(3, 6, 12, 24), **kwargs)
return _create_swin_transformer('swin_small_patch4_window7_224', pretrained=pretrained, **model_kwargs)
@register_model
def swin_base_patch4_window7_224_in22k(pretrained=False, **kwargs):
""" Swin-B @ 224x224, trained ImageNet-22k
def swin_base_patch4_window7_224(pretrained=False, **kwargs):
""" Swin-B @ 224x224
"""
model_kwargs = dict(
patch_size=4, window_size=7, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs)
return _create_swin_transformer('swin_base_patch4_window7_224_in22k', pretrained=pretrained, **model_kwargs)
return _create_swin_transformer('swin_base_patch4_window7_224', pretrained=pretrained, **model_kwargs)
@register_model
def swin_large_patch4_window12_384_in22k(pretrained=False, **kwargs):
""" Swin-L @ 384x384, trained ImageNet-22k
def swin_base_patch4_window12_384(pretrained=False, **kwargs):
""" Swin-B @ 384x384
"""
model_kwargs = dict(
patch_size=4, window_size=12, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs)
return _create_swin_transformer('swin_large_patch4_window12_384_in22k', pretrained=pretrained, **model_kwargs)
patch_size=4, window_size=12, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs)
return _create_swin_transformer('swin_base_patch4_window12_384', pretrained=pretrained, **model_kwargs)
@register_model
def swin_large_patch4_window7_224_in22k(pretrained=False, **kwargs):
""" Swin-L @ 224x224, trained ImageNet-22k
def swin_large_patch4_window7_224(pretrained=False, **kwargs):
""" Swin-L @ 224x224
"""
model_kwargs = dict(
patch_size=4, window_size=7, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs)
return _create_swin_transformer('swin_large_patch4_window7_224_in22k', pretrained=pretrained, **model_kwargs)
return _create_swin_transformer('swin_large_patch4_window7_224', pretrained=pretrained, **model_kwargs)
@register_model
def swin_large_patch4_window12_384(pretrained=False, **kwargs):
""" Swin-L @ 384x384
"""
model_kwargs = dict(
patch_size=4, window_size=12, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs)
return _create_swin_transformer('swin_large_patch4_window12_384', pretrained=pretrained, **model_kwargs)
@register_model
def swin_s3_tiny_224(pretrained=False, **kwargs):
""" Swin-S3-T @ 224x224, ImageNet-1k. https://arxiv.org/abs/2111.14725
""" Swin-S3-T @ 224x224, https://arxiv.org/abs/2111.14725
"""
model_kwargs = dict(
patch_size=4, window_size=(7, 7, 14, 7), embed_dim=96, depths=(2, 2, 6, 2),
@ -682,7 +763,7 @@ def swin_s3_tiny_224(pretrained=False, **kwargs):
@register_model
def swin_s3_small_224(pretrained=False, **kwargs):
""" Swin-S3-S @ 224x224, trained ImageNet-1k. https://arxiv.org/abs/2111.14725
""" Swin-S3-S @ 224x224, https://arxiv.org/abs/2111.14725
"""
model_kwargs = dict(
patch_size=4, window_size=(14, 14, 14, 7), embed_dim=96, depths=(2, 2, 18, 2),
@ -692,10 +773,17 @@ def swin_s3_small_224(pretrained=False, **kwargs):
@register_model
def swin_s3_base_224(pretrained=False, **kwargs):
""" Swin-S3-B @ 224x224, trained ImageNet-1k. https://arxiv.org/abs/2111.14725
""" Swin-S3-B @ 224x224, https://arxiv.org/abs/2111.14725
"""
model_kwargs = dict(
patch_size=4, window_size=(7, 7, 14, 7), embed_dim=96, depths=(2, 2, 30, 2),
num_heads=(3, 6, 12, 24), **kwargs)
return _create_swin_transformer('swin_s3_base_224', pretrained=pretrained, **model_kwargs)
register_model_deprecations(__name__, {
'swin_base_patch4_window7_224_in22k': 'swin_base_patch4_window7_224.ms_in22k',
'swin_base_patch4_window12_384_in22k': 'swin_base_patch4_window12_384.ms_in22k',
'swin_large_patch4_window7_224_in22k': 'swin_large_patch4_window7_224.ms_in22k',
'swin_large_patch4_window12_384_in22k': 'swin_large_patch4_window12_384.ms_in22k',
})

View File

@ -13,7 +13,7 @@ Modifications and additions for timm hacked together by / Copyright 2022, Ross W
# Written by Ze Liu
# --------------------------------------------------------
import math
from typing import Tuple, Optional
from typing import Callable, Optional, Tuple, Union
import torch
import torch.nn as nn
@ -21,76 +21,14 @@ import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, _assert
from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, _assert, ClassifierHead
from ._builder import build_model_with_cfg
from ._features_fx import register_notrace_function
from ._registry import register_model
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
__all__ = ['SwinTransformerV2'] # model_registry will add each entrypoint fn to this
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'patch_embed.proj', 'classifier': 'head',
**kwargs
}
default_cfgs = {
'swinv2_tiny_window8_256': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window8_256.pth',
input_size=(3, 256, 256)
),
'swinv2_tiny_window16_256': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window16_256.pth',
input_size=(3, 256, 256)
),
'swinv2_small_window8_256': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_small_patch4_window8_256.pth',
input_size=(3, 256, 256)
),
'swinv2_small_window16_256': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_small_patch4_window16_256.pth',
input_size=(3, 256, 256)
),
'swinv2_base_window8_256': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window8_256.pth',
input_size=(3, 256, 256)
),
'swinv2_base_window16_256': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window16_256.pth',
input_size=(3, 256, 256)
),
'swinv2_base_window12_192_22k': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12_192_22k.pth',
num_classes=21841, input_size=(3, 192, 192)
),
'swinv2_base_window12to16_192to256_22kft1k': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12to16_192to256_22kto1k_ft.pth',
input_size=(3, 256, 256)
),
'swinv2_base_window12to24_192to384_22kft1k': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12to24_192to384_22kto1k_ft.pth',
input_size=(3, 384, 384), crop_pct=1.0,
),
'swinv2_large_window12_192_22k': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12_192_22k.pth',
num_classes=21841, input_size=(3, 192, 192)
),
'swinv2_large_window12to16_192to256_22kft1k': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12to16_192to256_22kto1k_ft.pth',
input_size=(3, 256, 256)
),
'swinv2_large_window12to24_192to384_22kft1k': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12to24_192to384_22kto1k_ft.pth',
input_size=(3, 384, 384), crop_pct=1.0,
),
}
_int_or_tuple_2_t = Union[int, Tuple[int, int]]
def window_partition(x, window_size: Tuple[int, int]):
@ -141,9 +79,15 @@ class WindowAttention(nn.Module):
"""
def __init__(
self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
pretrained_window_size=[0, 0]):
self,
dim,
window_size,
num_heads,
qkv_bias=True,
attn_drop=0.,
proj_drop=0.,
pretrained_window_size=[0, 0],
):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
@ -231,8 +175,8 @@ class WindowAttention(nn.Module):
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
num_win = mask.shape[0]
attn = attn.view(-1, num_win, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
@ -246,29 +190,42 @@ class WindowAttention(nn.Module):
return x
class SwinTransformerBlock(nn.Module):
r""" Swin Transformer Block.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
pretrained_window_size (int): Window size in pretraining.
class SwinTransformerV2Block(nn.Module):
""" Swin Transformer Block.
"""
def __init__(
self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm, pretrained_window_size=0):
self,
dim,
input_resolution,
num_heads,
window_size=7,
shift_size=0,
mlp_ratio=4.,
qkv_bias=True,
drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
pretrained_window_size=0,
):
"""
Args:
dim: Number of input channels.
input_resolution: Input resolution.
num_heads: Number of attention heads.
window_size: Window size.
shift_size: Shift size for SW-MSA.
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
qkv_bias: If True, add a learnable bias to query, key, value.
drop: Dropout rate.
attn_drop: Attention dropout rate.
drop_path: Stochastic depth rate.
act_layer: Activation layer.
norm_layer: Normalization layer.
pretrained_window_size: Window size in pretraining.
"""
super().__init__()
self.dim = dim
self.input_resolution = to_2tuple(input_resolution)
@ -280,13 +237,23 @@ class SwinTransformerBlock(nn.Module):
self.mlp_ratio = mlp_ratio
self.attn = WindowAttention(
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
pretrained_window_size=to_2tuple(pretrained_window_size))
dim,
window_size=to_2tuple(self.window_size),
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=drop,
pretrained_window_size=to_2tuple(pretrained_window_size),
)
self.norm1 = norm_layer(dim)
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
self.mlp = Mlp(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
drop=drop,
)
self.norm2 = norm_layer(dim)
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
@ -322,10 +289,7 @@ class SwinTransformerBlock(nn.Module):
return tuple(window_size), tuple(shift_size)
def _attn(self, x):
H, W = self.input_resolution
B, L, C = x.shape
_assert(L == H * W, "input feature has wrong size")
x = x.view(B, H, W, C)
B, H, W, C = x.shape
# cyclic shift
has_shift = any(self.shift_size)
@ -350,113 +314,124 @@ class SwinTransformerBlock(nn.Module):
x = torch.roll(shifted_x, shifts=self.shift_size, dims=(1, 2))
else:
x = shifted_x
x = x.view(B, H * W, C)
return x
def forward(self, x):
B, H, W, C = x.shape
x = x + self.drop_path1(self.norm1(self._attn(x)))
x = x.reshape(B, -1, C)
x = x + self.drop_path2(self.norm2(self.mlp(x)))
x = x.reshape(B, H, W, C)
return x
class PatchMerging(nn.Module):
r""" Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
""" Patch Merging Layer.
"""
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
def __init__(self, dim, out_dim=None, norm_layer=nn.LayerNorm):
"""
Args:
dim (int): Number of input channels.
out_dim (int): Number of output channels (or 2 * dim if None)
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(2 * dim)
self.out_dim = out_dim or 2 * dim
self.reduction = nn.Linear(4 * dim, self.out_dim, bias=False)
self.norm = norm_layer(self.out_dim)
def forward(self, x):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
B, L, C = x.shape
_assert(L == H * W, "input feature has wrong size")
_assert(H % 2 == 0, f"x size ({H}*{W}) are not even.")
_assert(W % 2 == 0, f"x size ({H}*{W}) are not even.")
x = x.view(B, H, W, C)
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
B, H, W, C = x.shape
_assert(H % 2 == 0, f"x height ({H}) is not even.")
_assert(W % 2 == 0, f"x width ({W}) is not even.")
x = x.reshape(B, H // 2, 2, W // 2, 2, C).permute(0, 1, 3, 4, 2, 5).flatten(3)
x = self.reduction(x)
x = self.norm(x)
return x
class BasicLayer(nn.Module):
""" A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
pretrained_window_size (int): Local window size in pre-training.
class SwinTransformerV2Stage(nn.Module):
""" A Swin Transformer V2 Stage.
"""
def __init__(
self, dim, input_resolution, depth, num_heads, window_size,
mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
norm_layer=nn.LayerNorm, downsample=None, pretrained_window_size=0):
self,
dim,
out_dim,
input_resolution,
depth,
num_heads,
window_size,
downsample=False,
mlp_ratio=4.,
qkv_bias=True,
drop=0.,
attn_drop=0.,
drop_path=0.,
norm_layer=nn.LayerNorm,
pretrained_window_size=0,
output_nchw=False,
):
"""
Args:
dim: Number of input channels.
input_resolution: Input resolution.
depth: Number of blocks.
num_heads: Number of attention heads.
window_size: Local window size.
downsample: Use downsample layer at start of the block.
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
qkv_bias: If True, add a learnable bias to query, key, value.
drop: Dropout rate
attn_drop: Attention dropout rate.
drop_path: Stochastic depth rate.
norm_layer: Normalization layer.
pretrained_window_size: Local window size in pretraining.
output_nchw: Output tensors on NCHW format instead of NHWC.
"""
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.output_resolution = tuple(i // 2 for i in input_resolution) if downsample else input_resolution
self.depth = depth
self.output_nchw = output_nchw
self.grad_checkpointing = False
# patch merging / downsample layer
if downsample:
self.downsample = PatchMerging(dim=dim, out_dim=out_dim, norm_layer=norm_layer)
else:
assert dim == out_dim
self.downsample = nn.Identity()
# build blocks
self.blocks = nn.ModuleList([
SwinTransformerBlock(
dim=dim, input_resolution=input_resolution,
num_heads=num_heads, window_size=window_size,
SwinTransformerV2Block(
dim=out_dim,
input_resolution=self.output_resolution,
num_heads=num_heads,
window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop=drop, attn_drop=attn_drop,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer,
pretrained_window_size=pretrained_window_size)
pretrained_window_size=pretrained_window_size,
)
for i in range(depth)])
# patch merging layer
if downsample is not None:
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
else:
self.downsample = nn.Identity()
def forward(self, x):
x = self.downsample(x)
for blk in self.blocks:
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint.checkpoint(blk, x)
else:
x = blk(x)
x = self.downsample(x)
return x
def _init_respostnorm(self):
@ -468,88 +443,113 @@ class BasicLayer(nn.Module):
class SwinTransformerV2(nn.Module):
r""" Swin Transformer V2
A PyTorch impl of : `Swin Transformer V2: Scaling Up Capacity and Resolution`
- https://arxiv.org/abs/2111.09883
Args:
img_size (int | tuple(int)): Input image size. Default 224
patch_size (int | tuple(int)): Patch size. Default: 4
in_chans (int): Number of input image channels. Default: 3
num_classes (int): Number of classes for classification head. Default: 1000
embed_dim (int): Patch embedding dimension. Default: 96
depths (tuple(int)): Depth of each Swin Transformer layer.
num_heads (tuple(int)): Number of attention heads in different layers.
window_size (int): Window size. Default: 7
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
drop_rate (float): Dropout rate. Default: 0
attn_drop_rate (float): Attention dropout rate. Default: 0
drop_path_rate (float): Stochastic depth rate. Default: 0.1
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
patch_norm (bool): If True, add normalization after patch embedding. Default: True
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
pretrained_window_sizes (tuple(int)): Pretrained window sizes of each layer.
""" Swin Transformer V2
A PyTorch impl of : `Swin Transformer V2: Scaling Up Capacity and Resolution`
- https://arxiv.org/abs/2111.09883
"""
def __init__(
self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, global_pool='avg',
embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24),
window_size=7, mlp_ratio=4., qkv_bias=True,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
pretrained_window_sizes=(0, 0, 0, 0), **kwargs):
self,
img_size: _int_or_tuple_2_t = 224,
patch_size: int = 4,
in_chans: int = 3,
num_classes: int = 1000,
global_pool: str = 'avg',
embed_dim: int = 96,
depths: Tuple[int, ...] = (2, 2, 6, 2),
num_heads: Tuple[int, ...] = (3, 6, 12, 24),
window_size: _int_or_tuple_2_t = 7,
mlp_ratio: float = 4.,
qkv_bias: bool = True,
drop_rate: float = 0.,
attn_drop_rate: float = 0.,
drop_path_rate: float = 0.1,
norm_layer: Callable = nn.LayerNorm,
pretrained_window_sizes: Tuple[int, ...] = (0, 0, 0, 0),
**kwargs,
):
"""
Args:
img_size: Input image size.
patch_size: Patch size.
in_chans: Number of input image channels.
num_classes: Number of classes for classification head.
embed_dim: Patch embedding dimension.
depths: Depth of each Swin Transformer stage (layer).
num_heads: Number of attention heads in different layers.
window_size: Window size.
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
qkv_bias: If True, add a learnable bias to query, key, value.
drop_rate: Dropout rate.
attn_drop_rate: Attention dropout rate.
drop_path_rate: Stochastic depth rate.
norm_layer: Normalization layer.
patch_norm: If True, add normalization after patch embedding.
pretrained_window_sizes: Pretrained window sizes of each layer.
output_fmt: Output tensor format if not None, otherwise output 'NHWC' by default.
"""
super().__init__()
self.num_classes = num_classes
assert global_pool in ('', 'avg')
self.global_pool = global_pool
self.output_fmt = 'NHWC'
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.patch_norm = patch_norm
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
self.feature_info = []
if not isinstance(embed_dim, (tuple, list)):
embed_dim = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
# split image into non-overlapping patches
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
num_patches = self.patch_embed.num_patches
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim[0],
norm_layer=norm_layer,
output_fmt='NHWC',
)
# absolute position embedding
if ape:
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
trunc_normal_(self.absolute_pos_embed, std=.02)
else:
self.absolute_pos_embed = None
self.pos_drop = nn.Dropout(p=drop_rate)
# stochastic depth
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
# build layers
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = BasicLayer(
dim=int(embed_dim * 2 ** i_layer),
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
layers = []
in_dim = embed_dim[0]
scale = 1
for i in range(self.num_layers):
out_dim = embed_dim[i]
layers += [SwinTransformerV2Stage(
dim=in_dim,
out_dim=out_dim,
input_resolution=(
self.patch_embed.grid_size[0] // (2 ** i_layer),
self.patch_embed.grid_size[1] // (2 ** i_layer)),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
self.patch_embed.grid_size[0] // scale,
self.patch_embed.grid_size[1] // scale),
depth=depths[i],
downsample=i > 0,
num_heads=num_heads[i],
window_size=window_size,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
drop_path=dpr[i],
norm_layer=norm_layer,
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
pretrained_window_size=pretrained_window_sizes[i_layer]
)
self.layers.append(layer)
pretrained_window_size=pretrained_window_sizes[i],
)]
in_dim = out_dim
if i > 0:
scale *= 2
self.feature_info += [dict(num_chs=out_dim, reduction=4 * scale, module=f'layers.{i}')]
self.layers = nn.Sequential(*layers)
self.norm = norm_layer(self.num_features)
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
self.head = ClassifierHead(
self.num_features,
num_classes,
pool_type=global_pool,
drop_rate=drop_rate,
input_fmt=self.output_fmt,
)
self.apply(self._init_weights)
for bly in self.layers:
@ -563,7 +563,7 @@ class SwinTransformerV2(nn.Module):
@torch.jit.ignore
def no_weight_decay(self):
nod = {'absolute_pos_embed'}
nod = set()
for n, m in self.named_modules():
if any([kw in n for kw in ("cpb_mlp", "logit_scale", 'relative_position_bias_table')]):
nod.add(n)
@ -587,31 +587,20 @@ class SwinTransformerV2(nn.Module):
@torch.jit.ignore
def get_classifier(self):
return self.head
return self.head.fc
def reset_classifier(self, num_classes, global_pool=None):
self.num_classes = num_classes
if global_pool is not None:
assert global_pool in ('', 'avg')
self.global_pool = global_pool
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
self.head.reset(num_classes, global_pool)
def forward_features(self, x):
x = self.patch_embed(x)
if self.absolute_pos_embed is not None:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
for layer in self.layers:
x = layer(x)
x = self.norm(x) # B L C
x = self.layers(x)
x = self.norm(x)
return x
def forward_head(self, x, pre_logits: bool = False):
if self.global_pool == 'avg':
x = x.mean(dim=1)
return x if pre_logits else self.head(x)
return self.head(x, pre_logits=True) if pre_logits else self.head(x)
def forward(self, x):
x = self.forward_features(x)
@ -620,25 +609,102 @@ class SwinTransformerV2(nn.Module):
def checkpoint_filter_fn(state_dict, model):
state_dict = state_dict.get('model', state_dict)
state_dict = state_dict.get('state_dict', state_dict)
if 'head.fc.weight' in state_dict:
return state_dict
out_dict = {}
if 'model' in state_dict:
# For deit models
state_dict = state_dict['model']
import re
for k, v in state_dict.items():
if any([n in k for n in ('relative_position_index', 'relative_coords_table')]):
continue # skip buffers that should not be persistent
k = re.sub(r'layers.(\d+).downsample', lambda x: f'layers.{int(x.group(1)) + 1}.downsample', k)
k = k.replace('head.', 'head.fc.')
out_dict[k] = v
return out_dict
def _create_swin_transformer_v2(variant, pretrained=False, **kwargs):
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 1, 1))))
out_indices = kwargs.pop('out_indices', default_out_indices)
model = build_model_with_cfg(
SwinTransformerV2, variant, pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
**kwargs)
return model
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 256, 256), 'pool_size': (8, 8),
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'patch_embed.proj', 'classifier': 'head.fc',
'license': 'mit', **kwargs
}
default_cfgs = generate_default_cfgs({
'swinv2_base_window12to16_192to256.ms_in22k_ft_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12to16_192to256_22kto1k_ft.pth',
),
'swinv2_base_window12to24_192to384.ms_in22k_ft_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12to24_192to384_22kto1k_ft.pth',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0,
),
'swinv2_large_window12to16_192to256.ms_in22k_ft_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12to16_192to256_22kto1k_ft.pth',
),
'swinv2_large_window12to24_192to384.ms_in22k_ft_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12to24_192to384_22kto1k_ft.pth',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0,
),
'swinv2_tiny_window8_256.ms_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window8_256.pth',
),
'swinv2_tiny_window16_256.ms_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window16_256.pth',
),
'swinv2_small_window8_256.ms_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_small_patch4_window8_256.pth',
),
'swinv2_small_window16_256.ms_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_small_patch4_window16_256.pth',
),
'swinv2_base_window8_256.ms_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window8_256.pth',
),
'swinv2_base_window16_256.ms_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window16_256.pth',
),
'swinv2_base_window12_192.ms_in22k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12_192_22k.pth',
num_classes=21841, input_size=(3, 192, 192), pool_size=(6, 6)
),
'swinv2_large_window12_192.ms_in22k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12_192_22k.pth',
num_classes=21841, input_size=(3, 192, 192), pool_size=(6, 6)
),
})
@register_model
def swinv2_tiny_window16_256(pretrained=False, **kwargs):
"""
@ -694,62 +760,72 @@ def swinv2_base_window8_256(pretrained=False, **kwargs):
@register_model
def swinv2_base_window12_192_22k(pretrained=False, **kwargs):
def swinv2_base_window12_192(pretrained=False, **kwargs):
"""
"""
model_kwargs = dict(
window_size=12, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs)
return _create_swin_transformer_v2('swinv2_base_window12_192_22k', pretrained=pretrained, **model_kwargs)
return _create_swin_transformer_v2('swinv2_base_window12_192', pretrained=pretrained, **model_kwargs)
@register_model
def swinv2_base_window12to16_192to256_22kft1k(pretrained=False, **kwargs):
def swinv2_base_window12to16_192to256(pretrained=False, **kwargs):
"""
"""
model_kwargs = dict(
window_size=16, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32),
pretrained_window_sizes=(12, 12, 12, 6), **kwargs)
return _create_swin_transformer_v2(
'swinv2_base_window12to16_192to256_22kft1k', pretrained=pretrained, **model_kwargs)
'swinv2_base_window12to16_192to256', pretrained=pretrained, **model_kwargs)
@register_model
def swinv2_base_window12to24_192to384_22kft1k(pretrained=False, **kwargs):
def swinv2_base_window12to24_192to384(pretrained=False, **kwargs):
"""
"""
model_kwargs = dict(
window_size=24, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32),
pretrained_window_sizes=(12, 12, 12, 6), **kwargs)
return _create_swin_transformer_v2(
'swinv2_base_window12to24_192to384_22kft1k', pretrained=pretrained, **model_kwargs)
'swinv2_base_window12to24_192to384', pretrained=pretrained, **model_kwargs)
@register_model
def swinv2_large_window12_192_22k(pretrained=False, **kwargs):
def swinv2_large_window12_192(pretrained=False, **kwargs):
"""
"""
model_kwargs = dict(
window_size=12, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs)
return _create_swin_transformer_v2('swinv2_large_window12_192_22k', pretrained=pretrained, **model_kwargs)
return _create_swin_transformer_v2('swinv2_large_window12_192', pretrained=pretrained, **model_kwargs)
@register_model
def swinv2_large_window12to16_192to256_22kft1k(pretrained=False, **kwargs):
def swinv2_large_window12to16_192to256(pretrained=False, **kwargs):
"""
"""
model_kwargs = dict(
window_size=16, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48),
pretrained_window_sizes=(12, 12, 12, 6), **kwargs)
return _create_swin_transformer_v2(
'swinv2_large_window12to16_192to256_22kft1k', pretrained=pretrained, **model_kwargs)
'swinv2_large_window12to16_192to256', pretrained=pretrained, **model_kwargs)
@register_model
def swinv2_large_window12to24_192to384_22kft1k(pretrained=False, **kwargs):
def swinv2_large_window12to24_192to384(pretrained=False, **kwargs):
"""
"""
model_kwargs = dict(
window_size=24, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48),
pretrained_window_sizes=(12, 12, 12, 6), **kwargs)
return _create_swin_transformer_v2(
'swinv2_large_window12to24_192to384_22kft1k', pretrained=pretrained, **model_kwargs)
'swinv2_large_window12to24_192to384', pretrained=pretrained, **model_kwargs)
register_model_deprecations(__name__, {
'swinv2_base_window12_192_22k': 'swinv2_base_window12_192.ms_in22k',
'swinv2_base_window12to16_192to256_22kft1k': 'swinv2_base_window12to16_192to256.ms_in22k_ft_in1k',
'swinv2_base_window12to24_192to384_22kft1k': 'swinv2_base_window12to24_192to384.ms_in22k_ft_in1k',
'swinv2_large_window12_192_22k': 'swinv2_large_window12_192.ms_in22k',
'swinv2_large_window12to16_192to256_22kft1k': 'swinv2_large_window12to16_192to256.ms_in22k_ft_in1k',
'swinv2_large_window12to24_192to384_22kft1k': 'swinv2_large_window12to24_192to384.ms_in22k_ft_in1k',
})

View File

@ -37,71 +37,17 @@ import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, Mlp, to_2tuple, _assert
from timm.layers import DropPath, Mlp, ClassifierHead, to_2tuple, _assert
from ._builder import build_model_with_cfg
from ._features_fx import register_notrace_function
from ._manipulate import named_apply
from ._registry import register_model
from ._registry import generate_default_cfgs, register_model
__all__ = ['SwinTransformerV2Cr'] # model_registry will add each entrypoint fn to this
_logger = logging.getLogger(__name__)
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000,
'input_size': (3, 224, 224),
'pool_size': (7, 7),
'crop_pct': 0.9,
'interpolation': 'bicubic',
'fixed_input_size': True,
'mean': IMAGENET_DEFAULT_MEAN,
'std': IMAGENET_DEFAULT_STD,
'first_conv': 'patch_embed.proj',
'classifier': 'head',
**kwargs,
}
default_cfgs = {
'swinv2_cr_tiny_384': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
'swinv2_cr_tiny_224': _cfg(
url="", input_size=(3, 224, 224), crop_pct=0.9),
'swinv2_cr_tiny_ns_224': _cfg(
url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_tiny_ns_224-ba8166c6.pth",
input_size=(3, 224, 224), crop_pct=0.9),
'swinv2_cr_small_384': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
'swinv2_cr_small_224': _cfg(
url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_small_224-0813c165.pth",
input_size=(3, 224, 224), crop_pct=0.9),
'swinv2_cr_small_ns_224': _cfg(
url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_small_ns_224_iv-2ce90f8e.pth",
input_size=(3, 224, 224), crop_pct=0.9),
'swinv2_cr_base_384': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
'swinv2_cr_base_224': _cfg(
url="", input_size=(3, 224, 224), crop_pct=0.9),
'swinv2_cr_base_ns_224': _cfg(
url="", input_size=(3, 224, 224), crop_pct=0.9),
'swinv2_cr_large_384': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
'swinv2_cr_large_224': _cfg(
url="", input_size=(3, 224, 224), crop_pct=0.9),
'swinv2_cr_huge_384': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
'swinv2_cr_huge_224': _cfg(
url="", input_size=(3, 224, 224), crop_pct=0.9),
'swinv2_cr_giant_384': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
'swinv2_cr_giant_224': _cfg(
url="", input_size=(3, 224, 224), crop_pct=0.9),
}
def bchw_to_bhwc(x: torch.Tensor) -> torch.Tensor:
"""Permutes a tensor from the shape (B, C, H, W) to (B, H, W, C). """
return x.permute(0, 2, 3, 1)
@ -230,22 +176,14 @@ class WindowMultiHeadAttention(nn.Module):
relative_position_bias = relative_position_bias.unsqueeze(0)
return relative_position_bias
def _forward_sequential(
self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
"""
# FIXME TODO figure out 'sequential' attention mentioned in paper (should reduce GPU memory)
assert False, "not implemented"
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
""" Forward pass.
Args:
x (torch.Tensor): Input tensor of the shape (B * windows, N, C)
mask (Optional[torch.Tensor]): Attention mask for the shift case
def _forward_batch(
self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""This function performs standard (non-sequential) scaled cosine self-attention.
Returns:
Output tensor of the shape [B * windows, N, C]
"""
Bw, L, C = x.shape
@ -272,22 +210,8 @@ class WindowMultiHeadAttention(nn.Module):
x = self.proj_drop(x)
return x
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
""" Forward pass.
Args:
x (torch.Tensor): Input tensor of the shape (B * windows, N, C)
mask (Optional[torch.Tensor]): Attention mask for the shift case
Returns:
Output tensor of the shape [B * windows, N, C]
"""
if self.sequential_attn:
return self._forward_sequential(x, mask)
else:
return self._forward_batch(x, mask)
class SwinTransformerBlock(nn.Module):
class SwinTransformerV2CrBlock(nn.Module):
r"""This class implements the Swin transformer block.
Args:
@ -321,7 +245,7 @@ class SwinTransformerBlock(nn.Module):
sequential_attn: bool = False,
norm_layer: Type[nn.Module] = nn.LayerNorm,
) -> None:
super(SwinTransformerBlock, self).__init__()
super(SwinTransformerV2CrBlock, self).__init__()
self.dim: int = dim
self.feat_size: Tuple[int, int] = feat_size
self.target_shift_size: Tuple[int, int] = to_2tuple(shift_size)
@ -410,9 +334,7 @@ class SwinTransformerBlock(nn.Module):
self._make_attention_mask()
def _shifted_window_attn(self, x):
H, W = self.feat_size
B, L, C = x.shape
x = x.view(B, H, W, C)
B, H, W, C = x.shape
# cyclic shift
sh, sw = self.shift_size
@ -441,7 +363,6 @@ class SwinTransformerBlock(nn.Module):
# x = torch.cat([x[:, :, -sw:], x[:, :, :-sw]], dim=2)
x = torch.roll(x, shifts=(sh, sw), dims=(1, 2))
x = x.view(B, L, C)
return x
def forward(self, x: torch.Tensor) -> torch.Tensor:
@ -455,8 +376,12 @@ class SwinTransformerBlock(nn.Module):
"""
# post-norm branches (op -> norm -> drop)
x = x + self.drop_path1(self.norm1(self._shifted_window_attn(x)))
B, H, W, C = x.shape
x = x.reshape(B, -1, C)
x = x + self.drop_path2(self.norm2(self.mlp(x)))
x = self.norm3(x) # main-branch norm enabled for some blocks / stages (every 6 for Huge/Giant)
x = x.reshape(B, H, W, C)
return x
@ -479,12 +404,10 @@ class PatchMerging(nn.Module):
Returns:
output (torch.Tensor): Output tensor of the shape [B, 2 * C, H // 2, W // 2]
"""
B, C, H, W = x.shape
# unfold + BCHW -> BHWC together
# ordering, 5, 3, 1 instead of 3, 5, 1 maintains compat with original swin v1 merge
x = x.reshape(B, C, H // 2, 2, W // 2, 2).permute(0, 2, 4, 5, 3, 1).flatten(3)
B, H, W, C = x.shape
x = x.reshape(B, H // 2, 2, W // 2, 2, C).permute(0, 1, 3, 4, 2, 5).flatten(3)
x = self.norm(x)
x = bhwc_to_bchw(self.reduction(x))
x = self.reduction(x)
return x
@ -511,7 +434,7 @@ class PatchEmbed(nn.Module):
return x
class SwinTransformerStage(nn.Module):
class SwinTransformerV2CrStage(nn.Module):
r"""This class implements a stage of the Swin transformer including multiple layers.
Args:
@ -549,12 +472,16 @@ class SwinTransformerStage(nn.Module):
extra_norm_stage: bool = False,
sequential_attn: bool = False,
) -> None:
super(SwinTransformerStage, self).__init__()
super(SwinTransformerV2CrStage, self).__init__()
self.downscale: bool = downscale
self.grad_checkpointing: bool = False
self.feat_size: Tuple[int, int] = (feat_size[0] // 2, feat_size[1] // 2) if downscale else feat_size
self.downsample = PatchMerging(embed_dim, norm_layer=norm_layer) if downscale else nn.Identity()
if downscale:
self.downsample = PatchMerging(embed_dim, norm_layer=norm_layer)
embed_dim = embed_dim * 2
else:
self.downsample = nn.Identity()
def _extra_norm(index):
i = index + 1
@ -562,9 +489,8 @@ class SwinTransformerStage(nn.Module):
return True
return i == depth if extra_norm_stage else False
embed_dim = embed_dim * 2 if downscale else embed_dim
self.blocks = nn.Sequential(*[
SwinTransformerBlock(
SwinTransformerV2CrBlock(
dim=embed_dim,
num_heads=num_heads,
feat_size=self.feat_size,
@ -602,18 +528,15 @@ class SwinTransformerStage(nn.Module):
Returns:
output (torch.Tensor): Output tensor of the shape [B, 2 * C, H // 2, W // 2]
"""
x = bchw_to_bhwc(x)
x = self.downsample(x)
B, C, H, W = x.shape
L = H * W
x = bchw_to_bhwc(x).reshape(B, L, C)
for block in self.blocks:
# Perform checkpointing if utilized
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint.checkpoint(block, x)
else:
x = block(x)
x = bhwc_to_bchw(x.reshape(B, H, W, -1))
x = bhwc_to_bchw(x)
return x
@ -676,39 +599,54 @@ class SwinTransformerV2Cr(nn.Module):
self.img_size: Tuple[int, int] = img_size
self.window_size: int = window_size
self.num_features: int = int(embed_dim * 2 ** (len(depths) - 1))
self.feature_info = []
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans,
embed_dim=embed_dim, norm_layer=norm_layer)
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
norm_layer=norm_layer,
)
patch_grid_size: Tuple[int, int] = self.patch_embed.grid_size
drop_path_rate = torch.linspace(0.0, drop_path_rate, sum(depths)).tolist()
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
stages = []
for index, (depth, num_heads) in enumerate(zip(depths, num_heads)):
stage_scale = 2 ** max(index - 1, 0)
stages.append(
SwinTransformerStage(
embed_dim=embed_dim * stage_scale,
depth=depth,
downscale=index != 0,
feat_size=(patch_grid_size[0] // stage_scale, patch_grid_size[1] // stage_scale),
num_heads=num_heads,
window_size=window_size,
mlp_ratio=mlp_ratio,
init_values=init_values,
drop=drop_rate,
drop_attn=attn_drop_rate,
drop_path=drop_path_rate[sum(depths[:index]):sum(depths[:index + 1])],
extra_norm_period=extra_norm_period,
extra_norm_stage=extra_norm_stage or (index + 1) == len(depths), # last stage ends w/ norm
sequential_attn=sequential_attn,
norm_layer=norm_layer,
)
)
in_dim = embed_dim
in_scale = 1
for stage_idx, (depth, num_heads) in enumerate(zip(depths, num_heads)):
stages += [SwinTransformerV2CrStage(
embed_dim=in_dim,
depth=depth,
downscale=stage_idx != 0,
feat_size=(
patch_grid_size[0] // in_scale,
patch_grid_size[1] // in_scale
),
num_heads=num_heads,
window_size=window_size,
mlp_ratio=mlp_ratio,
init_values=init_values,
drop=drop_rate,
drop_attn=attn_drop_rate,
drop_path=dpr[stage_idx],
extra_norm_period=extra_norm_period,
extra_norm_stage=extra_norm_stage or (stage_idx + 1) == len(depths), # last stage ends w/ norm
sequential_attn=sequential_attn,
norm_layer=norm_layer,
)]
if stage_idx != 0:
in_dim *= 2
in_scale *= 2
self.feature_info += [dict(num_chs=in_dim, reduction=4 * in_scale, module=f'stages.{stage_idx}')]
self.stages = nn.Sequential(*stages)
self.global_pool: str = global_pool
self.head = nn.Linear(self.num_features, num_classes) if num_classes else nn.Identity()
self.head = ClassifierHead(
self.num_features,
num_classes,
pool_type=global_pool,
drop_rate=drop_rate,
)
# current weight init skips custom init and uses pytorch layer defaults, seems to work well
# FIXME more experiments needed
@ -765,7 +703,7 @@ class SwinTransformerV2Cr(nn.Module):
Returns:
head (nn.Module): Current classification head
"""
return self.head
return self.head.fc
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) -> None:
"""Method results the classification head
@ -774,10 +712,8 @@ class SwinTransformerV2Cr(nn.Module):
num_classes (int): Number of classes to be predicted
global_pool (str): Unused
"""
self.num_classes: int = num_classes
if global_pool is not None:
self.global_pool = global_pool
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
self.num_classes = num_classes
self.head.reset(num_classes, global_pool)
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
x = self.patch_embed(x)
@ -785,9 +721,7 @@ class SwinTransformerV2Cr(nn.Module):
return x
def forward_head(self, x, pre_logits: bool = False):
if self.global_pool == 'avg':
x = x.mean(dim=(2, 3))
return x if pre_logits else self.head(x)
return self.head(x, pre_logits=True) if pre_logits else self.head(x)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.forward_features(x)
@ -814,30 +748,93 @@ def init_weights(module: nn.Module, name: str = ''):
def checkpoint_filter_fn(state_dict, model):
""" convert patch embedding weight from manual patchify + linear proj to conv"""
state_dict = state_dict.get('model', state_dict)
state_dict = state_dict.get('state_dict', state_dict)
if 'head.fc.weight' in state_dict:
return state_dict
out_dict = {}
if 'model' in state_dict:
# For deit models
state_dict = state_dict['model']
for k, v in state_dict.items():
if 'tau' in k:
# convert old tau based checkpoints -> logit_scale (inverse)
v = torch.log(1 / v)
k = k.replace('tau', 'logit_scale')
k = k.replace('head.', 'head.fc.')
out_dict[k] = v
return out_dict
def _create_swin_transformer_v2_cr(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 1, 1))))
out_indices = kwargs.pop('out_indices', default_out_indices)
model = build_model_with_cfg(
SwinTransformerV2Cr, variant, pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
**kwargs
)
return model
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000,
'input_size': (3, 224, 224),
'pool_size': (7, 7),
'crop_pct': 0.9,
'interpolation': 'bicubic',
'fixed_input_size': True,
'mean': IMAGENET_DEFAULT_MEAN,
'std': IMAGENET_DEFAULT_STD,
'first_conv': 'patch_embed.proj',
'classifier': 'head.fc',
**kwargs,
}
default_cfgs = generate_default_cfgs({
'swinv2_cr_tiny_384.untrained': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
'swinv2_cr_tiny_224.untrained': _cfg(
url="", input_size=(3, 224, 224), crop_pct=0.9),
'swinv2_cr_tiny_ns_224.sw_in1k': _cfg(
hf_hub_id='timm/',
url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_tiny_ns_224-ba8166c6.pth",
input_size=(3, 224, 224), crop_pct=0.9),
'swinv2_cr_small_384.untrained': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
'swinv2_cr_small_224.sw_in1k': _cfg(
hf_hub_id='timm/',
url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_small_224-0813c165.pth",
input_size=(3, 224, 224), crop_pct=0.9),
'swinv2_cr_small_ns_224.sw_in1k': _cfg(
hf_hub_id='timm/',
url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_small_ns_224_iv-2ce90f8e.pth",
input_size=(3, 224, 224), crop_pct=0.9),
'swinv2_cr_small_ns_256.untrained': _cfg(
url="", input_size=(3, 256, 256), crop_pct=1.0, pool_size=(8, 8)),
'swinv2_cr_base_384.untrained': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
'swinv2_cr_base_224.untrained': _cfg(
url="", input_size=(3, 224, 224), crop_pct=0.9),
'swinv2_cr_base_ns_224.untrained': _cfg(
url="", input_size=(3, 224, 224), crop_pct=0.9),
'swinv2_cr_large_384.untrained': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
'swinv2_cr_large_224.untrained': _cfg(
url="", input_size=(3, 224, 224), crop_pct=0.9),
'swinv2_cr_huge_384.untrained': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
'swinv2_cr_huge_224.untrained': _cfg(
url="", input_size=(3, 224, 224), crop_pct=0.9),
'swinv2_cr_giant_384.untrained': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
'swinv2_cr_giant_224.untrained': _cfg(
url="", input_size=(3, 224, 224), crop_pct=0.9),
})
@register_model
def swinv2_cr_tiny_384(pretrained=False, **kwargs):
"""Swin-T V2 CR @ 384x384, trained ImageNet-1k"""
@ -915,6 +912,19 @@ def swinv2_cr_small_ns_224(pretrained=False, **kwargs):
return _create_swin_transformer_v2_cr('swinv2_cr_small_ns_224', pretrained=pretrained, **model_kwargs)
@register_model
def swinv2_cr_small_ns_256(pretrained=False, **kwargs):
"""Swin-S V2 CR @ 256x256, trained ImageNet-1k"""
model_kwargs = dict(
embed_dim=96,
depths=(2, 2, 18, 2),
num_heads=(3, 6, 12, 24),
extra_norm_stage=True,
**kwargs
)
return _create_swin_transformer_v2_cr('swinv2_cr_small_ns_256', pretrained=pretrained, **model_kwargs)
@register_model
def swinv2_cr_base_384(pretrained=False, **kwargs):
"""Swin-B V2 CR @ 384x384, trained ImageNet-1k"""
@ -961,8 +971,7 @@ def swinv2_cr_large_384(pretrained=False, **kwargs):
num_heads=(6, 12, 24, 48),
**kwargs
)
return _create_swin_transformer_v2_cr('swinv2_cr_large_384', pretrained=pretrained, **model_kwargs
)
return _create_swin_transformer_v2_cr('swinv2_cr_large_384', pretrained=pretrained, **model_kwargs)
@register_model

View File

@ -41,9 +41,7 @@ from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_,
resample_abs_pos_embed, RmsNorm
from ._builder import build_model_with_cfg
from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv
from ._pretrained import generate_default_cfgs
from ._registry import register_model
from ._registry import generate_default_cfgs, register_model
__all__ = ['VisionTransformer'] # model_registry will add each entrypoint fn to this

View File

@ -20,8 +20,7 @@ import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import StdConv2dSame, StdConv2d, to_2tuple
from ._pretrained import generate_default_cfgs
from ._registry import register_model
from ._registry import generate_default_cfgs, register_model
from .resnet import resnet26d, resnet50d
from .resnetv2 import ResNetV2, create_resnetv2_stem
from .vision_transformer import _create_vision_transformer

View File

@ -17,8 +17,7 @@ from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from timm.layers import PatchEmbed, Mlp, DropPath, RelPosMlp, RelPosBias
from ._builder import build_model_with_cfg
from ._pretrained import generate_default_cfgs
from ._registry import register_model
from ._registry import generate_default_cfgs, register_model
__all__ = ['VisionTransformerRelPos'] # model_registry will add each entrypoint fn to this

View File

@ -1 +1 @@
__version__ = '0.8.15dev0'
__version__ = '0.8.16dev0'

View File

@ -255,8 +255,7 @@ def validate(args):
if args.valid_labels:
with open(args.valid_labels, 'r') as f:
valid_labels = {int(line.rstrip()) for line in f}
valid_labels = [i in valid_labels for i in range(args.num_classes)]
valid_labels = [int(line.rstrip()) for line in f]
else:
valid_labels = None