Merge pull request #2165 from huggingface/more_vits_better_getter
Add features_intermediate() support to coatnet, maxvit, swin*. Tweak feature extraction interfaces. Prep new vit weights while testing.pull/2169/head
commit
3e7ab12af9
|
@ -49,8 +49,9 @@ if hasattr(torch._C, '_jit_set_profiling_executor'):
|
|||
|
||||
# models with forward_intermediates() and support for FeatureGetterNet features_only wrapper
|
||||
FEAT_INTER_FILTERS = [
|
||||
'vit_*', 'twins_*', 'deit*', 'beit*', 'mvitv2*', 'eva*', 'samvit_*', 'flexivit*',
|
||||
'cait_*', 'xcit_*', 'volo_*',
|
||||
'vision_transformer', 'vision_transformer_sam', 'vision_transformer_hybrid', 'vision_transformer_relpos',
|
||||
'beit', 'mvitv2', 'eva', 'cait', 'xcit', 'volo', 'twins', 'deit', 'swin_transformer', 'swin_transformer_v2',
|
||||
'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3', 'levit', 'efficientformer', 'resnet'
|
||||
]
|
||||
|
||||
# transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output.
|
||||
|
@ -388,13 +389,12 @@ def test_model_forward_features(model_name, batch_size):
|
|||
|
||||
@pytest.mark.features
|
||||
@pytest.mark.timeout(120)
|
||||
@pytest.mark.parametrize('model_name', list_models(FEAT_INTER_FILTERS, exclude_filters=EXCLUDE_FILTERS))
|
||||
@pytest.mark.parametrize('model_name', list_models(module=FEAT_INTER_FILTERS, exclude_filters=EXCLUDE_FILTERS + ['*pruned*']))
|
||||
@pytest.mark.parametrize('batch_size', [1])
|
||||
def test_model_forward_intermediates_features(model_name, batch_size):
|
||||
"""Run a single forward pass with each model in feature extraction mode"""
|
||||
model = create_model(model_name, pretrained=False, features_only=True)
|
||||
model = create_model(model_name, pretrained=False, features_only=True, feature_cls='getter')
|
||||
model.eval()
|
||||
print(model.feature_info.out_indices)
|
||||
expected_channels = model.feature_info.channels()
|
||||
expected_reduction = model.feature_info.reduction()
|
||||
|
||||
|
@ -420,7 +420,7 @@ def test_model_forward_intermediates_features(model_name, batch_size):
|
|||
|
||||
@pytest.mark.features
|
||||
@pytest.mark.timeout(120)
|
||||
@pytest.mark.parametrize('model_name', list_models(FEAT_INTER_FILTERS, exclude_filters=EXCLUDE_FILTERS))
|
||||
@pytest.mark.parametrize('model_name', list_models(module=FEAT_INTER_FILTERS, exclude_filters=EXCLUDE_FILTERS + ['*pruned*']))
|
||||
@pytest.mark.parametrize('batch_size', [1])
|
||||
def test_model_forward_intermediates(model_name, batch_size):
|
||||
"""Run a single forward pass with each model in feature extraction mode"""
|
||||
|
@ -429,18 +429,19 @@ def test_model_forward_intermediates(model_name, batch_size):
|
|||
feature_info = timm.models.FeatureInfo(model.feature_info, len(model.feature_info))
|
||||
expected_channels = feature_info.channels()
|
||||
expected_reduction = feature_info.reduction()
|
||||
assert len(expected_channels) >= 4 # all models here should have at least 4 feature levels by default, some 5 or 6
|
||||
assert len(expected_channels) >= 3 # all models here should have at least 3 feature levels
|
||||
|
||||
input_size = _get_input_size(model=model, target=TARGET_FFEAT_SIZE)
|
||||
if max(input_size) > MAX_FFEAT_SIZE:
|
||||
pytest.skip("Fixed input size model > limit.")
|
||||
output_fmt = getattr(model, 'output_fmt', 'NCHW')
|
||||
output_fmt = 'NCHW' # NOTE output_fmt determined by forward_intermediates() arg, not model attribute
|
||||
feat_axis = get_channel_dim(output_fmt)
|
||||
spatial_axis = get_spatial_dim(output_fmt)
|
||||
import math
|
||||
|
||||
output, intermediates = model.forward_intermediates(
|
||||
torch.randn((batch_size, *input_size)),
|
||||
output_fmt=output_fmt,
|
||||
)
|
||||
assert len(expected_channels) == len(intermediates)
|
||||
spatial_size = input_size[-2:]
|
||||
|
|
|
@ -134,6 +134,7 @@ class SelectAdaptivePool2d(nn.Module):
|
|||
super(SelectAdaptivePool2d, self).__init__()
|
||||
assert input_fmt in ('NCHW', 'NHWC')
|
||||
self.pool_type = pool_type or '' # convert other falsy values to empty string for consistent TS typing
|
||||
pool_type = pool_type.lower()
|
||||
if not pool_type:
|
||||
self.pool = nn.Identity() # pass through
|
||||
self.flatten = nn.Flatten(1) if flatten else nn.Identity()
|
||||
|
@ -145,8 +146,10 @@ class SelectAdaptivePool2d(nn.Module):
|
|||
self.pool = FastAdaptiveAvgMaxPool(flatten, input_fmt=input_fmt)
|
||||
elif pool_type.endswith('max'):
|
||||
self.pool = FastAdaptiveMaxPool(flatten, input_fmt=input_fmt)
|
||||
else:
|
||||
elif pool_type == 'fast' or pool_type.endswith('avg'):
|
||||
self.pool = FastAdaptiveAvgPool(flatten, input_fmt=input_fmt)
|
||||
else:
|
||||
assert False, 'Invalid pool type: %s' % pool_type
|
||||
self.flatten = nn.Identity()
|
||||
else:
|
||||
assert input_fmt == 'NCHW'
|
||||
|
@ -156,8 +159,10 @@ class SelectAdaptivePool2d(nn.Module):
|
|||
self.pool = AdaptiveCatAvgMaxPool2d(output_size)
|
||||
elif pool_type == 'max':
|
||||
self.pool = nn.AdaptiveMaxPool2d(output_size)
|
||||
else:
|
||||
elif pool_type == 'avg':
|
||||
self.pool = nn.AdaptiveAvgPool2d(output_size)
|
||||
else:
|
||||
assert False, 'Invalid pool type: %s' % pool_type
|
||||
self.flatten = nn.Flatten(1) if flatten else nn.Identity()
|
||||
|
||||
def is_identity(self):
|
||||
|
|
|
@ -2,7 +2,7 @@ import dataclasses
|
|||
import logging
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Optional, Dict, Callable, Any, Tuple
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
from torch import nn as nn
|
||||
from torch.hub import load_state_dict_from_url
|
||||
|
@ -359,15 +359,15 @@ def build_model_with_cfg(
|
|||
* pruning config / model adaptation
|
||||
|
||||
Args:
|
||||
model_cls (nn.Module): model class
|
||||
variant (str): model variant name
|
||||
pretrained (bool): load pretrained weights
|
||||
pretrained_cfg (dict): model's pretrained weight/task config
|
||||
model_cfg (Optional[Dict]): model's architecture config
|
||||
feature_cfg (Optional[Dict]: feature extraction adapter config
|
||||
pretrained_strict (bool): load pretrained weights strictly
|
||||
pretrained_filter_fn (Optional[Callable]): filter callable for pretrained weights
|
||||
kwargs_filter (Optional[Tuple]): kwargs to filter before passing to model
|
||||
model_cls: model class
|
||||
variant: model variant name
|
||||
pretrained: load pretrained weights
|
||||
pretrained_cfg: model's pretrained weight/task config
|
||||
model_cfg: model's architecture config
|
||||
feature_cfg: feature extraction adapter config
|
||||
pretrained_strict: load pretrained weights strictly
|
||||
pretrained_filter_fn: filter callable for pretrained weights
|
||||
kwargs_filter: kwargs to filter before passing to model
|
||||
**kwargs: model args passed through to model __init__
|
||||
"""
|
||||
pruned = kwargs.pop('pruned', False)
|
||||
|
@ -392,6 +392,8 @@ def build_model_with_cfg(
|
|||
feature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4))
|
||||
if 'out_indices' in kwargs:
|
||||
feature_cfg['out_indices'] = kwargs.pop('out_indices')
|
||||
if 'feature_cls' in kwargs:
|
||||
feature_cfg['feature_cls'] = kwargs.pop('feature_cls')
|
||||
|
||||
# Instantiate the model
|
||||
if model_cfg is None:
|
||||
|
@ -418,24 +420,36 @@ def build_model_with_cfg(
|
|||
|
||||
# Wrap the model in a feature extraction module if enabled
|
||||
if features:
|
||||
feature_cls = FeatureListNet
|
||||
output_fmt = getattr(model, 'output_fmt', None)
|
||||
if output_fmt is not None:
|
||||
feature_cfg.setdefault('output_fmt', output_fmt)
|
||||
use_getter = False
|
||||
if 'feature_cls' in feature_cfg:
|
||||
feature_cls = feature_cfg.pop('feature_cls')
|
||||
if isinstance(feature_cls, str):
|
||||
feature_cls = feature_cls.lower()
|
||||
|
||||
# flatten_sequential only valid for some feature extractors
|
||||
if feature_cls not in ('dict', 'list', 'hook'):
|
||||
feature_cfg.pop('flatten_sequential', None)
|
||||
|
||||
if 'hook' in feature_cls:
|
||||
feature_cls = FeatureHookNet
|
||||
elif feature_cls == 'list':
|
||||
feature_cls = FeatureListNet
|
||||
elif feature_cls == 'dict':
|
||||
feature_cls = FeatureDictNet
|
||||
elif feature_cls == 'fx':
|
||||
feature_cls = FeatureGraphNet
|
||||
elif feature_cls == 'getter':
|
||||
use_getter = True
|
||||
feature_cls = FeatureGetterNet
|
||||
else:
|
||||
assert False, f'Unknown feature class {feature_cls}'
|
||||
else:
|
||||
feature_cls = FeatureListNet
|
||||
|
||||
output_fmt = getattr(model, 'output_fmt', None)
|
||||
if output_fmt is not None and not use_getter: # don't set default for intermediate feat getter
|
||||
feature_cfg.setdefault('output_fmt', output_fmt)
|
||||
|
||||
model = feature_cls(model, **feature_cfg)
|
||||
model.pretrained_cfg = pretrained_cfg_for_features(pretrained_cfg) # add back pretrained cfg
|
||||
model.default_cfg = model.pretrained_cfg # alias for rename backwards compat (default_cfg -> pretrained_cfg)
|
||||
|
|
|
@ -363,7 +363,7 @@ class FeatureHookNet(nn.ModuleDict):
|
|||
out_map: Optional[Sequence[Union[int, str]]] = None,
|
||||
return_dict: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
no_rewrite: bool = False,
|
||||
no_rewrite: Optional[bool] = None,
|
||||
flatten_sequential: bool = False,
|
||||
default_hook_type: str = 'forward',
|
||||
):
|
||||
|
@ -385,7 +385,8 @@ class FeatureHookNet(nn.ModuleDict):
|
|||
self.return_dict = return_dict
|
||||
self.output_fmt = Format(output_fmt)
|
||||
self.grad_checkpointing = False
|
||||
|
||||
if no_rewrite is None:
|
||||
no_rewrite = not flatten_sequential
|
||||
layers = OrderedDict()
|
||||
hooks = []
|
||||
if no_rewrite:
|
||||
|
@ -467,7 +468,7 @@ class FeatureGetterNet(nn.ModuleDict):
|
|||
self.out_indices = out_indices
|
||||
self.out_map = out_map
|
||||
self.return_dict = return_dict
|
||||
self.output_fmt = output_fmt
|
||||
self.output_fmt = Format(output_fmt)
|
||||
self.norm = norm
|
||||
|
||||
def forward(self, x):
|
||||
|
|
|
@ -15,7 +15,7 @@ except ImportError:
|
|||
has_fx_feature_extraction = False
|
||||
|
||||
# Layers we went to treat as leaf modules
|
||||
from timm.layers import Conv2dSame, ScaledStdConv2dSame, CondConv2d, StdConv2dSame
|
||||
from timm.layers import Conv2dSame, ScaledStdConv2dSame, CondConv2d, StdConv2dSame, Format
|
||||
from timm.layers.non_local_attn import BilinearAttnTransform
|
||||
from timm.layers.pool2d_same import MaxPool2dSame, AvgPool2dSame
|
||||
from timm.layers.norm_act import (
|
||||
|
@ -108,12 +108,14 @@ class FeatureGraphNet(nn.Module):
|
|||
model: nn.Module,
|
||||
out_indices: Tuple[int, ...],
|
||||
out_map: Optional[Dict] = None,
|
||||
output_fmt: str = 'NCHW',
|
||||
):
|
||||
super().__init__()
|
||||
assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
|
||||
self.feature_info = _get_feature_info(model, out_indices)
|
||||
if out_map is not None:
|
||||
assert len(out_map) == len(out_indices)
|
||||
self.output_fmt = Format(output_fmt)
|
||||
return_nodes = _get_return_layers(self.feature_info, out_map)
|
||||
self.graph_module = create_feature_extractor(model, return_nodes)
|
||||
|
||||
|
|
|
@ -184,7 +184,7 @@ def _expand_filter(filter: str):
|
|||
|
||||
def list_models(
|
||||
filter: Union[str, List[str]] = '',
|
||||
module: str = '',
|
||||
module: Union[str, List[str]] = '',
|
||||
pretrained: bool = False,
|
||||
exclude_filters: Union[str, List[str]] = '',
|
||||
name_matches_cfg: bool = False,
|
||||
|
@ -217,7 +217,16 @@ def list_models(
|
|||
# FIXME should this be default behaviour? or default to include_tags=True?
|
||||
include_tags = pretrained
|
||||
|
||||
all_models: Set[str] = _module_to_models[module] if module else set(_model_entrypoints.keys())
|
||||
if not module:
|
||||
all_models: Set[str] = set(_model_entrypoints.keys())
|
||||
else:
|
||||
if isinstance(module, str):
|
||||
all_models: Set[str] = _module_to_models[module]
|
||||
else:
|
||||
assert isinstance(module, Sequence)
|
||||
all_models: Set[str] = set()
|
||||
for m in module:
|
||||
all_models.update(_module_to_models[m])
|
||||
all_models = all_models - _deprecated_models.keys() # remove deprecated models from listings
|
||||
|
||||
if include_tags:
|
||||
|
|
|
@ -407,7 +407,7 @@ class Beit(nn.Module):
|
|||
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||
return_prefix_tokens: bool = False,
|
||||
norm: bool = False,
|
||||
stop_early: bool = True,
|
||||
stop_early: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
intermediates_only: bool = False,
|
||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
|
@ -424,7 +424,7 @@ class Beit(nn.Module):
|
|||
Returns:
|
||||
|
||||
"""
|
||||
assert output_fmt in ('NCHW', 'NLC'), 'Output format for ViT features must be one of NCHW or NLC.'
|
||||
assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.'
|
||||
reshape = output_fmt == 'NCHW'
|
||||
intermediates = []
|
||||
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
|
||||
|
@ -436,6 +436,7 @@ class Beit(nn.Module):
|
|||
if self.pos_embed is not None:
|
||||
x = x + self.pos_embed
|
||||
x = self.pos_drop(x)
|
||||
|
||||
rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
|
||||
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||
blocks = self.blocks
|
||||
|
@ -469,19 +470,19 @@ class Beit(nn.Module):
|
|||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
n: Union[int, List[int], Tuple[int]] = 1,
|
||||
indices: Union[int, List[int], Tuple[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
""" Prune layers not required for specified intermediates.
|
||||
"""
|
||||
take_indices, max_index = feature_take_indices(len(self.blocks), n)
|
||||
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
|
||||
self.blocks = self.blocks[:max_index + 1] # truncate blocks
|
||||
if prune_norm:
|
||||
self.norm = nn.Identity()
|
||||
if prune_head:
|
||||
self.fc_norm = nn.Identity()
|
||||
self.head = nn.Identity()
|
||||
self.reset_classifier(0, '')
|
||||
return take_indices
|
||||
|
||||
def forward_features(self, x):
|
||||
|
|
|
@ -343,7 +343,7 @@ class Cait(nn.Module):
|
|||
x: torch.Tensor,
|
||||
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = True,
|
||||
stop_early: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
intermediates_only: bool = False,
|
||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
|
@ -357,7 +357,7 @@ class Cait(nn.Module):
|
|||
output_fmt: Shape of intermediate feature outputs
|
||||
intermediates_only: Only return intermediate features
|
||||
"""
|
||||
assert output_fmt in ('NCHW', 'NLC'), 'Output format for ViT features must be one of NCHW or NLC.'
|
||||
assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.'
|
||||
reshape = output_fmt == 'NCHW'
|
||||
intermediates = []
|
||||
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
|
||||
|
@ -367,6 +367,7 @@ class Cait(nn.Module):
|
|||
x = self.patch_embed(x)
|
||||
x = x + self.pos_embed
|
||||
x = self.pos_drop(x)
|
||||
|
||||
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||
blocks = self.blocks
|
||||
else:
|
||||
|
@ -397,19 +398,19 @@ class Cait(nn.Module):
|
|||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
n: Union[int, List[int], Tuple[int]] = 1,
|
||||
indices: Union[int, List[int], Tuple[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
""" Prune layers not required for specified intermediates.
|
||||
"""
|
||||
take_indices, max_index = feature_take_indices(len(self.blocks), n)
|
||||
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
|
||||
self.blocks = self.blocks[:max_index + 1] # truncate blocks
|
||||
if prune_norm:
|
||||
self.norm = nn.Identity()
|
||||
if prune_head:
|
||||
self.blocks_token_only = nn.ModuleList() # prune token blocks with head
|
||||
self.head = nn.Identity()
|
||||
self.reset_classifier(0, '')
|
||||
return take_indices
|
||||
|
||||
def forward_features(self, x):
|
||||
|
|
|
@ -39,7 +39,7 @@ Modifications and additions for timm hacked together by / Copyright 2022, Ross W
|
|||
|
||||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
from typing import Callable, Optional, Tuple, Union
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -49,6 +49,7 @@ from timm.layers import trunc_normal_, AvgPool2dSame, DropPath, Mlp, GlobalRespo
|
|||
LayerNorm2d, LayerNorm, create_conv2d, get_act_layer, make_divisible, to_ntuple
|
||||
from timm.layers import NormMlpClassifierHead, ClassifierHead
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import feature_take_indices
|
||||
from ._manipulate import named_apply, checkpoint_seq
|
||||
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
||||
|
||||
|
@ -407,6 +408,71 @@ class ConvNeXt(nn.Module):
|
|||
def reset_classifier(self, num_classes=0, global_pool=None):
|
||||
self.head.reset(num_classes, global_pool)
|
||||
|
||||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Union[int, List[int], Tuple[int]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
intermediates_only: bool = False,
|
||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
""" Forward features that returns intermediates.
|
||||
|
||||
Args:
|
||||
x: Input image tensor
|
||||
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||
norm: Apply norm layer to compatible intermediates
|
||||
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||
output_fmt: Shape of intermediate feature outputs
|
||||
intermediates_only: Only return intermediate features
|
||||
Returns:
|
||||
|
||||
"""
|
||||
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||
intermediates = []
|
||||
take_indices, max_index = feature_take_indices(len(self.stages) + 1, indices)
|
||||
|
||||
# forward pass
|
||||
feat_idx = 0 # stem is index 0
|
||||
x = self.stem(x)
|
||||
if feat_idx in take_indices:
|
||||
intermediates.append(x)
|
||||
|
||||
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||
stages = self.stages
|
||||
else:
|
||||
stages = self.stages[:max_index]
|
||||
for stage in stages:
|
||||
feat_idx += 1
|
||||
x = stage(x)
|
||||
if feat_idx in take_indices:
|
||||
# NOTE not bothering to apply norm_pre when norm=True as almost no models have it enabled
|
||||
intermediates.append(x)
|
||||
|
||||
if intermediates_only:
|
||||
return intermediates
|
||||
|
||||
x = self.norm_pre(x)
|
||||
|
||||
return x, intermediates
|
||||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
indices: Union[int, List[int], Tuple[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
""" Prune layers not required for specified intermediates.
|
||||
"""
|
||||
take_indices, max_index = feature_take_indices(len(self.stages) + 1, indices)
|
||||
self.stages = self.stages[:max_index] # truncate blocks w/ stem as idx 0
|
||||
if prune_norm:
|
||||
self.norm_pre = nn.Identity()
|
||||
if prune_head:
|
||||
self.reset_classifier(0, '')
|
||||
return take_indices
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.stem(x)
|
||||
x = self.stages(x)
|
||||
|
|
|
@ -12,7 +12,7 @@ Based on Apache 2.0 licensed code at https://github.com/snap-research/EfficientF
|
|||
|
||||
Modifications and timm support by / Copyright 2022, Ross Wightman
|
||||
"""
|
||||
from typing import Dict
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -20,6 +20,7 @@ import torch.nn as nn
|
|||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import DropPath, trunc_normal_, to_2tuple, Mlp, ndgrid
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import feature_take_indices
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._registry import generate_default_cfgs, register_model
|
||||
|
||||
|
@ -382,16 +383,19 @@ class EfficientFormer(nn.Module):
|
|||
prev_dim = embed_dims[0]
|
||||
|
||||
# stochastic depth decay rule
|
||||
self.num_stages = len(depths)
|
||||
last_stage = self.num_stages - 1
|
||||
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
|
||||
downsamples = downsamples or (False,) + (True,) * (len(depths) - 1)
|
||||
downsamples = downsamples or (False,) + (True,) * (self.num_stages - 1)
|
||||
stages = []
|
||||
for i in range(len(depths)):
|
||||
self.feature_info = []
|
||||
for i in range(self.num_stages):
|
||||
stage = EfficientFormerStage(
|
||||
prev_dim,
|
||||
embed_dims[i],
|
||||
depths[i],
|
||||
downsample=downsamples[i],
|
||||
num_vit=num_vit if i == 3 else 0,
|
||||
num_vit=num_vit if i == last_stage else 0,
|
||||
pool_size=pool_size,
|
||||
mlp_ratio=mlp_ratios,
|
||||
act_layer=act_layer,
|
||||
|
@ -403,7 +407,7 @@ class EfficientFormer(nn.Module):
|
|||
)
|
||||
prev_dim = embed_dims[i]
|
||||
stages.append(stage)
|
||||
|
||||
self.feature_info += [dict(num_chs=embed_dims[i], reduction=2**(1+i), module=f'stages.{i}')]
|
||||
self.stages = nn.Sequential(*stages)
|
||||
|
||||
# Classifier head
|
||||
|
@ -456,6 +460,76 @@ class EfficientFormer(nn.Module):
|
|||
def set_distilled_training(self, enable=True):
|
||||
self.distilled_training = enable
|
||||
|
||||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Union[int, List[int], Tuple[int]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
intermediates_only: bool = False,
|
||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
""" Forward features that returns intermediates.
|
||||
|
||||
Args:
|
||||
x: Input image tensor
|
||||
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||
norm: Apply norm layer to compatible intermediates
|
||||
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||
output_fmt: Shape of intermediate feature outputs
|
||||
intermediates_only: Only return intermediate features
|
||||
Returns:
|
||||
|
||||
"""
|
||||
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||
intermediates = []
|
||||
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||
|
||||
# forward pass
|
||||
x = self.stem(x)
|
||||
B, C, H, W = x.shape
|
||||
|
||||
last_idx = self.num_stages - 1
|
||||
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||
stages = self.stages
|
||||
else:
|
||||
stages = self.stages[:max_index + 1]
|
||||
feat_idx = 0
|
||||
for feat_idx, stage in enumerate(stages):
|
||||
x = stage(x)
|
||||
if feat_idx < last_idx:
|
||||
B, C, H, W = x.shape
|
||||
if feat_idx in take_indices:
|
||||
if feat_idx == last_idx:
|
||||
x_inter = self.norm(x) if norm else x
|
||||
intermediates.append(x_inter.reshape(B, H // 2, W // 2, -1).permute(0, 3, 1, 2))
|
||||
else:
|
||||
intermediates.append(x)
|
||||
|
||||
if intermediates_only:
|
||||
return intermediates
|
||||
|
||||
if feat_idx == last_idx:
|
||||
x = self.norm(x)
|
||||
|
||||
return x, intermediates
|
||||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
indices: Union[int, List[int], Tuple[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
""" Prune layers not required for specified intermediates.
|
||||
"""
|
||||
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
|
||||
if prune_norm:
|
||||
self.norm = nn.Identity()
|
||||
if prune_head:
|
||||
self.reset_classifier(0, '')
|
||||
return take_indices
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.stem(x)
|
||||
x = self.stages(x)
|
||||
|
@ -534,13 +608,13 @@ default_cfgs = generate_default_cfgs({
|
|||
|
||||
|
||||
def _create_efficientformer(variant, pretrained=False, **kwargs):
|
||||
if kwargs.get('features_only', None):
|
||||
raise RuntimeError('features_only not implemented for EfficientFormer models.')
|
||||
|
||||
out_indices = kwargs.pop('out_indices', 4)
|
||||
model = build_model_with_cfg(
|
||||
EfficientFormer, variant, pretrained,
|
||||
pretrained_filter_fn=_checkpoint_filter_fn,
|
||||
**kwargs)
|
||||
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
|
||||
**kwargs,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
|
|
|
@ -36,7 +36,7 @@ the models and weights open source!
|
|||
Hacked together by / Copyright 2019, Ross Wightman
|
||||
"""
|
||||
from functools import partial
|
||||
from typing import List
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -49,7 +49,7 @@ from ._builder import build_model_with_cfg, pretrained_cfg_for_features
|
|||
from ._efficientnet_blocks import SqueezeExcite
|
||||
from ._efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \
|
||||
round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
|
||||
from ._features import FeatureInfo, FeatureHooks
|
||||
from ._features import FeatureInfo, FeatureHooks, feature_take_indices
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
||||
|
||||
|
@ -118,6 +118,7 @@ class EfficientNet(nn.Module):
|
|||
)
|
||||
self.blocks = nn.Sequential(*builder(stem_size, block_args))
|
||||
self.feature_info = builder.features
|
||||
self.stage_ends = [f['stage'] for f in self.feature_info]
|
||||
head_chs = builder.in_chs
|
||||
|
||||
# Head + Pooling
|
||||
|
@ -158,6 +159,86 @@ class EfficientNet(nn.Module):
|
|||
self.global_pool, self.classifier = create_classifier(
|
||||
self.num_features, self.num_classes, pool_type=global_pool)
|
||||
|
||||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Union[int, List[int], Tuple[int]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
intermediates_only: bool = False,
|
||||
extra_blocks: bool = False,
|
||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
""" Forward features that returns intermediates.
|
||||
|
||||
Args:
|
||||
x: Input image tensor
|
||||
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||
norm: Apply norm layer to compatible intermediates
|
||||
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||
output_fmt: Shape of intermediate feature outputs
|
||||
intermediates_only: Only return intermediate features
|
||||
extra_blocks: Include outputs of all blocks and head conv in output, does not align with feature_info
|
||||
Returns:
|
||||
|
||||
"""
|
||||
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||
if stop_early:
|
||||
assert intermediates_only, 'Must use intermediates_only for early stopping.'
|
||||
intermediates = []
|
||||
if extra_blocks:
|
||||
take_indices, max_index = feature_take_indices(len(self.blocks) + 1, indices)
|
||||
else:
|
||||
take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
|
||||
take_indices = [self.stage_ends[i] for i in take_indices]
|
||||
max_index = self.stage_ends[max_index]
|
||||
# forward pass
|
||||
feat_idx = 0 # stem is index 0
|
||||
x = self.conv_stem(x)
|
||||
x = self.bn1(x)
|
||||
if feat_idx in take_indices:
|
||||
intermediates.append(x)
|
||||
|
||||
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||
blocks = self.blocks
|
||||
else:
|
||||
blocks = self.blocks[:max_index]
|
||||
for blk in blocks:
|
||||
feat_idx += 1
|
||||
x = blk(x)
|
||||
if feat_idx in take_indices:
|
||||
intermediates.append(x)
|
||||
|
||||
if intermediates_only:
|
||||
return intermediates
|
||||
|
||||
x = self.conv_head(x)
|
||||
x = self.bn2(x)
|
||||
|
||||
return x, intermediates
|
||||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
indices: Union[int, List[int], Tuple[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
extra_blocks: bool = False,
|
||||
):
|
||||
""" Prune layers not required for specified intermediates.
|
||||
"""
|
||||
if extra_blocks:
|
||||
take_indices, max_index = feature_take_indices(len(self.blocks) + 1, indices)
|
||||
else:
|
||||
take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
|
||||
max_index = self.stage_ends[max_index]
|
||||
self.blocks = self.blocks[:max_index] # truncate blocks w/ stem as idx 0
|
||||
if prune_norm or max_index < len(self.blocks):
|
||||
self.conv_head = nn.Identity()
|
||||
self.bn2 = nn.Identity()
|
||||
if prune_head:
|
||||
self.reset_classifier(0, '')
|
||||
return take_indices
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.conv_stem(x)
|
||||
x = self.bn1(x)
|
||||
|
@ -272,7 +353,7 @@ def _create_effnet(variant, pretrained=False, **kwargs):
|
|||
model_cls = EfficientNet
|
||||
kwargs_filter = None
|
||||
if kwargs.pop('features_only', False):
|
||||
if 'feature_cfg' in kwargs:
|
||||
if 'feature_cfg' in kwargs or 'feature_cls' in kwargs:
|
||||
features_mode = 'cfg'
|
||||
else:
|
||||
kwargs_filter = ('num_classes', 'num_features', 'head_conv', 'global_pool')
|
||||
|
|
|
@ -53,6 +53,7 @@ class EvaAttention(nn.Module):
|
|||
num_heads: int = 8,
|
||||
qkv_bias: bool = True,
|
||||
qkv_fused: bool = True,
|
||||
num_prefix_tokens: int = 1,
|
||||
attn_drop: float = 0.,
|
||||
proj_drop: float = 0.,
|
||||
attn_head_dim: Optional[int] = None,
|
||||
|
@ -77,6 +78,7 @@ class EvaAttention(nn.Module):
|
|||
head_dim = attn_head_dim
|
||||
all_head_dim = head_dim * self.num_heads
|
||||
self.scale = head_dim ** -0.5
|
||||
self.num_prefix_tokens = num_prefix_tokens
|
||||
self.fused_attn = use_fused_attn()
|
||||
|
||||
if qkv_fused:
|
||||
|
@ -119,8 +121,9 @@ class EvaAttention(nn.Module):
|
|||
v = self.v_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2)
|
||||
|
||||
if rope is not None:
|
||||
q = torch.cat([q[:, :, :1, :], apply_rot_embed_cat(q[:, :, 1:, :], rope)], 2).type_as(v)
|
||||
k = torch.cat([k[:, :, :1, :], apply_rot_embed_cat(k[:, :, 1:, :], rope)], 2).type_as(v)
|
||||
npt = self.num_prefix_tokens
|
||||
q = torch.cat([q[:, :, :npt, :], apply_rot_embed_cat(q[:, :, npt:, :], rope)], dim=2).type_as(v)
|
||||
k = torch.cat([k[:, :, :npt, :], apply_rot_embed_cat(k[:, :, npt:, :], rope)], dim=2).type_as(v)
|
||||
|
||||
if self.fused_attn:
|
||||
x = F.scaled_dot_product_attention(
|
||||
|
@ -157,6 +160,7 @@ class EvaBlock(nn.Module):
|
|||
swiglu_mlp: bool = False,
|
||||
scale_mlp: bool = False,
|
||||
scale_attn_inner: bool = False,
|
||||
num_prefix_tokens: int = 1,
|
||||
proj_drop: float = 0.,
|
||||
attn_drop: float = 0.,
|
||||
drop_path: float = 0.,
|
||||
|
@ -191,6 +195,7 @@ class EvaBlock(nn.Module):
|
|||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qkv_fused=qkv_fused,
|
||||
num_prefix_tokens=num_prefix_tokens,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=proj_drop,
|
||||
attn_head_dim=attn_head_dim,
|
||||
|
@ -253,6 +258,7 @@ class EvaBlockPostNorm(nn.Module):
|
|||
swiglu_mlp: bool = False,
|
||||
scale_mlp: bool = False,
|
||||
scale_attn_inner: bool = False,
|
||||
num_prefix_tokens: int = 1,
|
||||
proj_drop: float = 0.,
|
||||
attn_drop: float = 0.,
|
||||
drop_path: float = 0.,
|
||||
|
@ -286,6 +292,7 @@ class EvaBlockPostNorm(nn.Module):
|
|||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qkv_fused=qkv_fused,
|
||||
num_prefix_tokens=num_prefix_tokens,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=proj_drop,
|
||||
attn_head_dim=attn_head_dim,
|
||||
|
@ -364,6 +371,7 @@ class Eva(nn.Module):
|
|||
norm_layer: Callable = LayerNorm,
|
||||
init_values: Optional[float] = None,
|
||||
class_token: bool = True,
|
||||
num_reg_tokens: int = 0,
|
||||
use_abs_pos_emb: bool = True,
|
||||
use_rot_pos_emb: bool = False,
|
||||
use_post_norm: bool = False,
|
||||
|
@ -407,7 +415,7 @@ class Eva(nn.Module):
|
|||
self.num_classes = num_classes
|
||||
self.global_pool = global_pool
|
||||
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
||||
self.num_prefix_tokens = 1 if class_token else 0
|
||||
self.num_prefix_tokens = (1 if class_token else 0) + num_reg_tokens
|
||||
self.dynamic_img_size = dynamic_img_size
|
||||
self.grad_checkpointing = False
|
||||
|
||||
|
@ -427,6 +435,8 @@ class Eva(nn.Module):
|
|||
r = self.patch_embed.feat_ratio() if hasattr(self.patch_embed, 'feat_ratio') else patch_size
|
||||
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
|
||||
self.reg_token = nn.Parameter(torch.zeros(1, num_reg_tokens, embed_dim)) if num_reg_tokens else None
|
||||
self.cls_embed = class_token and self.reg_token is None
|
||||
|
||||
self.pos_embed = nn.Parameter(
|
||||
torch.zeros(1, num_patches + self.num_prefix_tokens, embed_dim)) if use_abs_pos_emb else None
|
||||
|
@ -463,6 +473,7 @@ class Eva(nn.Module):
|
|||
swiglu_mlp=swiglu_mlp,
|
||||
scale_mlp=scale_mlp,
|
||||
scale_attn_inner=scale_attn_inner,
|
||||
num_prefix_tokens=self.num_prefix_tokens,
|
||||
proj_drop=proj_drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[i],
|
||||
|
@ -484,6 +495,8 @@ class Eva(nn.Module):
|
|||
trunc_normal_(self.pos_embed, std=.02)
|
||||
if self.cls_token is not None:
|
||||
trunc_normal_(self.cls_token, std=.02)
|
||||
if self.reg_token is not None:
|
||||
trunc_normal_(self.reg_token, std=.02)
|
||||
|
||||
self.fix_init_weight()
|
||||
if isinstance(self.head, nn.Linear):
|
||||
|
@ -551,8 +564,17 @@ class Eva(nn.Module):
|
|||
|
||||
if self.cls_token is not None:
|
||||
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
||||
|
||||
if pos_embed is not None:
|
||||
x = x + pos_embed
|
||||
|
||||
if self.reg_token is not None:
|
||||
to_cat = []
|
||||
if self.cls_token is not None:
|
||||
to_cat.append(self.cls_token.expand(x.shape[0], -1, -1))
|
||||
to_cat.append(self.reg_token.expand(x.shape[0], -1, -1))
|
||||
x = torch.cat(to_cat + [x], dim=1)
|
||||
|
||||
x = self.pos_drop(x)
|
||||
|
||||
# obtain shared rotary position embedding and apply patch dropout
|
||||
|
@ -568,7 +590,7 @@ class Eva(nn.Module):
|
|||
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||
return_prefix_tokens: bool = False,
|
||||
norm: bool = False,
|
||||
stop_early: bool = True,
|
||||
stop_early: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
intermediates_only: bool = False,
|
||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
|
@ -622,19 +644,19 @@ class Eva(nn.Module):
|
|||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
n: Union[int, List[int], Tuple[int]] = 1,
|
||||
indices: Union[int, List[int], Tuple[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
""" Prune layers not required for specified intermediates.
|
||||
"""
|
||||
take_indices, max_index = feature_take_indices(len(self.blocks), n)
|
||||
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
|
||||
self.blocks = self.blocks[:max_index + 1] # truncate blocks
|
||||
if prune_norm:
|
||||
self.norm = nn.Identity()
|
||||
if prune_head:
|
||||
self.fc_norm = nn.Identity()
|
||||
self.head = nn.Identity()
|
||||
self.reset_classifier(0, '')
|
||||
return take_indices
|
||||
|
||||
def forward_features(self, x):
|
||||
|
@ -695,6 +717,12 @@ def checkpoint_filter_fn(
|
|||
# fixed embedding no need to load buffer from checkpoint
|
||||
continue
|
||||
|
||||
# FIXME here while import new weights, to remove
|
||||
# if k == 'cls_token':
|
||||
# print('DEBUG: cls token -> reg')
|
||||
# k = 'reg_token'
|
||||
# #v = v + state_dict['pos_embed'][0, :]
|
||||
|
||||
if 'patch_embed.proj.weight' in k:
|
||||
_, _, H, W = model.patch_embed.proj.weight.shape
|
||||
if v.shape[-1] != W or v.shape[-2] != H:
|
||||
|
@ -923,6 +951,29 @@ default_cfgs = generate_default_cfgs({
|
|||
num_classes=0,
|
||||
),
|
||||
|
||||
'vit_medium_patch16_rope_reg1_gap_256.in1k': _cfg(
|
||||
#hf_hub_id='timm/',
|
||||
file='vit_medium_gap1_rope-in1k-20230920-5.pth',
|
||||
input_size=(3, 256, 256), crop_pct=0.95,
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)
|
||||
),
|
||||
'vit_mediumd_patch16_rope_reg1_gap_256.in1k': _cfg(
|
||||
#hf_hub_id='timm/',
|
||||
file='vit_mediumd_gap1_rope-in1k-20230926-5.pth',
|
||||
input_size=(3, 256, 256), crop_pct=0.95,
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)
|
||||
),
|
||||
'vit_betwixt_patch16_rope_reg4_gap_256.in1k': _cfg(
|
||||
#hf_hub_id='timm/',
|
||||
file='vit_betwixt_gap4_rope-in1k-20231005-5.pth',
|
||||
input_size=(3, 256, 256), crop_pct=0.95,
|
||||
),
|
||||
'vit_base_patch16_rope_reg1_gap_256.in1k': _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
file='vit_base_gap1_rope-in1k-20230930-5.pth',
|
||||
input_size=(3, 256, 256), crop_pct=0.95,
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)
|
||||
),
|
||||
})
|
||||
|
||||
|
||||
|
@ -1185,3 +1236,87 @@ def eva02_enormous_patch14_clip_224(pretrained=False, **kwargs) -> Eva:
|
|||
)
|
||||
model = _create_eva('eva02_enormous_patch14_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_medium_patch16_rope_reg1_gap_256(pretrained=False, **kwargs) -> Eva:
|
||||
model_args = dict(
|
||||
img_size=256,
|
||||
patch_size=16,
|
||||
embed_dim=512,
|
||||
depth=12,
|
||||
num_heads=8,
|
||||
qkv_fused=True,
|
||||
qkv_bias=True,
|
||||
init_values=1e-5,
|
||||
class_token=False,
|
||||
num_reg_tokens=1,
|
||||
use_rot_pos_emb=True,
|
||||
use_abs_pos_emb=False,
|
||||
ref_feat_shape=(16, 16), # 224/14
|
||||
)
|
||||
model = _create_eva('vit_medium_patch16_rope_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_mediumd_patch16_rope_reg1_gap_256(pretrained=False, **kwargs) -> Eva:
|
||||
model_args = dict(
|
||||
img_size=256,
|
||||
patch_size=16,
|
||||
embed_dim=512,
|
||||
depth=20,
|
||||
num_heads=8,
|
||||
qkv_fused=True,
|
||||
qkv_bias=False,
|
||||
init_values=1e-5,
|
||||
class_token=False,
|
||||
num_reg_tokens=1,
|
||||
use_rot_pos_emb=True,
|
||||
use_abs_pos_emb=False,
|
||||
ref_feat_shape=(16, 16), # 224/14
|
||||
)
|
||||
model = _create_eva('vit_mediumd_patch16_rope_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_betwixt_patch16_rope_reg4_gap_256(pretrained=False, **kwargs) -> Eva:
|
||||
model_args = dict(
|
||||
img_size=256,
|
||||
patch_size=16,
|
||||
embed_dim=640,
|
||||
depth=12,
|
||||
num_heads=10,
|
||||
qkv_fused=True,
|
||||
qkv_bias=True,
|
||||
init_values=1e-5,
|
||||
class_token=False,
|
||||
num_reg_tokens=4,
|
||||
use_rot_pos_emb=True,
|
||||
use_abs_pos_emb=False,
|
||||
ref_feat_shape=(16, 16), # 224/14
|
||||
)
|
||||
model = _create_eva('vit_betwixt_patch16_rope_reg4_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_patch16_rope_reg1_gap_256(pretrained=False, **kwargs) -> Eva:
|
||||
model_args = dict(
|
||||
img_size=256,
|
||||
patch_size=16,
|
||||
embed_dim=768,
|
||||
depth=12,
|
||||
num_heads=12,
|
||||
qkv_fused=True,
|
||||
qkv_bias=True,
|
||||
init_values=1e-5,
|
||||
class_token=False,
|
||||
num_reg_tokens=1,
|
||||
use_rot_pos_emb=True,
|
||||
use_abs_pos_emb=False,
|
||||
ref_feat_shape=(16, 16), # 224/14
|
||||
)
|
||||
model = _create_eva('vit_base_patch16_rope_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
|
|
@ -25,7 +25,7 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W
|
|||
# Copyright 2020 Ross Wightman, Apache-2.0 License
|
||||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
from typing import Dict
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -33,6 +33,7 @@ import torch.nn as nn
|
|||
from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN
|
||||
from timm.layers import to_ntuple, to_2tuple, get_act_layer, DropPath, trunc_normal_, ndgrid
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import feature_take_indices
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._registry import generate_default_cfgs, register_model
|
||||
|
||||
|
@ -634,6 +635,70 @@ class Levit(nn.Module):
|
|||
self.head = NormLinear(
|
||||
self.embed_dim[-1], num_classes, drop=self.drop_rate) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Union[int, List[int], Tuple[int]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
intermediates_only: bool = False,
|
||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
""" Forward features that returns intermediates.
|
||||
|
||||
Args:
|
||||
x: Input image tensor
|
||||
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||
norm: Apply norm layer to compatible intermediates
|
||||
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||
output_fmt: Shape of intermediate feature outputs
|
||||
intermediates_only: Only return intermediate features
|
||||
Returns:
|
||||
|
||||
"""
|
||||
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||
intermediates = []
|
||||
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||
|
||||
# forward pass
|
||||
x = self.stem(x)
|
||||
B, C, H, W = x.shape
|
||||
if not self.use_conv:
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
|
||||
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||
stages = self.stages
|
||||
else:
|
||||
stages = self.stages[:max_index + 1]
|
||||
for feat_idx, stage in enumerate(stages):
|
||||
x = stage(x)
|
||||
if feat_idx in take_indices:
|
||||
if self.use_conv:
|
||||
intermediates.append(x)
|
||||
else:
|
||||
intermediates.append(x.reshape(B, H, W, -1).permute(0, 3, 1, 2))
|
||||
H = (H + 2 - 1) // 2
|
||||
W = (W + 2 - 1) // 2
|
||||
|
||||
if intermediates_only:
|
||||
return intermediates
|
||||
|
||||
return x, intermediates
|
||||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
indices: Union[int, List[int], Tuple[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
""" Prune layers not required for specified intermediates.
|
||||
"""
|
||||
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
|
||||
if prune_head:
|
||||
self.reset_classifier(0, '')
|
||||
return take_indices
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.stem(x)
|
||||
if not self.use_conv:
|
||||
|
@ -746,9 +811,8 @@ model_cfgs = dict(
|
|||
def create_levit(variant, cfg_variant=None, pretrained=False, distilled=True, **kwargs):
|
||||
is_conv = '_conv' in variant
|
||||
out_indices = kwargs.pop('out_indices', (0, 1, 2))
|
||||
if kwargs.get('features_only', None):
|
||||
if not is_conv:
|
||||
raise RuntimeError('features_only not implemented for LeVit in non-convolutional mode.')
|
||||
if kwargs.get('features_only', False) and not is_conv:
|
||||
kwargs.setdefault('feature_cls', 'getter')
|
||||
if cfg_variant is None:
|
||||
if variant in model_cfgs:
|
||||
cfg_variant = variant
|
||||
|
|
|
@ -50,6 +50,7 @@ from timm.layers import create_attn, get_act_layer, get_norm_layer, get_norm_act
|
|||
from timm.layers import trunc_normal_tf_, to_2tuple, extend_tuple, make_divisible, _assert
|
||||
from timm.layers import RelPosMlp, RelPosBias, RelPosBiasTf, use_fused_attn, resize_rel_pos_bias_table
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import feature_take_indices
|
||||
from ._features_fx import register_notrace_function
|
||||
from ._manipulate import named_apply, checkpoint_seq
|
||||
from ._registry import generate_default_cfgs, register_model
|
||||
|
@ -1251,6 +1252,75 @@ class MaxxVit(nn.Module):
|
|||
self.num_classes = num_classes
|
||||
self.head.reset(num_classes, global_pool)
|
||||
|
||||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Union[int, List[int], Tuple[int]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
intermediates_only: bool = False,
|
||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
""" Forward features that returns intermediates.
|
||||
|
||||
Args:
|
||||
x: Input image tensor
|
||||
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||
norm: Apply norm layer to compatible intermediates
|
||||
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||
output_fmt: Shape of intermediate feature outputs
|
||||
intermediates_only: Only return intermediate features
|
||||
Returns:
|
||||
|
||||
"""
|
||||
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||
intermediates = []
|
||||
take_indices, max_index = feature_take_indices(len(self.stages) + 1, indices)
|
||||
|
||||
# forward pass
|
||||
feat_idx = 0 # stem is index 0
|
||||
x = self.stem(x)
|
||||
if feat_idx in take_indices:
|
||||
intermediates.append(x)
|
||||
|
||||
last_idx = len(self.stages)
|
||||
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||
stages = self.stages
|
||||
else:
|
||||
stages = self.stages[:max_index]
|
||||
for stage in stages:
|
||||
feat_idx += 1
|
||||
x = stage(x)
|
||||
if feat_idx in take_indices:
|
||||
if norm and feat_idx == last_idx:
|
||||
x_inter = self.norm(x) # applying final norm to last intermediate
|
||||
else:
|
||||
x_inter = x
|
||||
intermediates.append(x_inter)
|
||||
|
||||
if intermediates_only:
|
||||
return intermediates
|
||||
|
||||
x = self.norm(x)
|
||||
|
||||
return x, intermediates
|
||||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
indices: Union[int, List[int], Tuple[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
""" Prune layers not required for specified intermediates.
|
||||
"""
|
||||
take_indices, max_index = feature_take_indices(len(self.stages) + 1, indices)
|
||||
self.stages = self.stages[:max_index] # truncate blocks w/ stem as idx 0
|
||||
if prune_norm:
|
||||
self.norm = nn.Identity()
|
||||
if prune_head:
|
||||
self.head = self.reset_classifier(0, '')
|
||||
return take_indices
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.stem(x)
|
||||
x = self.stages(x)
|
||||
|
|
|
@ -7,7 +7,7 @@ Paper: Searching for MobileNetV3 - https://arxiv.org/abs/1905.02244
|
|||
Hacked together by / Copyright 2019, Ross Wightman
|
||||
"""
|
||||
from functools import partial
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -20,7 +20,7 @@ from ._builder import build_model_with_cfg, pretrained_cfg_for_features
|
|||
from ._efficientnet_blocks import SqueezeExcite
|
||||
from ._efficientnet_builder import BlockArgs, EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \
|
||||
round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
|
||||
from ._features import FeatureInfo, FeatureHooks
|
||||
from ._features import FeatureInfo, FeatureHooks, feature_take_indices
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
||||
|
||||
|
@ -109,6 +109,7 @@ class MobileNetV3(nn.Module):
|
|||
)
|
||||
self.blocks = nn.Sequential(*builder(stem_size, block_args))
|
||||
self.feature_info = builder.features
|
||||
self.stage_ends = [f['stage'] for f in self.feature_info]
|
||||
head_chs = builder.in_chs
|
||||
|
||||
# Head + Pooling
|
||||
|
@ -150,6 +151,84 @@ class MobileNetV3(nn.Module):
|
|||
self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
|
||||
self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Union[int, List[int], Tuple[int]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
intermediates_only: bool = False,
|
||||
extra_blocks: bool = False,
|
||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
""" Forward features that returns intermediates.
|
||||
|
||||
Args:
|
||||
x: Input image tensor
|
||||
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||
norm: Apply norm layer to compatible intermediates
|
||||
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||
output_fmt: Shape of intermediate feature outputs
|
||||
intermediates_only: Only return intermediate features
|
||||
extra_blocks: Include outputs of all blocks and head conv in output, does not align with feature_info
|
||||
Returns:
|
||||
|
||||
"""
|
||||
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||
if stop_early:
|
||||
assert intermediates_only, 'Must use intermediates_only for early stopping.'
|
||||
intermediates = []
|
||||
if extra_blocks:
|
||||
take_indices, max_index = feature_take_indices(len(self.blocks) + 1, indices)
|
||||
else:
|
||||
take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
|
||||
take_indices = [self.stage_ends[i] for i in take_indices]
|
||||
max_index = self.stage_ends[max_index]
|
||||
|
||||
# forward pass
|
||||
feat_idx = 0 # stem is index 0
|
||||
x = self.conv_stem(x)
|
||||
x = self.bn1(x)
|
||||
if feat_idx in take_indices:
|
||||
intermediates.append(x)
|
||||
|
||||
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||
blocks = self.blocks
|
||||
else:
|
||||
blocks = self.blocks[:max_index]
|
||||
for blk in blocks:
|
||||
feat_idx += 1
|
||||
x = blk(x)
|
||||
if feat_idx in take_indices:
|
||||
intermediates.append(x)
|
||||
|
||||
if intermediates_only:
|
||||
return intermediates
|
||||
|
||||
return x, intermediates
|
||||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
indices: Union[int, List[int], Tuple[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
extra_blocks: bool = False,
|
||||
):
|
||||
""" Prune layers not required for specified intermediates.
|
||||
"""
|
||||
if extra_blocks:
|
||||
take_indices, max_index = feature_take_indices(len(self.blocks) + 1, indices)
|
||||
else:
|
||||
take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
|
||||
max_index = self.stage_ends[max_index]
|
||||
self.blocks = self.blocks[:max_index] # truncate blocks w/ stem as idx 0
|
||||
if max_index < len(self.blocks):
|
||||
self.conv_head = nn.Identity()
|
||||
if prune_head:
|
||||
self.conv_head = nn.Identity()
|
||||
self.reset_classifier(0, '')
|
||||
return take_indices
|
||||
|
||||
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.conv_stem(x)
|
||||
x = self.bn1(x)
|
||||
|
@ -288,7 +367,7 @@ def _create_mnv3(variant: str, pretrained: bool = False, **kwargs) -> MobileNetV
|
|||
model_cls = MobileNetV3
|
||||
kwargs_filter = None
|
||||
if kwargs.pop('features_only', False):
|
||||
if 'feature_cfg' in kwargs:
|
||||
if 'feature_cfg' in kwargs or 'feature_cls' in kwargs:
|
||||
features_mode = 'cfg'
|
||||
else:
|
||||
kwargs_filter = ('num_classes', 'num_features', 'head_conv', 'head_bias', 'global_pool')
|
||||
|
|
|
@ -839,7 +839,7 @@ class MultiScaleVit(nn.Module):
|
|||
x: torch.Tensor,
|
||||
indices: Union[int, List[int], Tuple[int]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = True,
|
||||
stop_early: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
intermediates_only: bool = False,
|
||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
|
@ -855,13 +855,12 @@ class MultiScaleVit(nn.Module):
|
|||
Returns:
|
||||
|
||||
"""
|
||||
assert output_fmt in ('NCHW', 'NLC'), 'Output shape for MViT-V2 must be NCHW or NLC.'
|
||||
assert output_fmt in ('NCHW', 'NLC'), 'Output shape must be NCHW or NLC.'
|
||||
reshape = output_fmt == 'NCHW'
|
||||
intermediates = []
|
||||
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||
|
||||
# FIXME slice block/pos_block if < max
|
||||
|
||||
# forward pass
|
||||
x, feat_size = self.patch_embed(x)
|
||||
B = x.shape[0]
|
||||
|
@ -870,6 +869,7 @@ class MultiScaleVit(nn.Module):
|
|||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
if self.pos_embed is not None:
|
||||
x = x + self.pos_embed
|
||||
|
||||
for i, stage in enumerate(self.stages):
|
||||
x, feat_size = stage(x, feat_size)
|
||||
if i in take_indices:
|
||||
|
@ -891,6 +891,23 @@ class MultiScaleVit(nn.Module):
|
|||
|
||||
return x, intermediates
|
||||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
indices: Union[int, List[int], Tuple[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
""" Prune layers not required for specified intermediates.
|
||||
"""
|
||||
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||
# FIXME add stage pruning
|
||||
# self.stages = self.stages[:max_index] # truncate blocks w/ stem as idx 0
|
||||
if prune_norm:
|
||||
self.norm = nn.Identity()
|
||||
if prune_head:
|
||||
self.reset_classifier(0, '')
|
||||
return take_indices
|
||||
|
||||
def forward_features(self, x):
|
||||
x, feat_size = self.patch_embed(x)
|
||||
B, N, C = x.shape
|
||||
|
|
|
@ -19,6 +19,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|||
from timm.layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, GroupNorm, LayerType, create_attn, \
|
||||
get_attn, get_act_layer, get_norm_layer, create_classifier
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import feature_take_indices
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._registry import register_model, generate_default_cfgs, register_model_deprecations
|
||||
|
||||
|
@ -295,8 +296,8 @@ def drop_blocks(drop_prob: float = 0.):
|
|||
|
||||
def make_blocks(
|
||||
block_fn: Union[BasicBlock, Bottleneck],
|
||||
channels: List[int],
|
||||
block_repeats: List[int],
|
||||
channels: Tuple[int, ...],
|
||||
block_repeats: Tuple[int, ...],
|
||||
inplanes: int,
|
||||
reduce_first: int = 1,
|
||||
output_stride: int = 32,
|
||||
|
@ -394,7 +395,7 @@ class ResNet(nn.Module):
|
|||
def __init__(
|
||||
self,
|
||||
block: Union[BasicBlock, Bottleneck],
|
||||
layers: List[int],
|
||||
layers: Tuple[int, ...],
|
||||
num_classes: int = 1000,
|
||||
in_chans: int = 3,
|
||||
output_stride: int = 32,
|
||||
|
@ -497,7 +498,7 @@ class ResNet(nn.Module):
|
|||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
|
||||
# Feature Blocks
|
||||
channels = [64, 128, 256, 512]
|
||||
channels = (64, 128, 256, 512)
|
||||
stage_modules, stage_feature_info = make_blocks(
|
||||
block,
|
||||
channels,
|
||||
|
@ -553,6 +554,73 @@ class ResNet(nn.Module):
|
|||
self.num_classes = num_classes
|
||||
self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
|
||||
|
||||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Union[int, List[int], Tuple[int]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
intermediates_only: bool = False,
|
||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
""" Forward features that returns intermediates.
|
||||
|
||||
Args:
|
||||
x: Input image tensor
|
||||
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||
norm: Apply norm layer to compatible intermediates
|
||||
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||
output_fmt: Shape of intermediate feature outputs
|
||||
intermediates_only: Only return intermediate features
|
||||
Returns:
|
||||
|
||||
"""
|
||||
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||
if stop_early:
|
||||
assert intermediates_only, 'Must use intermediates_only for early stopping.'
|
||||
intermediates = []
|
||||
take_indices, max_index = feature_take_indices(5, indices)
|
||||
|
||||
# forward pass
|
||||
feat_idx = 0
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.act1(x)
|
||||
if feat_idx in take_indices:
|
||||
intermediates.append(x)
|
||||
x = self.maxpool(x)
|
||||
|
||||
layer_names = ('layer1', 'layer2', 'layer3', 'layer4')
|
||||
if stop_early:
|
||||
layer_names = layer_names[:max_index]
|
||||
for n in layer_names:
|
||||
feat_idx += 1
|
||||
x = getattr(self, n)(x) # won't work with torchscript, but keeps code reasonable, FML
|
||||
if feat_idx in take_indices:
|
||||
intermediates.append(x)
|
||||
|
||||
if intermediates_only:
|
||||
return intermediates
|
||||
|
||||
return x, intermediates
|
||||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
indices: Union[int, List[int], Tuple[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
""" Prune layers not required for specified intermediates.
|
||||
"""
|
||||
take_indices, max_index = feature_take_indices(5, indices)
|
||||
layer_names = ('layer1', 'layer2', 'layer3', 'layer4')
|
||||
layer_names = layer_names[max_index:]
|
||||
for n in layer_names:
|
||||
setattr(self, n, nn.Identity())
|
||||
if prune_head:
|
||||
self.reset_classifier(0, '')
|
||||
return take_indices
|
||||
|
||||
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
|
@ -1246,7 +1314,7 @@ default_cfgs = generate_default_cfgs({
|
|||
def resnet10t(pretrained: bool = False, **kwargs) -> ResNet:
|
||||
"""Constructs a ResNet-10-T model.
|
||||
"""
|
||||
model_args = dict(block=BasicBlock, layers=[1, 1, 1, 1], stem_width=32, stem_type='deep_tiered', avg_down=True)
|
||||
model_args = dict(block=BasicBlock, layers=(1, 1, 1, 1), stem_width=32, stem_type='deep_tiered', avg_down=True)
|
||||
return _create_resnet('resnet10t', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
|
@ -1254,7 +1322,7 @@ def resnet10t(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
def resnet14t(pretrained: bool = False, **kwargs) -> ResNet:
|
||||
"""Constructs a ResNet-14-T model.
|
||||
"""
|
||||
model_args = dict(block=Bottleneck, layers=[1, 1, 1, 1], stem_width=32, stem_type='deep_tiered', avg_down=True)
|
||||
model_args = dict(block=Bottleneck, layers=(1, 1, 1, 1), stem_width=32, stem_type='deep_tiered', avg_down=True)
|
||||
return _create_resnet('resnet14t', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
|
@ -1262,7 +1330,7 @@ def resnet14t(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
def resnet18(pretrained: bool = False, **kwargs) -> ResNet:
|
||||
"""Constructs a ResNet-18 model.
|
||||
"""
|
||||
model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2])
|
||||
model_args = dict(block=BasicBlock, layers=(2, 2, 2, 2))
|
||||
return _create_resnet('resnet18', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
|
@ -1270,7 +1338,7 @@ def resnet18(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
def resnet18d(pretrained: bool = False, **kwargs) -> ResNet:
|
||||
"""Constructs a ResNet-18-D model.
|
||||
"""
|
||||
model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], stem_width=32, stem_type='deep', avg_down=True)
|
||||
model_args = dict(block=BasicBlock, layers=(2, 2, 2, 2), stem_width=32, stem_type='deep', avg_down=True)
|
||||
return _create_resnet('resnet18d', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
|
@ -1278,7 +1346,7 @@ def resnet18d(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
def resnet34(pretrained: bool = False, **kwargs) -> ResNet:
|
||||
"""Constructs a ResNet-34 model.
|
||||
"""
|
||||
model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3])
|
||||
model_args = dict(block=BasicBlock, layers=(3, 4, 6, 3))
|
||||
return _create_resnet('resnet34', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
|
@ -1286,7 +1354,7 @@ def resnet34(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
def resnet34d(pretrained: bool = False, **kwargs) -> ResNet:
|
||||
"""Constructs a ResNet-34-D model.
|
||||
"""
|
||||
model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True)
|
||||
model_args = dict(block=BasicBlock, layers=(3, 4, 6, 3), stem_width=32, stem_type='deep', avg_down=True)
|
||||
return _create_resnet('resnet34d', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
|
@ -1294,7 +1362,7 @@ def resnet34d(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
def resnet26(pretrained: bool = False, **kwargs) -> ResNet:
|
||||
"""Constructs a ResNet-26 model.
|
||||
"""
|
||||
model_args = dict(block=Bottleneck, layers=[2, 2, 2, 2])
|
||||
model_args = dict(block=Bottleneck, layers=(2, 2, 2, 2))
|
||||
return _create_resnet('resnet26', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
|
@ -1302,7 +1370,7 @@ def resnet26(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
def resnet26t(pretrained: bool = False, **kwargs) -> ResNet:
|
||||
"""Constructs a ResNet-26-T model.
|
||||
"""
|
||||
model_args = dict(block=Bottleneck, layers=[2, 2, 2, 2], stem_width=32, stem_type='deep_tiered', avg_down=True)
|
||||
model_args = dict(block=Bottleneck, layers=(2, 2, 2, 2), stem_width=32, stem_type='deep_tiered', avg_down=True)
|
||||
return _create_resnet('resnet26t', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
|
@ -1310,7 +1378,7 @@ def resnet26t(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
def resnet26d(pretrained: bool = False, **kwargs) -> ResNet:
|
||||
"""Constructs a ResNet-26-D model.
|
||||
"""
|
||||
model_args = dict(block=Bottleneck, layers=[2, 2, 2, 2], stem_width=32, stem_type='deep', avg_down=True)
|
||||
model_args = dict(block=Bottleneck, layers=(2, 2, 2, 2), stem_width=32, stem_type='deep', avg_down=True)
|
||||
return _create_resnet('resnet26d', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
|
@ -1318,7 +1386,7 @@ def resnet26d(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
def resnet50(pretrained: bool = False, **kwargs) -> ResNet:
|
||||
"""Constructs a ResNet-50 model.
|
||||
"""
|
||||
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3])
|
||||
model_args = dict(block=Bottleneck, layers=(3, 4, 6, 3))
|
||||
return _create_resnet('resnet50', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
|
@ -1326,7 +1394,7 @@ def resnet50(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
def resnet50c(pretrained: bool = False, **kwargs) -> ResNet:
|
||||
"""Constructs a ResNet-50-C model.
|
||||
"""
|
||||
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep')
|
||||
model_args = dict(block=Bottleneck, layers=(3, 4, 6, 3), stem_width=32, stem_type='deep')
|
||||
return _create_resnet('resnet50c', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
|
@ -1334,7 +1402,7 @@ def resnet50c(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
def resnet50d(pretrained: bool = False, **kwargs) -> ResNet:
|
||||
"""Constructs a ResNet-50-D model.
|
||||
"""
|
||||
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True)
|
||||
model_args = dict(block=Bottleneck, layers=(3, 4, 6, 3), stem_width=32, stem_type='deep', avg_down=True)
|
||||
return _create_resnet('resnet50d', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
|
@ -1342,7 +1410,7 @@ def resnet50d(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
def resnet50s(pretrained: bool = False, **kwargs) -> ResNet:
|
||||
"""Constructs a ResNet-50-S model.
|
||||
"""
|
||||
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], stem_width=64, stem_type='deep')
|
||||
model_args = dict(block=Bottleneck, layers=(3, 4, 6, 3), stem_width=64, stem_type='deep')
|
||||
return _create_resnet('resnet50s', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
|
@ -1350,7 +1418,7 @@ def resnet50s(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
def resnet50t(pretrained: bool = False, **kwargs) -> ResNet:
|
||||
"""Constructs a ResNet-50-T model.
|
||||
"""
|
||||
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep_tiered', avg_down=True)
|
||||
model_args = dict(block=Bottleneck, layers=(3, 4, 6, 3), stem_width=32, stem_type='deep_tiered', avg_down=True)
|
||||
return _create_resnet('resnet50t', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
|
@ -1358,7 +1426,7 @@ def resnet50t(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
def resnet101(pretrained: bool = False, **kwargs) -> ResNet:
|
||||
"""Constructs a ResNet-101 model.
|
||||
"""
|
||||
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3])
|
||||
model_args = dict(block=Bottleneck, layers=(3, 4, 23, 3))
|
||||
return _create_resnet('resnet101', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
|
@ -1366,7 +1434,7 @@ def resnet101(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
def resnet101c(pretrained: bool = False, **kwargs) -> ResNet:
|
||||
"""Constructs a ResNet-101-C model.
|
||||
"""
|
||||
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep')
|
||||
model_args = dict(block=Bottleneck, layers=(3, 4, 23, 3), stem_width=32, stem_type='deep')
|
||||
return _create_resnet('resnet101c', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
|
@ -1374,7 +1442,7 @@ def resnet101c(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
def resnet101d(pretrained: bool = False, **kwargs) -> ResNet:
|
||||
"""Constructs a ResNet-101-D model.
|
||||
"""
|
||||
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True)
|
||||
model_args = dict(block=Bottleneck, layers=(3, 4, 23, 3), stem_width=32, stem_type='deep', avg_down=True)
|
||||
return _create_resnet('resnet101d', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
|
@ -1382,7 +1450,7 @@ def resnet101d(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
def resnet101s(pretrained: bool = False, **kwargs) -> ResNet:
|
||||
"""Constructs a ResNet-101-S model.
|
||||
"""
|
||||
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], stem_width=64, stem_type='deep')
|
||||
model_args = dict(block=Bottleneck, layers=(3, 4, 23, 3), stem_width=64, stem_type='deep')
|
||||
return _create_resnet('resnet101s', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
|
@ -1390,7 +1458,7 @@ def resnet101s(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
def resnet152(pretrained: bool = False, **kwargs) -> ResNet:
|
||||
"""Constructs a ResNet-152 model.
|
||||
"""
|
||||
model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3])
|
||||
model_args = dict(block=Bottleneck, layers=(3, 8, 36, 3))
|
||||
return _create_resnet('resnet152', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
|
@ -1398,7 +1466,7 @@ def resnet152(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
def resnet152c(pretrained: bool = False, **kwargs) -> ResNet:
|
||||
"""Constructs a ResNet-152-C model.
|
||||
"""
|
||||
model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep')
|
||||
model_args = dict(block=Bottleneck, layers=(3, 8, 36, 3), stem_width=32, stem_type='deep')
|
||||
return _create_resnet('resnet152c', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
|
@ -1406,7 +1474,7 @@ def resnet152c(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
def resnet152d(pretrained: bool = False, **kwargs) -> ResNet:
|
||||
"""Constructs a ResNet-152-D model.
|
||||
"""
|
||||
model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', avg_down=True)
|
||||
model_args = dict(block=Bottleneck, layers=(3, 8, 36, 3), stem_width=32, stem_type='deep', avg_down=True)
|
||||
return _create_resnet('resnet152d', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
|
@ -1414,7 +1482,7 @@ def resnet152d(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
def resnet152s(pretrained: bool = False, **kwargs) -> ResNet:
|
||||
"""Constructs a ResNet-152-S model.
|
||||
"""
|
||||
model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], stem_width=64, stem_type='deep')
|
||||
model_args = dict(block=Bottleneck, layers=(3, 8, 36, 3), stem_width=64, stem_type='deep')
|
||||
return _create_resnet('resnet152s', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
|
@ -1422,7 +1490,7 @@ def resnet152s(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
def resnet200(pretrained: bool = False, **kwargs) -> ResNet:
|
||||
"""Constructs a ResNet-200 model.
|
||||
"""
|
||||
model_args = dict(block=Bottleneck, layers=[3, 24, 36, 3])
|
||||
model_args = dict(block=Bottleneck, layers=(3, 24, 36, 3))
|
||||
return _create_resnet('resnet200', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
|
@ -1430,7 +1498,7 @@ def resnet200(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
def resnet200d(pretrained: bool = False, **kwargs) -> ResNet:
|
||||
"""Constructs a ResNet-200-D model.
|
||||
"""
|
||||
model_args = dict(block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', avg_down=True)
|
||||
model_args = dict(block=Bottleneck, layers=(3, 24, 36, 3), stem_width=32, stem_type='deep', avg_down=True)
|
||||
return _create_resnet('resnet200d', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
|
@ -1442,7 +1510,7 @@ def wide_resnet50_2(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
|
||||
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
|
||||
"""
|
||||
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], base_width=128)
|
||||
model_args = dict(block=Bottleneck, layers=(3, 4, 6, 3), base_width=128)
|
||||
return _create_resnet('wide_resnet50_2', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
|
@ -1453,7 +1521,7 @@ def wide_resnet101_2(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
which is twice larger in every block. The number of channels in outer 1x1
|
||||
convolutions is the same.
|
||||
"""
|
||||
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], base_width=128)
|
||||
model_args = dict(block=Bottleneck, layers=(3, 4, 23, 3), base_width=128)
|
||||
return _create_resnet('wide_resnet101_2', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
|
@ -1461,7 +1529,7 @@ def wide_resnet101_2(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
def resnet50_gn(pretrained: bool = False, **kwargs) -> ResNet:
|
||||
"""Constructs a ResNet-50 model w/ GroupNorm
|
||||
"""
|
||||
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], norm_layer='groupnorm')
|
||||
model_args = dict(block=Bottleneck, layers=(3, 4, 6, 3), norm_layer='groupnorm')
|
||||
return _create_resnet('resnet50_gn', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
|
@ -1469,7 +1537,7 @@ def resnet50_gn(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
def resnext50_32x4d(pretrained: bool = False, **kwargs) -> ResNet:
|
||||
"""Constructs a ResNeXt50-32x4d model.
|
||||
"""
|
||||
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4)
|
||||
model_args = dict(block=Bottleneck, layers=(3, 4, 6, 3), cardinality=32, base_width=4)
|
||||
return _create_resnet('resnext50_32x4d', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
|
@ -1478,7 +1546,7 @@ def resnext50d_32x4d(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
"""Constructs a ResNeXt50d-32x4d model. ResNext50 w/ deep stem & avg pool downsample
|
||||
"""
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4,
|
||||
block=Bottleneck, layers=(3, 4, 6, 3), cardinality=32, base_width=4,
|
||||
stem_width=32, stem_type='deep', avg_down=True)
|
||||
return _create_resnet('resnext50d_32x4d', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
@ -1487,7 +1555,7 @@ def resnext50d_32x4d(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
def resnext101_32x4d(pretrained: bool = False, **kwargs) -> ResNet:
|
||||
"""Constructs a ResNeXt-101 32x4d model.
|
||||
"""
|
||||
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4)
|
||||
model_args = dict(block=Bottleneck, layers=(3, 4, 23, 3), cardinality=32, base_width=4)
|
||||
return _create_resnet('resnext101_32x4d', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
|
@ -1495,7 +1563,7 @@ def resnext101_32x4d(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
def resnext101_32x8d(pretrained: bool = False, **kwargs) -> ResNet:
|
||||
"""Constructs a ResNeXt-101 32x8d model.
|
||||
"""
|
||||
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8)
|
||||
model_args = dict(block=Bottleneck, layers=(3, 4, 23, 3), cardinality=32, base_width=8)
|
||||
return _create_resnet('resnext101_32x8d', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
|
@ -1503,7 +1571,7 @@ def resnext101_32x8d(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
def resnext101_32x16d(pretrained: bool = False, **kwargs) -> ResNet:
|
||||
"""Constructs a ResNeXt-101 32x16d model
|
||||
"""
|
||||
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=16)
|
||||
model_args = dict(block=Bottleneck, layers=(3, 4, 23, 3), cardinality=32, base_width=16)
|
||||
return _create_resnet('resnext101_32x16d', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
|
@ -1511,7 +1579,7 @@ def resnext101_32x16d(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
def resnext101_32x32d(pretrained: bool = False, **kwargs) -> ResNet:
|
||||
"""Constructs a ResNeXt-101 32x32d model
|
||||
"""
|
||||
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=32)
|
||||
model_args = dict(block=Bottleneck, layers=(3, 4, 23, 3), cardinality=32, base_width=32)
|
||||
return _create_resnet('resnext101_32x32d', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
|
@ -1519,7 +1587,7 @@ def resnext101_32x32d(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
def resnext101_64x4d(pretrained: bool = False, **kwargs) -> ResNet:
|
||||
"""Constructs a ResNeXt101-64x4d model.
|
||||
"""
|
||||
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=64, base_width=4)
|
||||
model_args = dict(block=Bottleneck, layers=(3, 4, 23, 3), cardinality=64, base_width=4)
|
||||
return _create_resnet('resnext101_64x4d', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
|
@ -1530,7 +1598,7 @@ def ecaresnet26t(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
in the deep stem and ECA attn.
|
||||
"""
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[2, 2, 2, 2], stem_width=32,
|
||||
block=Bottleneck, layers=(2, 2, 2, 2), stem_width=32,
|
||||
stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='eca'))
|
||||
return _create_resnet('ecaresnet26t', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
@ -1540,7 +1608,7 @@ def ecaresnet50d(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
"""Constructs a ResNet-50-D model with eca.
|
||||
"""
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True,
|
||||
block=Bottleneck, layers=(3, 4, 6, 3), stem_width=32, stem_type='deep', avg_down=True,
|
||||
block_args=dict(attn_layer='eca'))
|
||||
return _create_resnet('ecaresnet50d', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
@ -1551,7 +1619,7 @@ def ecaresnet50d_pruned(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf
|
||||
"""
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True,
|
||||
block=Bottleneck, layers=(3, 4, 6, 3), stem_width=32, stem_type='deep', avg_down=True,
|
||||
block_args=dict(attn_layer='eca'))
|
||||
return _create_resnet('ecaresnet50d_pruned', pretrained, pruned=True, **dict(model_args, **kwargs))
|
||||
|
||||
|
@ -1562,7 +1630,7 @@ def ecaresnet50t(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
Like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels in the deep stem and ECA attn.
|
||||
"""
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32,
|
||||
block=Bottleneck, layers=(3, 4, 6, 3), stem_width=32,
|
||||
stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='eca'))
|
||||
return _create_resnet('ecaresnet50t', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
@ -1572,7 +1640,7 @@ def ecaresnetlight(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
"""Constructs a ResNet-50-D light model with eca.
|
||||
"""
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[1, 1, 11, 3], stem_width=32, avg_down=True,
|
||||
block=Bottleneck, layers=(1, 1, 11, 3), stem_width=32, avg_down=True,
|
||||
block_args=dict(attn_layer='eca'))
|
||||
return _create_resnet('ecaresnetlight', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
@ -1582,7 +1650,7 @@ def ecaresnet101d(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
"""Constructs a ResNet-101-D model with eca.
|
||||
"""
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True,
|
||||
block=Bottleneck, layers=(3, 4, 23, 3), stem_width=32, stem_type='deep', avg_down=True,
|
||||
block_args=dict(attn_layer='eca'))
|
||||
return _create_resnet('ecaresnet101d', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
@ -1593,7 +1661,7 @@ def ecaresnet101d_pruned(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf
|
||||
"""
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True,
|
||||
block=Bottleneck, layers=(3, 4, 23, 3), stem_width=32, stem_type='deep', avg_down=True,
|
||||
block_args=dict(attn_layer='eca'))
|
||||
return _create_resnet('ecaresnet101d_pruned', pretrained, pruned=True, **dict(model_args, **kwargs))
|
||||
|
||||
|
@ -1603,7 +1671,7 @@ def ecaresnet200d(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
"""Constructs a ResNet-200-D model with ECA.
|
||||
"""
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', avg_down=True,
|
||||
block=Bottleneck, layers=(3, 24, 36, 3), stem_width=32, stem_type='deep', avg_down=True,
|
||||
block_args=dict(attn_layer='eca'))
|
||||
return _create_resnet('ecaresnet200d', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
@ -1613,7 +1681,7 @@ def ecaresnet269d(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
"""Constructs a ResNet-269-D model with ECA.
|
||||
"""
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[3, 30, 48, 8], stem_width=32, stem_type='deep', avg_down=True,
|
||||
block=Bottleneck, layers=(3, 30, 48, 8), stem_width=32, stem_type='deep', avg_down=True,
|
||||
block_args=dict(attn_layer='eca'))
|
||||
return _create_resnet('ecaresnet269d', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
@ -1625,7 +1693,7 @@ def ecaresnext26t_32x4d(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
in the deep stem. This model replaces SE module with the ECA module
|
||||
"""
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32,
|
||||
block=Bottleneck, layers=(2, 2, 2, 2), cardinality=32, base_width=4, stem_width=32,
|
||||
stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='eca'))
|
||||
return _create_resnet('ecaresnext26t_32x4d', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
@ -1637,53 +1705,53 @@ def ecaresnext50t_32x4d(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
in the deep stem. This model replaces SE module with the ECA module
|
||||
"""
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32,
|
||||
block=Bottleneck, layers=(2, 2, 2, 2), cardinality=32, base_width=4, stem_width=32,
|
||||
stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='eca'))
|
||||
return _create_resnet('ecaresnext50t_32x4d', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def seresnet18(pretrained: bool = False, **kwargs) -> ResNet:
|
||||
model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], block_args=dict(attn_layer='se'))
|
||||
model_args = dict(block=BasicBlock, layers=(2, 2, 2, 2), block_args=dict(attn_layer='se'))
|
||||
return _create_resnet('seresnet18', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def seresnet34(pretrained: bool = False, **kwargs) -> ResNet:
|
||||
model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], block_args=dict(attn_layer='se'))
|
||||
model_args = dict(block=BasicBlock, layers=(3, 4, 6, 3), block_args=dict(attn_layer='se'))
|
||||
return _create_resnet('seresnet34', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def seresnet50(pretrained: bool = False, **kwargs) -> ResNet:
|
||||
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], block_args=dict(attn_layer='se'))
|
||||
model_args = dict(block=Bottleneck, layers=(3, 4, 6, 3), block_args=dict(attn_layer='se'))
|
||||
return _create_resnet('seresnet50', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def seresnet50t(pretrained: bool = False, **kwargs) -> ResNet:
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep_tiered',
|
||||
block=Bottleneck, layers=(3, 4, 6, 3), stem_width=32, stem_type='deep_tiered',
|
||||
avg_down=True, block_args=dict(attn_layer='se'))
|
||||
return _create_resnet('seresnet50t', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def seresnet101(pretrained: bool = False, **kwargs) -> ResNet:
|
||||
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], block_args=dict(attn_layer='se'))
|
||||
model_args = dict(block=Bottleneck, layers=(3, 4, 23, 3), block_args=dict(attn_layer='se'))
|
||||
return _create_resnet('seresnet101', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def seresnet152(pretrained: bool = False, **kwargs) -> ResNet:
|
||||
model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], block_args=dict(attn_layer='se'))
|
||||
model_args = dict(block=Bottleneck, layers=(3, 8, 36, 3), block_args=dict(attn_layer='se'))
|
||||
return _create_resnet('seresnet152', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def seresnet152d(pretrained: bool = False, **kwargs) -> ResNet:
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep',
|
||||
block=Bottleneck, layers=(3, 8, 36, 3), stem_width=32, stem_type='deep',
|
||||
avg_down=True, block_args=dict(attn_layer='se'))
|
||||
return _create_resnet('seresnet152d', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
@ -1693,7 +1761,7 @@ def seresnet200d(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
"""Constructs a ResNet-200-D model with SE attn.
|
||||
"""
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep',
|
||||
block=Bottleneck, layers=(3, 24, 36, 3), stem_width=32, stem_type='deep',
|
||||
avg_down=True, block_args=dict(attn_layer='se'))
|
||||
return _create_resnet('seresnet200d', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
@ -1703,7 +1771,7 @@ def seresnet269d(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
"""Constructs a ResNet-269-D model with SE attn.
|
||||
"""
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[3, 30, 48, 8], stem_width=32, stem_type='deep',
|
||||
block=Bottleneck, layers=(3, 30, 48, 8), stem_width=32, stem_type='deep',
|
||||
avg_down=True, block_args=dict(attn_layer='se'))
|
||||
return _create_resnet('seresnet269d', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
@ -1715,7 +1783,7 @@ def seresnext26d_32x4d(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
combination of deep stem and avg_pool in downsample.
|
||||
"""
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32,
|
||||
block=Bottleneck, layers=(2, 2, 2, 2), cardinality=32, base_width=4, stem_width=32,
|
||||
stem_type='deep', avg_down=True, block_args=dict(attn_layer='se'))
|
||||
return _create_resnet('seresnext26d_32x4d', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
@ -1727,7 +1795,7 @@ def seresnext26t_32x4d(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
in the deep stem.
|
||||
"""
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32,
|
||||
block=Bottleneck, layers=(2, 2, 2, 2), cardinality=32, base_width=4, stem_width=32,
|
||||
stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='se'))
|
||||
return _create_resnet('seresnext26t_32x4d', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
@ -1735,7 +1803,7 @@ def seresnext26t_32x4d(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
@register_model
|
||||
def seresnext50_32x4d(pretrained: bool = False, **kwargs) -> ResNet:
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4,
|
||||
block=Bottleneck, layers=(3, 4, 6, 3), cardinality=32, base_width=4,
|
||||
block_args=dict(attn_layer='se'))
|
||||
return _create_resnet('seresnext50_32x4d', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
@ -1743,7 +1811,7 @@ def seresnext50_32x4d(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
@register_model
|
||||
def seresnext101_32x4d(pretrained: bool = False, **kwargs) -> ResNet:
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4,
|
||||
block=Bottleneck, layers=(3, 4, 23, 3), cardinality=32, base_width=4,
|
||||
block_args=dict(attn_layer='se'))
|
||||
return _create_resnet('seresnext101_32x4d', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
@ -1751,7 +1819,7 @@ def seresnext101_32x4d(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
@register_model
|
||||
def seresnext101_32x8d(pretrained: bool = False, **kwargs) -> ResNet:
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8,
|
||||
block=Bottleneck, layers=(3, 4, 23, 3), cardinality=32, base_width=8,
|
||||
block_args=dict(attn_layer='se'))
|
||||
return _create_resnet('seresnext101_32x8d', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
@ -1759,7 +1827,7 @@ def seresnext101_32x8d(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
@register_model
|
||||
def seresnext101d_32x8d(pretrained: bool = False, **kwargs) -> ResNet:
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8,
|
||||
block=Bottleneck, layers=(3, 4, 23, 3), cardinality=32, base_width=8,
|
||||
stem_width=32, stem_type='deep', avg_down=True,
|
||||
block_args=dict(attn_layer='se'))
|
||||
return _create_resnet('seresnext101d_32x8d', pretrained, **dict(model_args, **kwargs))
|
||||
|
@ -1768,7 +1836,7 @@ def seresnext101d_32x8d(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
@register_model
|
||||
def seresnext101_64x4d(pretrained: bool = False, **kwargs) -> ResNet:
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[3, 4, 23, 3], cardinality=64, base_width=4,
|
||||
block=Bottleneck, layers=(3, 4, 23, 3), cardinality=64, base_width=4,
|
||||
block_args=dict(attn_layer='se'))
|
||||
return _create_resnet('seresnext101_64x4d', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
@ -1776,7 +1844,7 @@ def seresnext101_64x4d(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
@register_model
|
||||
def senet154(pretrained: bool = False, **kwargs) -> ResNet:
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[3, 8, 36, 3], cardinality=64, base_width=4, stem_type='deep',
|
||||
block=Bottleneck, layers=(3, 8, 36, 3), cardinality=64, base_width=4, stem_type='deep',
|
||||
down_kernel_size=3, block_reduce_first=2, block_args=dict(attn_layer='se'))
|
||||
return _create_resnet('senet154', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
@ -1785,7 +1853,7 @@ def senet154(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
def resnetblur18(pretrained: bool = False, **kwargs) -> ResNet:
|
||||
"""Constructs a ResNet-18 model with blur anti-aliasing
|
||||
"""
|
||||
model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], aa_layer=BlurPool2d)
|
||||
model_args = dict(block=BasicBlock, layers=(2, 2, 2, 2), aa_layer=BlurPool2d)
|
||||
return _create_resnet('resnetblur18', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
|
@ -1793,7 +1861,7 @@ def resnetblur18(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
def resnetblur50(pretrained: bool = False, **kwargs) -> ResNet:
|
||||
"""Constructs a ResNet-50 model with blur anti-aliasing
|
||||
"""
|
||||
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=BlurPool2d)
|
||||
model_args = dict(block=Bottleneck, layers=(3, 4, 6, 3), aa_layer=BlurPool2d)
|
||||
return _create_resnet('resnetblur50', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
|
@ -1802,7 +1870,7 @@ def resnetblur50d(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
"""Constructs a ResNet-50-D model with blur anti-aliasing
|
||||
"""
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=BlurPool2d,
|
||||
block=Bottleneck, layers=(3, 4, 6, 3), aa_layer=BlurPool2d,
|
||||
stem_width=32, stem_type='deep', avg_down=True)
|
||||
return _create_resnet('resnetblur50d', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
@ -1812,7 +1880,7 @@ def resnetblur101d(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
"""Constructs a ResNet-101-D model with blur anti-aliasing
|
||||
"""
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[3, 4, 23, 3], aa_layer=BlurPool2d,
|
||||
block=Bottleneck, layers=(3, 4, 23, 3), aa_layer=BlurPool2d,
|
||||
stem_width=32, stem_type='deep', avg_down=True)
|
||||
return _create_resnet('resnetblur101d', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
@ -1822,7 +1890,7 @@ def resnetaa34d(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
"""Constructs a ResNet-34-D model w/ avgpool anti-aliasing
|
||||
"""
|
||||
model_args = dict(
|
||||
block=BasicBlock, layers=[3, 4, 6, 3], aa_layer=nn.AvgPool2d, stem_width=32, stem_type='deep', avg_down=True)
|
||||
block=BasicBlock, layers=(3, 4, 6, 3), aa_layer=nn.AvgPool2d, stem_width=32, stem_type='deep', avg_down=True)
|
||||
return _create_resnet('resnetaa34d', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
|
@ -1830,7 +1898,7 @@ def resnetaa34d(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
def resnetaa50(pretrained: bool = False, **kwargs) -> ResNet:
|
||||
"""Constructs a ResNet-50 model with avgpool anti-aliasing
|
||||
"""
|
||||
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=nn.AvgPool2d)
|
||||
model_args = dict(block=Bottleneck, layers=(3, 4, 6, 3), aa_layer=nn.AvgPool2d)
|
||||
return _create_resnet('resnetaa50', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
|
@ -1839,7 +1907,7 @@ def resnetaa50d(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
"""Constructs a ResNet-50-D model with avgpool anti-aliasing
|
||||
"""
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=nn.AvgPool2d,
|
||||
block=Bottleneck, layers=(3, 4, 6, 3), aa_layer=nn.AvgPool2d,
|
||||
stem_width=32, stem_type='deep', avg_down=True)
|
||||
return _create_resnet('resnetaa50d', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
@ -1849,7 +1917,7 @@ def resnetaa101d(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
"""Constructs a ResNet-101-D model with avgpool anti-aliasing
|
||||
"""
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[3, 4, 23, 3], aa_layer=nn.AvgPool2d,
|
||||
block=Bottleneck, layers=(3, 4, 23, 3), aa_layer=nn.AvgPool2d,
|
||||
stem_width=32, stem_type='deep', avg_down=True)
|
||||
return _create_resnet('resnetaa101d', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
@ -1859,7 +1927,7 @@ def seresnetaa50d(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
"""Constructs a SE=ResNet-50-D model with avgpool anti-aliasing
|
||||
"""
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=nn.AvgPool2d,
|
||||
block=Bottleneck, layers=(3, 4, 6, 3), aa_layer=nn.AvgPool2d,
|
||||
stem_width=32, stem_type='deep', avg_down=True, block_args=dict(attn_layer='se'))
|
||||
return _create_resnet('seresnetaa50d', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
@ -1869,7 +1937,7 @@ def seresnextaa101d_32x8d(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
"""Constructs a SE=ResNeXt-101-D 32x8d model with avgpool anti-aliasing
|
||||
"""
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8,
|
||||
block=Bottleneck, layers=(3, 4, 23, 3), cardinality=32, base_width=8,
|
||||
stem_width=32, stem_type='deep', avg_down=True, aa_layer=nn.AvgPool2d,
|
||||
block_args=dict(attn_layer='se'))
|
||||
return _create_resnet('seresnextaa101d_32x8d', pretrained, **dict(model_args, **kwargs))
|
||||
|
@ -1880,7 +1948,7 @@ def seresnextaa201d_32x8d(pretrained: bool = False, **kwargs):
|
|||
"""Constructs a SE=ResNeXt-101-D 32x8d model with avgpool anti-aliasing
|
||||
"""
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[3, 24, 36, 4], cardinality=32, base_width=8,
|
||||
block=Bottleneck, layers=(3, 24, 36, 4), cardinality=32, base_width=8,
|
||||
stem_width=64, stem_type='deep', avg_down=True, aa_layer=nn.AvgPool2d,
|
||||
block_args=dict(attn_layer='se'))
|
||||
return _create_resnet('seresnextaa201d_32x8d', pretrained, **dict(model_args, **kwargs))
|
||||
|
@ -1894,7 +1962,7 @@ def resnetrs50(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
"""
|
||||
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', replace_stem_pool=True,
|
||||
block=Bottleneck, layers=(3, 4, 6, 3), stem_width=32, stem_type='deep', replace_stem_pool=True,
|
||||
avg_down=True, block_args=dict(attn_layer=attn_layer))
|
||||
return _create_resnet('resnetrs50', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
@ -1907,7 +1975,7 @@ def resnetrs101(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
"""
|
||||
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', replace_stem_pool=True,
|
||||
block=Bottleneck, layers=(3, 4, 23, 3), stem_width=32, stem_type='deep', replace_stem_pool=True,
|
||||
avg_down=True, block_args=dict(attn_layer=attn_layer))
|
||||
return _create_resnet('resnetrs101', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
@ -1920,7 +1988,7 @@ def resnetrs152(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
"""
|
||||
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', replace_stem_pool=True,
|
||||
block=Bottleneck, layers=(3, 8, 36, 3), stem_width=32, stem_type='deep', replace_stem_pool=True,
|
||||
avg_down=True, block_args=dict(attn_layer=attn_layer))
|
||||
return _create_resnet('resnetrs152', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
@ -1933,7 +2001,7 @@ def resnetrs200(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
"""
|
||||
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', replace_stem_pool=True,
|
||||
block=Bottleneck, layers=(3, 24, 36, 3), stem_width=32, stem_type='deep', replace_stem_pool=True,
|
||||
avg_down=True, block_args=dict(attn_layer=attn_layer))
|
||||
return _create_resnet('resnetrs200', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
@ -1946,7 +2014,7 @@ def resnetrs270(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
"""
|
||||
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[4, 29, 53, 4], stem_width=32, stem_type='deep', replace_stem_pool=True,
|
||||
block=Bottleneck, layers=(4, 29, 53, 4), stem_width=32, stem_type='deep', replace_stem_pool=True,
|
||||
avg_down=True, block_args=dict(attn_layer=attn_layer))
|
||||
return _create_resnet('resnetrs270', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
@ -1960,7 +2028,7 @@ def resnetrs350(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
"""
|
||||
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[4, 36, 72, 4], stem_width=32, stem_type='deep', replace_stem_pool=True,
|
||||
block=Bottleneck, layers=(4, 36, 72, 4), stem_width=32, stem_type='deep', replace_stem_pool=True,
|
||||
avg_down=True, block_args=dict(attn_layer=attn_layer))
|
||||
return _create_resnet('resnetrs350', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
@ -1973,7 +2041,7 @@ def resnetrs420(pretrained: bool = False, **kwargs) -> ResNet:
|
|||
"""
|
||||
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[4, 44, 87, 4], stem_width=32, stem_type='deep', replace_stem_pool=True,
|
||||
block=Bottleneck, layers=(4, 44, 87, 4), stem_width=32, stem_type='deep', replace_stem_pool=True,
|
||||
avg_down=True, block_args=dict(attn_layer=attn_layer))
|
||||
return _create_resnet('resnetrs420', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
|
|
@ -26,6 +26,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|||
from timm.layers import PatchEmbed, Mlp, DropPath, ClassifierHead, to_2tuple, to_ntuple, trunc_normal_, \
|
||||
_assert, use_fused_attn, resize_rel_pos_bias_table, resample_patch_embed, ndgrid
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import feature_take_indices
|
||||
from ._features_fx import register_notrace_function
|
||||
from ._manipulate import checkpoint_seq, named_apply
|
||||
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
||||
|
@ -607,6 +608,72 @@ class SwinTransformer(nn.Module):
|
|||
self.num_classes = num_classes
|
||||
self.head.reset(num_classes, pool_type=global_pool)
|
||||
|
||||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Union[int, List[int], Tuple[int]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
intermediates_only: bool = False,
|
||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
""" Forward features that returns intermediates.
|
||||
|
||||
Args:
|
||||
x: Input image tensor
|
||||
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||
norm: Apply norm layer to compatible intermediates
|
||||
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||
output_fmt: Shape of intermediate feature outputs
|
||||
intermediates_only: Only return intermediate features
|
||||
Returns:
|
||||
|
||||
"""
|
||||
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||
intermediates = []
|
||||
take_indices, max_index = feature_take_indices(len(self.layers), indices)
|
||||
|
||||
# forward pass
|
||||
x = self.patch_embed(x)
|
||||
|
||||
num_stages = len(self.layers)
|
||||
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||
stages = self.layers
|
||||
else:
|
||||
stages = self.layers[:max_index + 1]
|
||||
for i, stage in enumerate(stages):
|
||||
x = stage(x)
|
||||
if i in take_indices:
|
||||
if norm and i == num_stages - 1:
|
||||
x_inter = self.norm(x) # applying final norm last intermediate
|
||||
else:
|
||||
x_inter = x
|
||||
x_inter = x_inter.permute(0, 3, 1, 2).contiguous()
|
||||
intermediates.append(x_inter)
|
||||
|
||||
if intermediates_only:
|
||||
return intermediates
|
||||
|
||||
x = self.norm(x)
|
||||
|
||||
return x, intermediates
|
||||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
indices: Union[int, List[int], Tuple[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
""" Prune layers not required for specified intermediates.
|
||||
"""
|
||||
take_indices, max_index = feature_take_indices(len(self.layers), indices)
|
||||
self.layers = self.layers[:max_index + 1] # truncate blocks
|
||||
if prune_norm:
|
||||
self.norm = nn.Identity()
|
||||
if prune_head:
|
||||
self.reset_classifier(0, '')
|
||||
return take_indices
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.patch_embed(x)
|
||||
x = self.layers(x)
|
||||
|
|
|
@ -13,7 +13,7 @@ Modifications and additions for timm hacked together by / Copyright 2022, Ross W
|
|||
# Written by Ze Liu
|
||||
# --------------------------------------------------------
|
||||
import math
|
||||
from typing import Callable, Optional, Tuple, Union, Set, Dict
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -24,6 +24,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|||
from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, _assert, ClassifierHead,\
|
||||
resample_patch_embed, ndgrid, get_act_layer, LayerType
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import feature_take_indices
|
||||
from ._features_fx import register_notrace_function
|
||||
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
||||
|
||||
|
@ -608,6 +609,72 @@ class SwinTransformerV2(nn.Module):
|
|||
self.num_classes = num_classes
|
||||
self.head.reset(num_classes, global_pool)
|
||||
|
||||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Union[int, List[int], Tuple[int]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
intermediates_only: bool = False,
|
||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
""" Forward features that returns intermediates.
|
||||
|
||||
Args:
|
||||
x: Input image tensor
|
||||
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||
norm: Apply norm layer to compatible intermediates
|
||||
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||
output_fmt: Shape of intermediate feature outputs
|
||||
intermediates_only: Only return intermediate features
|
||||
Returns:
|
||||
|
||||
"""
|
||||
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||
intermediates = []
|
||||
take_indices, max_index = feature_take_indices(len(self.layers), indices)
|
||||
|
||||
# forward pass
|
||||
x = self.patch_embed(x)
|
||||
|
||||
num_stages = len(self.layers)
|
||||
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||
stages = self.layers
|
||||
else:
|
||||
stages = self.layers[:max_index + 1]
|
||||
for i, stage in enumerate(stages):
|
||||
x = stage(x)
|
||||
if i in take_indices:
|
||||
if norm and i == num_stages - 1:
|
||||
x_inter = self.norm(x) # applying final norm last intermediate
|
||||
else:
|
||||
x_inter = x
|
||||
x_inter = x_inter.permute(0, 3, 1, 2).contiguous()
|
||||
intermediates.append(x_inter)
|
||||
|
||||
if intermediates_only:
|
||||
return intermediates
|
||||
|
||||
x = self.norm(x)
|
||||
|
||||
return x, intermediates
|
||||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
indices: Union[int, List[int], Tuple[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
""" Prune layers not required for specified intermediates.
|
||||
"""
|
||||
take_indices, max_index = feature_take_indices(len(self.layers), indices)
|
||||
self.layers = self.layers[:max_index + 1] # truncate blocks
|
||||
if prune_norm:
|
||||
self.norm = nn.Identity()
|
||||
if prune_head:
|
||||
self.reset_classifier(0, '')
|
||||
return take_indices
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.patch_embed(x)
|
||||
x = self.layers(x)
|
||||
|
|
|
@ -39,6 +39,7 @@ import torch.utils.checkpoint as checkpoint
|
|||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import DropPath, Mlp, ClassifierHead, to_2tuple, _assert, ndgrid
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import feature_take_indices
|
||||
from ._features_fx import register_notrace_function
|
||||
from ._manipulate import named_apply
|
||||
from ._registry import generate_default_cfgs, register_model
|
||||
|
@ -718,6 +719,62 @@ class SwinTransformerV2Cr(nn.Module):
|
|||
self.num_classes = num_classes
|
||||
self.head.reset(num_classes, global_pool)
|
||||
|
||||
def forward_intermediates(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
indices: Union[int, List[int], Tuple[int]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
intermediates_only: bool = False,
|
||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
""" Forward features that returns intermediates.
|
||||
|
||||
Args:
|
||||
x: Input image tensor
|
||||
indices: Take last n blocks if int, all if None, select matching indices if sequence
|
||||
norm: Apply norm layer to compatible intermediates
|
||||
stop_early: Stop iterating over blocks when last desired intermediate hit
|
||||
output_fmt: Shape of intermediate feature outputs
|
||||
intermediates_only: Only return intermediate features
|
||||
Returns:
|
||||
|
||||
"""
|
||||
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
|
||||
intermediates = []
|
||||
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||
|
||||
# forward pass
|
||||
x = self.patch_embed(x)
|
||||
|
||||
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||
stages = self.stages
|
||||
else:
|
||||
stages = self.stages[:max_index + 1]
|
||||
for i, stage in enumerate(stages):
|
||||
x = stage(x)
|
||||
if i in take_indices:
|
||||
intermediates.append(x)
|
||||
|
||||
if intermediates_only:
|
||||
return intermediates
|
||||
|
||||
return x, intermediates
|
||||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
indices: Union[int, List[int], Tuple[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
""" Prune layers not required for specified intermediates.
|
||||
"""
|
||||
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||
self.stages = self.stages[:max_index + 1] # truncate blocks
|
||||
if prune_head:
|
||||
self.reset_classifier(0, '')
|
||||
return take_indices
|
||||
|
||||
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.patch_embed(x)
|
||||
x = self.stages(x)
|
||||
|
|
|
@ -409,7 +409,7 @@ class Twins(nn.Module):
|
|||
x: torch.Tensor,
|
||||
indices: Union[int, List[int], Tuple[int]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = True,
|
||||
stop_early: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
intermediates_only: bool = False,
|
||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
|
@ -459,6 +459,22 @@ class Twins(nn.Module):
|
|||
|
||||
return x, intermediates
|
||||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
indices: Union[int, List[int], Tuple[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
""" Prune layers not required for specified intermediates.
|
||||
"""
|
||||
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
|
||||
# FIXME add block pruning
|
||||
if prune_norm:
|
||||
self.norm = nn.Identity()
|
||||
if prune_head:
|
||||
self.reset_classifier(0, '')
|
||||
return take_indices
|
||||
|
||||
def forward_features(self, x):
|
||||
B = x.shape[0]
|
||||
for i, (embed, drop, blocks, pos_blk) in enumerate(
|
||||
|
|
|
@ -638,7 +638,7 @@ class VisionTransformer(nn.Module):
|
|||
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||
return_prefix_tokens: bool = False,
|
||||
norm: bool = False,
|
||||
stop_early: bool = True,
|
||||
stop_early: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
intermediates_only: bool = False,
|
||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
|
@ -655,7 +655,7 @@ class VisionTransformer(nn.Module):
|
|||
Returns:
|
||||
|
||||
"""
|
||||
assert output_fmt in ('NCHW', 'NLC'), 'Output format for ViT features must be one of NCHW or NLC.'
|
||||
assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.'
|
||||
reshape = output_fmt == 'NCHW'
|
||||
intermediates = []
|
||||
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
|
||||
|
@ -666,6 +666,7 @@ class VisionTransformer(nn.Module):
|
|||
x = self._pos_embed(x)
|
||||
x = self.patch_drop(x)
|
||||
x = self.norm_pre(x)
|
||||
|
||||
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||
blocks = self.blocks
|
||||
else:
|
||||
|
@ -698,21 +699,19 @@ class VisionTransformer(nn.Module):
|
|||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
n: Union[int, List[int], Tuple[int]] = 1,
|
||||
indices: Union[int, List[int], Tuple[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
""" Prune layers not required for specified intermediates.
|
||||
"""
|
||||
take_indices, max_index = feature_take_indices(len(self.blocks), n)
|
||||
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
|
||||
self.blocks = self.blocks[:max_index + 1] # truncate blocks
|
||||
if prune_norm:
|
||||
self.norm = nn.Identity()
|
||||
if prune_head:
|
||||
if self.attn_pool is not None:
|
||||
self.attn_pool = None
|
||||
self.fc_norm = nn.Identity()
|
||||
self.head = nn.Identity()
|
||||
self.reset_classifier(0, '')
|
||||
return take_indices
|
||||
|
||||
def get_intermediate_layers(
|
||||
|
@ -1791,12 +1790,27 @@ default_cfgs = {
|
|||
license='mit',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
|
||||
|
||||
'vit_medium_patch16_reg4_256': _cfg(
|
||||
input_size=(3, 256, 256)),
|
||||
'vit_wee_patch16_reg1_gap_256': _cfg(
|
||||
file='',
|
||||
input_size=(3, 256, 256), crop_pct=0.95),
|
||||
'vit_little_patch16_reg4_gap_256': _cfg(
|
||||
file='',
|
||||
input_size=(3, 256, 256), crop_pct=0.95),
|
||||
'vit_medium_patch16_reg1_gap_256': _cfg(
|
||||
file='vit_medium_gap1-in1k-20231118-8.pth',
|
||||
input_size=(3, 256, 256), crop_pct=0.95),
|
||||
'vit_medium_patch16_reg4_gap_256': _cfg(
|
||||
input_size=(3, 256, 256)),
|
||||
file='vit_medium_gap4-in1k-20231115-8.pth',
|
||||
input_size=(3, 256, 256), crop_pct=0.95),
|
||||
'vit_betwixt_patch16_reg1_gap_256': _cfg(
|
||||
file='vit_betwixt_gap1-in1k-20231121-8.pth',
|
||||
input_size=(3, 256, 256), crop_pct=0.95),
|
||||
'vit_betwixt_patch16_reg4_gap_256': _cfg(
|
||||
file='vit_betwixt_gap4-in1k-20231106-8.pth',
|
||||
input_size=(3, 256, 256), crop_pct=0.95),
|
||||
'vit_base_patch16_reg4_gap_256': _cfg(
|
||||
input_size=(3, 256, 256)),
|
||||
|
||||
'vit_so150m_patch16_reg4_gap_256': _cfg(
|
||||
input_size=(3, 256, 256)),
|
||||
'vit_so150m_patch16_reg4_map_256': _cfg(
|
||||
|
@ -2083,6 +2097,18 @@ def vit_medium_patch16_gap_384(pretrained: bool = False, **kwargs) -> VisionTran
|
|||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_betwixt_patch16_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||
""" ViT-Betwixt (ViT-b/16) w/o class token, w/ avg-pool @ 256x256
|
||||
"""
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=640, depth=12, num_heads=10, class_token=False,
|
||||
global_pool='avg', qkv_bias=False, init_values=1e-6, fc_norm=False)
|
||||
model = _create_vision_transformer(
|
||||
'vit_medium_patch16_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_patch16_gap_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||
""" ViT-Base (ViT-B/16) w/o class token, w/ avg-pool @ 224x224
|
||||
|
@ -2714,21 +2740,54 @@ def vit_so400m_patch14_siglip_384(pretrained: bool = False, **kwargs) -> VisionT
|
|||
return model
|
||||
|
||||
|
||||
# @register_model
|
||||
# def vit_medium_patch16_reg4_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||
# model_args = dict(
|
||||
# patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=True,
|
||||
# no_embed_class=True, reg_tokens=4,
|
||||
# )
|
||||
# model = _create_vision_transformer(
|
||||
# 'vit_medium_patch16_reg4_256', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
# return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_medium_patch16_reg4_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||
def vit_wee_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=True,
|
||||
no_embed_class=True, reg_tokens=4,
|
||||
patch_size=16, embed_dim=256, depth=14, num_heads=4, init_values=1e-5, mlp_ratio=5,
|
||||
class_token=False, no_embed_class=True, reg_tokens=1, global_pool='avg', block_fn=ParallelScalingBlock,
|
||||
)
|
||||
model = _create_vision_transformer(
|
||||
'vit_medium_patch16_reg4_256', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
'vit_medium_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_little_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=320, depth=14, num_heads=5, init_values=1e-5, mlp_ratio=5.6,
|
||||
class_token=False, no_embed_class=True, reg_tokens=4, global_pool='avg',
|
||||
)
|
||||
model = _create_vision_transformer(
|
||||
'vit_medium_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_medium_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=512, depth=12, num_heads=8, init_values=1e-5,
|
||||
class_token=False, no_embed_class=True, reg_tokens=1, global_pool='avg',
|
||||
)
|
||||
model = _create_vision_transformer(
|
||||
'vit_medium_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_medium_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=512, depth=12, num_heads=8,
|
||||
patch_size=16, embed_dim=512, depth=12, num_heads=8, init_values=1e-5,
|
||||
class_token=False, no_embed_class=True, reg_tokens=4, global_pool='avg',
|
||||
)
|
||||
model = _create_vision_transformer(
|
||||
|
@ -2736,6 +2795,28 @@ def vit_medium_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> Visio
|
|||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_betwixt_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=640, depth=12, num_heads=10, init_values=1e-5,
|
||||
class_token=False, no_embed_class=True, reg_tokens=1, global_pool='avg',
|
||||
)
|
||||
model = _create_vision_transformer(
|
||||
'vit_betwixt_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_betwixt_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=640, depth=12, num_heads=10, init_values=1e-5,
|
||||
class_token=False, no_embed_class=True, reg_tokens=4, global_pool='avg',
|
||||
)
|
||||
model = _create_vision_transformer(
|
||||
'vit_betwixt_patch16_reg4_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||
model_args = dict(
|
||||
|
|
|
@ -394,7 +394,7 @@ class VisionTransformerRelPos(nn.Module):
|
|||
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||
return_prefix_tokens: bool = False,
|
||||
norm: bool = False,
|
||||
stop_early: bool = True,
|
||||
stop_early: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
intermediates_only: bool = False,
|
||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
|
@ -411,7 +411,7 @@ class VisionTransformerRelPos(nn.Module):
|
|||
Returns:
|
||||
|
||||
"""
|
||||
assert output_fmt in ('NCHW', 'NLC'), 'Output format for ViT features must be one of NCHW or NLC.'
|
||||
assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.'
|
||||
reshape = output_fmt == 'NCHW'
|
||||
intermediates = []
|
||||
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
|
||||
|
@ -455,19 +455,19 @@ class VisionTransformerRelPos(nn.Module):
|
|||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
n: Union[int, List[int], Tuple[int]] = 1,
|
||||
indices: Union[int, List[int], Tuple[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
""" Prune layers not required for specified intermediates.
|
||||
"""
|
||||
take_indices, max_index = feature_take_indices(len(self.blocks), n)
|
||||
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
|
||||
self.blocks = self.blocks[:max_index + 1] # truncate blocks
|
||||
if prune_norm:
|
||||
self.norm = nn.Identity()
|
||||
if prune_head:
|
||||
self.fc_norm = nn.Identity()
|
||||
self.head = nn.Identity()
|
||||
self.reset_classifier(0, '')
|
||||
return take_indices
|
||||
|
||||
def forward_features(self, x):
|
||||
|
|
|
@ -545,7 +545,7 @@ class VisionTransformerSAM(nn.Module):
|
|||
x: torch.Tensor,
|
||||
indices: Union[int, List[int], Tuple[int]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = True,
|
||||
stop_early: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
intermediates_only: bool = False,
|
||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
|
@ -573,6 +573,7 @@ class VisionTransformerSAM(nn.Module):
|
|||
x = self.pos_drop(x)
|
||||
x = self.patch_drop(x)
|
||||
x = self.norm_pre(x)
|
||||
|
||||
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
|
||||
blocks = self.blocks
|
||||
else:
|
||||
|
@ -597,19 +598,19 @@ class VisionTransformerSAM(nn.Module):
|
|||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
n: Union[int, List[int], Tuple[int]] = None,
|
||||
indices: Union[int, List[int], Tuple[int]] = None,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
""" Prune layers not required for specified intermediates.
|
||||
"""
|
||||
take_indices, max_index = feature_take_indices(len(self.blocks), n)
|
||||
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
|
||||
self.blocks = self.blocks[:max_index + 1] # truncate blocks
|
||||
if prune_norm:
|
||||
# neck is being treated as equivalent to final norm here
|
||||
self.neck = nn.Identity()
|
||||
if prune_head:
|
||||
self.head = nn.Identity()
|
||||
self.reset_classifier(0, '')
|
||||
return take_indices
|
||||
|
||||
def forward_features(self, x):
|
||||
|
|
|
@ -713,7 +713,7 @@ class VOLO(nn.Module):
|
|||
x: torch.Tensor,
|
||||
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = True,
|
||||
stop_early: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
intermediates_only: bool = False,
|
||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
|
@ -751,8 +751,11 @@ class VOLO(nn.Module):
|
|||
x = self.pos_drop(x)
|
||||
x = block(x)
|
||||
if idx in take_indices:
|
||||
# normalize intermediates with final norm layer if enabled
|
||||
intermediates.append(x.permute(0, 3, 1, 2))
|
||||
if norm and idx >= 2:
|
||||
x_inter = self.norm(x)
|
||||
else:
|
||||
x_inter = x
|
||||
intermediates.append(x_inter.permute(0, 3, 1, 2))
|
||||
|
||||
if intermediates_only:
|
||||
return intermediates
|
||||
|
@ -769,20 +772,20 @@ class VOLO(nn.Module):
|
|||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
n: Union[int, List[int], Tuple[int]] = 1,
|
||||
indices: Union[int, List[int], Tuple[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
""" Prune layers not required for specified intermediates.
|
||||
"""
|
||||
take_indices, max_index = feature_take_indices(len(self.stage_ends), n)
|
||||
take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
|
||||
max_index = self.stage_ends[max_index]
|
||||
self.network = self.network[:max_index + 1] # truncate blocks
|
||||
if prune_norm:
|
||||
self.norm = nn.Identity()
|
||||
if prune_head:
|
||||
self.post_network = nn.ModuleList() # prune token blocks with head
|
||||
self.head = nn.Identity()
|
||||
self.reset_classifier(0, '')
|
||||
return take_indices
|
||||
|
||||
def forward_features(self, x):
|
||||
|
|
|
@ -444,7 +444,7 @@ class Xcit(nn.Module):
|
|||
x: torch.Tensor,
|
||||
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
|
||||
norm: bool = False,
|
||||
stop_early: bool = True,
|
||||
stop_early: bool = False,
|
||||
output_fmt: str = 'NCHW',
|
||||
intermediates_only: bool = False,
|
||||
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
|
@ -460,7 +460,7 @@ class Xcit(nn.Module):
|
|||
Returns:
|
||||
|
||||
"""
|
||||
assert output_fmt in ('NCHW', 'NLC'), 'Output format for ViT features must be one of NCHW or NLC.'
|
||||
assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.'
|
||||
reshape = output_fmt == 'NCHW'
|
||||
intermediates = []
|
||||
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
|
||||
|
@ -468,7 +468,6 @@ class Xcit(nn.Module):
|
|||
# forward pass
|
||||
B, _, height, width = x.shape
|
||||
x, (Hp, Wp) = self.patch_embed(x)
|
||||
|
||||
if self.pos_embed is not None:
|
||||
# `pos_embed` (B, C, Hp, Wp), reshape -> (B, C, N), permute -> (B, N, C)
|
||||
pos_encoding = self.pos_embed(B, Hp, Wp).reshape(B, -1, x.shape[1]).permute(0, 2, 1)
|
||||
|
@ -503,19 +502,19 @@ class Xcit(nn.Module):
|
|||
|
||||
def prune_intermediate_layers(
|
||||
self,
|
||||
n: Union[int, List[int], Tuple[int]] = 1,
|
||||
indices: Union[int, List[int], Tuple[int]] = 1,
|
||||
prune_norm: bool = False,
|
||||
prune_head: bool = True,
|
||||
):
|
||||
""" Prune layers not required for specified intermediates.
|
||||
"""
|
||||
take_indices, max_index = feature_take_indices(len(self.blocks), n)
|
||||
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
|
||||
self.blocks = self.blocks[:max_index + 1] # truncate blocks
|
||||
if prune_norm:
|
||||
self.norm = nn.Identity()
|
||||
if prune_head:
|
||||
self.cls_attn_blocks = nn.ModuleList() # prune token blocks with head
|
||||
self.head = nn.Identity()
|
||||
self.reset_classifier(0, '')
|
||||
return take_indices
|
||||
|
||||
def forward_features(self, x):
|
||||
|
|
Loading…
Reference in New Issue