1
0
mirror of https://github.com/huggingface/pytorch-image-models.git synced 2025-06-03 15:01:08 +08:00

Wrap torch checkpoint() fn to default use_reentrant flag to False and allow env var override

This commit is contained in:
Ross Wightman 2025-01-06 11:28:39 -08:00
parent 131518c15c
commit 2b251fb291
22 changed files with 91 additions and 54 deletions

@ -8,7 +8,8 @@ from .blur_pool import BlurPool2d, create_aa
from .classifier import create_classifier, ClassifierHead, NormMlpClassifierHead, ClNormMlpClassifierHead from .classifier import create_classifier, ClassifierHead, NormMlpClassifierHead, ClNormMlpClassifierHead
from .cond_conv2d import CondConv2d, get_condconv_initializer from .cond_conv2d import CondConv2d, get_condconv_initializer
from .config import is_exportable, is_scriptable, is_no_jit, use_fused_attn, \ from .config import is_exportable, is_scriptable, is_no_jit, use_fused_attn, \
set_exportable, set_scriptable, set_no_jit, set_layer_config, set_fused_attn set_exportable, set_scriptable, set_no_jit, set_layer_config, set_fused_attn, \
set_reentrant_ckpt, use_reentrant_ckpt
from .conv2d_same import Conv2dSame, conv2d_same from .conv2d_same import Conv2dSame, conv2d_same
from .conv_bn_act import ConvNormAct, ConvNormActAa, ConvBnAct from .conv_bn_act import ConvNormAct, ConvNormActAa, ConvBnAct
from .create_act import create_act_layer, get_act_layer, get_act_fn from .create_act import create_act_layer, get_act_layer, get_act_fn

@ -8,7 +8,8 @@ import torch
__all__ = [ __all__ = [
'is_exportable', 'is_scriptable', 'is_no_jit', 'use_fused_attn', 'is_exportable', 'is_scriptable', 'is_no_jit', 'use_fused_attn',
'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config', 'set_fused_attn' 'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config', 'set_fused_attn',
'set_reentrant_ckpt', 'use_reentrant_ckpt'
] ]
# Set to True if prefer to have layers with no jit optimization (includes activations) # Set to True if prefer to have layers with no jit optimization (includes activations)
@ -34,6 +35,12 @@ else:
_USE_FUSED_ATTN = 1 # 0 == off, 1 == on (for tested use), 2 == on (for experimental use) _USE_FUSED_ATTN = 1 # 0 == off, 1 == on (for tested use), 2 == on (for experimental use)
if 'TIMM_REENTRANT_CKPT' in os.environ:
_USE_REENTRANT_CKPT = bool(os.environ['TIMM_REENTRANT_CKPT'])
else:
_USE_REENTRANT_CKPT = False # defaults to disabled (off)
def is_no_jit(): def is_no_jit():
return _NO_JIT return _NO_JIT
@ -147,3 +154,12 @@ def set_fused_attn(enable: bool = True, experimental: bool = False):
_USE_FUSED_ATTN = 1 _USE_FUSED_ATTN = 1
else: else:
_USE_FUSED_ATTN = 0 _USE_FUSED_ATTN = 0
def use_reentrant_ckpt() -> bool:
return _USE_REENTRANT_CKPT
def set_reentrant_ckpt(enable: bool = True):
global _USE_REENTRANT_CKPT
_USE_REENTRANT_CKPT = enable

@ -15,10 +15,9 @@ from typing import Dict, List, Optional, Sequence, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from timm.layers import Format, _assert from timm.layers import Format, _assert
from ._manipulate import checkpoint
__all__ = [ __all__ = [
'FeatureInfo', 'FeatureHooks', 'FeatureDictNet', 'FeatureListNet', 'FeatureHookNet', 'FeatureGetterNet', 'FeatureInfo', 'FeatureHooks', 'FeatureDictNet', 'FeatureListNet', 'FeatureHookNet', 'FeatureGetterNet',

@ -3,14 +3,17 @@ import math
import re import re
from collections import defaultdict from collections import defaultdict
from itertools import chain from itertools import chain
from typing import Any, Callable, Dict, Iterator, Tuple, Type, Union from typing import Any, Callable, Dict, Iterator, Optional, Tuple, Type, Union
import torch import torch
import torch.utils.checkpoint
from torch import nn as nn from torch import nn as nn
from torch.utils.checkpoint import checkpoint
from timm.layers import use_reentrant_ckpt
__all__ = ['model_parameters', 'named_apply', 'named_modules', 'named_modules_with_params', 'adapt_input_conv', __all__ = ['model_parameters', 'named_apply', 'named_modules', 'named_modules_with_params', 'adapt_input_conv',
'group_with_matcher', 'group_modules', 'group_parameters', 'flatten_modules', 'checkpoint_seq'] 'group_with_matcher', 'group_modules', 'group_parameters', 'flatten_modules', 'checkpoint_seq', 'checkpoint']
def model_parameters(model: nn.Module, exclude_head: bool = False): def model_parameters(model: nn.Module, exclude_head: bool = False):
@ -183,13 +186,35 @@ def flatten_modules(
yield name, module yield name, module
def checkpoint(
function,
*args,
use_reentrant: Optional[bool] = None,
**kwargs,
):
""" checkpoint wrapper fn
A thin wrapper around torch.utils.checkpoint.checkpoint to default
use_reentrant to False
"""
if use_reentrant is None:
use_reentrant = use_reentrant_ckpt()
return torch.utils.checkpoint.checkpoint(
function,
*args,
use_reentrant=use_reentrant,
**kwargs,
)
def checkpoint_seq( def checkpoint_seq(
functions, functions,
x, x,
every=1, every: int = 1,
flatten=False, flatten: bool = False,
skip_last=False, skip_last: bool = False,
preserve_rng_state=True use_reentrant: Optional[bool] = None,
): ):
r"""A helper function for checkpointing sequential models. r"""A helper function for checkpointing sequential models.
@ -215,10 +240,9 @@ def checkpoint_seq(
functions: A :class:`torch.nn.Sequential` or the list of modules or functions to run sequentially. functions: A :class:`torch.nn.Sequential` or the list of modules or functions to run sequentially.
x: A Tensor that is input to :attr:`functions` x: A Tensor that is input to :attr:`functions`
every: checkpoint every-n functions (default: 1) every: checkpoint every-n functions (default: 1)
flatten (bool): flatten nn.Sequential of nn.Sequentials flatten: flatten nn.Sequential of nn.Sequentials
skip_last (bool): skip checkpointing the last function in the sequence if True skip_last: skip checkpointing the last function in the sequence if True
preserve_rng_state (bool, optional, default=True): Omit stashing and restoring use_reentrant: Use re-entrant checkpointing
the RNG state during each checkpoint.
Returns: Returns:
Output of running :attr:`functions` sequentially on :attr:`*inputs` Output of running :attr:`functions` sequentially on :attr:`*inputs`
@ -227,6 +251,9 @@ def checkpoint_seq(
>>> model = nn.Sequential(...) >>> model = nn.Sequential(...)
>>> input_var = checkpoint_seq(model, input_var, every=2) >>> input_var = checkpoint_seq(model, input_var, every=2)
""" """
if use_reentrant is None:
use_reentrant = use_reentrant_ckpt()
def run_function(start, end, functions): def run_function(start, end, functions):
def forward(_x): def forward(_x):
for j in range(start, end + 1): for j in range(start, end + 1):
@ -247,7 +274,11 @@ def checkpoint_seq(
end = -1 end = -1
for start in range(0, num_checkpointed, every): for start in range(0, num_checkpointed, every):
end = min(start + every - 1, num_checkpointed - 1) end = min(start + every - 1, num_checkpointed - 1)
x = checkpoint(run_function(start, end, functions), x, preserve_rng_state=preserve_rng_state) x = torch.utils.checkpoint.checkpoint(
run_function(start, end, functions),
x,
use_reentrant=use_reentrant,
)
if skip_last: if skip_last:
return run_function(end + 1, len(functions) - 1, functions)(x) return run_function(end + 1, len(functions) - 1, functions)(x)
return x return x

@ -44,15 +44,14 @@ from typing import Callable, List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import PatchEmbed, Mlp, SwiGLU, LayerNorm, DropPath, trunc_normal_, use_fused_attn from timm.layers import PatchEmbed, Mlp, SwiGLU, LayerNorm, DropPath, trunc_normal_, use_fused_attn
from timm.layers import resample_patch_embed, resample_abs_pos_embed, resize_rel_pos_bias_table, ndgrid from timm.layers import resample_patch_embed, resample_abs_pos_embed, resize_rel_pos_bias_table, ndgrid
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._features import feature_take_indices from ._features import feature_take_indices
from ._manipulate import checkpoint
from ._registry import generate_default_cfgs, register_model from ._registry import generate_default_cfgs, register_model
__all__ = ['Beit'] __all__ = ['Beit']

@ -8,13 +8,12 @@ from collections import OrderedDict
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.utils.checkpoint as cp
from torch.jit.annotations import List from torch.jit.annotations import List
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import BatchNormAct2d, get_norm_act_layer, BlurPool2d, create_classifier from timm.layers import BatchNormAct2d, get_norm_act_layer, BlurPool2d, create_classifier
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._manipulate import MATCH_PREV_GROUP from ._manipulate import MATCH_PREV_GROUP, checkpoint
from ._registry import register_model, generate_default_cfgs, register_model_deprecations from ._registry import register_model, generate_default_cfgs, register_model_deprecations
__all__ = ['DenseNet'] __all__ = ['DenseNet']
@ -60,7 +59,7 @@ class DenseLayer(nn.Module):
def closure(*xs): def closure(*xs):
return self.bottleneck_fn(xs) return self.bottleneck_fn(xs)
return cp.checkpoint(closure, *x) return checkpoint(closure, *x)
@torch.jit._overload_method # noqa: F811 @torch.jit._overload_method # noqa: F811
def forward(self, x): def forward(self, x):

@ -41,7 +41,6 @@ from typing import Callable, List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from timm.layers import create_conv2d, create_classifier, get_norm_act_layer, LayerType, \ from timm.layers import create_conv2d, create_classifier, get_norm_act_layer, LayerType, \
@ -51,7 +50,7 @@ from ._efficientnet_blocks import SqueezeExcite
from ._efficientnet_builder import BlockArgs, EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \ from ._efficientnet_builder import BlockArgs, EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \
round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
from ._features import FeatureInfo, FeatureHooks, feature_take_indices from ._features import FeatureInfo, FeatureHooks, feature_take_indices
from ._manipulate import checkpoint_seq from ._manipulate import checkpoint_seq, checkpoint
from ._registry import generate_default_cfgs, register_model, register_model_deprecations from ._registry import generate_default_cfgs, register_model, register_model_deprecations
__all__ = ['EfficientNet', 'EfficientNetFeatures'] __all__ = ['EfficientNet', 'EfficientNetFeatures']

@ -30,7 +30,6 @@ from typing import Callable, List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from timm.layers import PatchEmbed, Mlp, GluMlp, SwiGLU, LayerNorm, DropPath, PatchDropout, RotaryEmbeddingCat, \ from timm.layers import PatchEmbed, Mlp, GluMlp, SwiGLU, LayerNorm, DropPath, PatchDropout, RotaryEmbeddingCat, \
@ -39,6 +38,7 @@ from timm.layers import PatchEmbed, Mlp, GluMlp, SwiGLU, LayerNorm, DropPath, Pa
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._features import feature_take_indices from ._features import feature_take_indices
from ._manipulate import checkpoint
from ._registry import generate_default_cfgs, register_model from ._registry import generate_default_cfgs, register_model
__all__ = ['Eva'] __all__ = ['Eva']

@ -22,12 +22,11 @@ from typing import Callable, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import Mlp, DropPath, LayerNorm2d, trunc_normal_, ClassifierHead, NormMlpClassifierHead from timm.layers import Mlp, DropPath, LayerNorm2d, trunc_normal_, ClassifierHead, NormMlpClassifierHead
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._manipulate import named_apply from ._manipulate import named_apply, checkpoint
from ._registry import generate_default_cfgs, register_model from ._registry import generate_default_cfgs, register_model
__all__ = ['FocalNet'] __all__ = ['FocalNet']

@ -25,14 +25,13 @@ from typing import Callable, List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, to_2tuple, to_ntuple, Mlp, ClassifierHead, LayerNorm2d, \ from timm.layers import DropPath, to_2tuple, to_ntuple, Mlp, ClassifierHead, LayerNorm2d, \
get_attn, get_act_layer, get_norm_layer, RelPosBias, _assert get_attn, get_act_layer, get_norm_layer, RelPosBias, _assert
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._features_fx import register_notrace_function from ._features_fx import register_notrace_function
from ._manipulate import named_apply from ._manipulate import named_apply, checkpoint
from ._registry import register_model, generate_default_cfgs from ._registry import register_model, generate_default_cfgs
__all__ = ['GlobalContextVit'] __all__ = ['GlobalContextVit']

@ -29,7 +29,6 @@ from typing import Callable, Dict, List, Optional, Tuple, Type, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, Mlp, LayerScale, ClNormMlpClassifierHead, use_fused_attn, \ from timm.layers import DropPath, Mlp, LayerScale, ClNormMlpClassifierHead, use_fused_attn, \
@ -39,7 +38,7 @@ from ._registry import generate_default_cfgs, register_model
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._features import feature_take_indices from ._features import feature_take_indices
from ._features_fx import register_notrace_function from ._features_fx import register_notrace_function
from ._manipulate import named_apply from ._manipulate import named_apply, checkpoint
__all__ = ['Hiera'] __all__ = ['Hiera']

@ -12,7 +12,6 @@ from typing import Callable, List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from timm.layers import SelectAdaptivePool2d, Linear, LayerType, PadType, create_conv2d, get_norm_act_layer from timm.layers import SelectAdaptivePool2d, Linear, LayerType, PadType, create_conv2d, get_norm_act_layer
@ -21,7 +20,7 @@ from ._efficientnet_blocks import SqueezeExcite
from ._efficientnet_builder import BlockArgs, EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \ from ._efficientnet_builder import BlockArgs, EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \
round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
from ._features import FeatureInfo, FeatureHooks, feature_take_indices from ._features import FeatureInfo, FeatureHooks, feature_take_indices
from ._manipulate import checkpoint_seq from ._manipulate import checkpoint_seq, checkpoint
from ._registry import generate_default_cfgs, register_model, register_model_deprecations from ._registry import generate_default_cfgs, register_model, register_model_deprecations
__all__ = ['MobileNetV3', 'MobileNetV3Features'] __all__ = ['MobileNetV3', 'MobileNetV3Features']

@ -20,7 +20,6 @@ from functools import partial, reduce
from typing import Union, List, Tuple, Optional from typing import Union, List, Tuple, Optional
import torch import torch
import torch.utils.checkpoint as checkpoint
from torch import nn from torch import nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
@ -28,7 +27,8 @@ from timm.layers import Mlp, DropPath, trunc_normal_tf_, get_norm_layer, to_2tup
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._features import feature_take_indices from ._features import feature_take_indices
from ._features_fx import register_notrace_function from ._features_fx import register_notrace_function
from ._registry import register_model, register_model_deprecations, generate_default_cfgs from ._manipulate import checkpoint
from ._registry import register_model, generate_default_cfgs
__all__ = ['MultiScaleVit', 'MultiScaleVitCfg'] # model_registry will add each entrypoint fn to this __all__ = ['MultiScaleVit', 'MultiScaleVitCfg'] # model_registry will add each entrypoint fn to this

@ -21,11 +21,11 @@ from typing import Callable, List, Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, to_2tuple, to_ntuple, trunc_normal_, LayerNorm, use_fused_attn from timm.layers import DropPath, to_2tuple, to_ntuple, trunc_normal_, LayerNorm, use_fused_attn
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._manipulate import checkpoint
from ._registry import register_model, generate_default_cfgs from ._registry import register_model, generate_default_cfgs
__all__ = ['PyramidVisionTransformerV2'] __all__ = ['PyramidVisionTransformerV2']

@ -18,14 +18,14 @@ from typing import Callable, List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, _assert, ClassifierHead,\ from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, ClassifierHead,\
resample_patch_embed, ndgrid, get_act_layer, LayerType resample_patch_embed, ndgrid, get_act_layer, LayerType
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._features import feature_take_indices from ._features import feature_take_indices
from ._features_fx import register_notrace_function from ._features_fx import register_notrace_function
from ._manipulate import checkpoint
from ._registry import generate_default_cfgs, register_model, register_model_deprecations from ._registry import generate_default_cfgs, register_model, register_model_deprecations
__all__ = ['SwinTransformerV2'] # model_registry will add each entrypoint fn to this __all__ = ['SwinTransformerV2'] # model_registry will add each entrypoint fn to this

@ -34,14 +34,13 @@ from typing import Tuple, Optional, List, Union, Any, Type
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, Mlp, ClassifierHead, to_2tuple, _assert, ndgrid from timm.layers import DropPath, Mlp, ClassifierHead, to_2tuple, _assert, ndgrid
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._features import feature_take_indices from ._features import feature_take_indices
from ._features_fx import register_notrace_function from ._features_fx import register_notrace_function
from ._manipulate import named_apply from ._manipulate import named_apply, checkpoint
from ._registry import generate_default_cfgs, register_model from ._registry import generate_default_cfgs, register_model
__all__ = ['SwinTransformerV2Cr'] # model_registry will add each entrypoint fn to this __all__ = ['SwinTransformerV2Cr'] # model_registry will add each entrypoint fn to this

@ -11,13 +11,13 @@ from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import Mlp, DropPath, trunc_normal_, _assert, to_2tuple from timm.layers import Mlp, DropPath, trunc_normal_, _assert, to_2tuple, resample_abs_pos_embed
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._manipulate import checkpoint
from ._registry import register_model from ._registry import register_model
from .vision_transformer import resize_pos_embed
__all__ = ['TNT'] # model_registry will add each entrypoint fn to this __all__ = ['TNT'] # model_registry will add each entrypoint fn to this
@ -340,8 +340,11 @@ class TNT(nn.Module):
def checkpoint_filter_fn(state_dict, model): def checkpoint_filter_fn(state_dict, model):
""" convert patch embedding weight from manual patchify + linear proj to conv""" """ convert patch embedding weight from manual patchify + linear proj to conv"""
if state_dict['patch_pos'].shape != model.patch_pos.shape: if state_dict['patch_pos'].shape != model.patch_pos.shape:
state_dict['patch_pos'] = resize_pos_embed(state_dict['patch_pos'], state_dict['patch_pos'] = resample_abs_pos_embed(
model.patch_pos, getattr(model, 'num_tokens', 1), model.pixel_embed.grid_size) state_dict['patch_pos'],
new_size=model.pixel_embed.grid_size,
num_prefix_tokens=1,
)
return state_dict return state_dict

@ -37,7 +37,6 @@ except ImportError:
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.utils.checkpoint
from torch.jit import Final from torch.jit import Final
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \
@ -1019,7 +1018,6 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str =
else: else:
pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
if pos_embed_w.shape != model.pos_embed.shape: if pos_embed_w.shape != model.pos_embed.shape:
old_shape = pos_embed_w.shape
num_prefix_tokens = 0 if getattr(model, 'no_embed_class', False) else getattr(model, 'num_prefix_tokens', 1) num_prefix_tokens = 0 if getattr(model, 'no_embed_class', False) else getattr(model, 'num_prefix_tokens', 1)
pos_embed_w = resample_abs_pos_embed( # resize pos embedding when different size from pretrained weights pos_embed_w = resample_abs_pos_embed( # resize pos embedding when different size from pretrained weights
pos_embed_w, pos_embed_w,

@ -17,13 +17,12 @@ except ImportError:
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.jit import Final from torch.jit import Final
from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from timm.layers import PatchEmbed, Mlp, DropPath, RelPosMlp, RelPosBias, use_fused_attn, LayerType from timm.layers import PatchEmbed, Mlp, DropPath, RelPosMlp, RelPosBias, use_fused_attn, LayerType
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._features import feature_take_indices from ._features import feature_take_indices
from ._manipulate import named_apply from ._manipulate import named_apply, checkpoint
from ._registry import generate_default_cfgs, register_model from ._registry import generate_default_cfgs, register_model
from .vision_transformer import get_init_weights_vit from .vision_transformer import get_init_weights_vit

@ -16,7 +16,6 @@ from typing import Callable, List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.utils.checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from timm.layers import PatchEmbed, Mlp, DropPath, PatchDropout, LayerNorm2d, ClassifierHead, NormMlpClassifierHead, \ from timm.layers import PatchEmbed, Mlp, DropPath, PatchDropout, LayerNorm2d, ClassifierHead, NormMlpClassifierHead, \
Format, resample_abs_pos_embed_nhwc, RotaryEmbeddingCat, apply_rot_embed_cat, to_2tuple, use_fused_attn Format, resample_abs_pos_embed_nhwc, RotaryEmbeddingCat, apply_rot_embed_cat, to_2tuple, use_fused_attn
@ -25,7 +24,7 @@ from torch.jit import Final
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._features import feature_take_indices from ._features import feature_take_indices
from ._features_fx import register_notrace_function from ._features_fx import register_notrace_function
from ._manipulate import checkpoint_seq from ._manipulate import checkpoint_seq, checkpoint
from ._registry import generate_default_cfgs, register_model from ._registry import generate_default_cfgs, register_model
# model_registry will add each entrypoint fn to this # model_registry will add each entrypoint fn to this

@ -26,12 +26,12 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, Mlp, to_2tuple, to_ntuple, trunc_normal_, use_fused_attn from timm.layers import DropPath, Mlp, to_2tuple, to_ntuple, trunc_normal_, use_fused_attn
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._features import feature_take_indices from ._features import feature_take_indices
from ._manipulate import checkpoint
from ._registry import register_model, generate_default_cfgs from ._registry import register_model, generate_default_cfgs
__all__ = ['VOLO'] # model_registry will add each entrypoint fn to this __all__ = ['VOLO'] # model_registry will add each entrypoint fn to this

@ -17,16 +17,15 @@ from typing import List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, trunc_normal_, to_2tuple, use_fused_attn from timm.layers import DropPath, trunc_normal_, to_2tuple, use_fused_attn, Mlp
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._features import feature_take_indices from ._features import feature_take_indices
from ._features_fx import register_notrace_module from ._features_fx import register_notrace_module
from ._manipulate import checkpoint
from ._registry import register_model, generate_default_cfgs, register_model_deprecations from ._registry import register_model, generate_default_cfgs, register_model_deprecations
from .cait import ClassAttn from .cait import ClassAttn
from .vision_transformer import Mlp
__all__ = ['Xcit'] # model_registry will add each entrypoint fn to this __all__ = ['Xcit'] # model_registry will add each entrypoint fn to this