Wrap torch checkpoint() fn to default use_reentrant flag to False and allow env var override

This commit is contained in:
Ross Wightman 2025-01-06 11:28:39 -08:00
parent 131518c15c
commit 2b251fb291
22 changed files with 91 additions and 54 deletions

View File

@ -8,7 +8,8 @@ from .blur_pool import BlurPool2d, create_aa
from .classifier import create_classifier, ClassifierHead, NormMlpClassifierHead, ClNormMlpClassifierHead
from .cond_conv2d import CondConv2d, get_condconv_initializer
from .config import is_exportable, is_scriptable, is_no_jit, use_fused_attn, \
set_exportable, set_scriptable, set_no_jit, set_layer_config, set_fused_attn
set_exportable, set_scriptable, set_no_jit, set_layer_config, set_fused_attn, \
set_reentrant_ckpt, use_reentrant_ckpt
from .conv2d_same import Conv2dSame, conv2d_same
from .conv_bn_act import ConvNormAct, ConvNormActAa, ConvBnAct
from .create_act import create_act_layer, get_act_layer, get_act_fn

View File

@ -8,7 +8,8 @@ import torch
__all__ = [
'is_exportable', 'is_scriptable', 'is_no_jit', 'use_fused_attn',
'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config', 'set_fused_attn'
'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config', 'set_fused_attn',
'set_reentrant_ckpt', 'use_reentrant_ckpt'
]
# Set to True if prefer to have layers with no jit optimization (includes activations)
@ -34,6 +35,12 @@ else:
_USE_FUSED_ATTN = 1 # 0 == off, 1 == on (for tested use), 2 == on (for experimental use)
if 'TIMM_REENTRANT_CKPT' in os.environ:
_USE_REENTRANT_CKPT = bool(os.environ['TIMM_REENTRANT_CKPT'])
else:
_USE_REENTRANT_CKPT = False # defaults to disabled (off)
def is_no_jit():
return _NO_JIT
@ -147,3 +154,12 @@ def set_fused_attn(enable: bool = True, experimental: bool = False):
_USE_FUSED_ATTN = 1
else:
_USE_FUSED_ATTN = 0
def use_reentrant_ckpt() -> bool:
return _USE_REENTRANT_CKPT
def set_reentrant_ckpt(enable: bool = True):
global _USE_REENTRANT_CKPT
_USE_REENTRANT_CKPT = enable

View File

@ -15,10 +15,9 @@ from typing import Dict, List, Optional, Sequence, Tuple, Union
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from timm.layers import Format, _assert
from ._manipulate import checkpoint
__all__ = [
'FeatureInfo', 'FeatureHooks', 'FeatureDictNet', 'FeatureListNet', 'FeatureHookNet', 'FeatureGetterNet',

View File

@ -3,14 +3,17 @@ import math
import re
from collections import defaultdict
from itertools import chain
from typing import Any, Callable, Dict, Iterator, Tuple, Type, Union
from typing import Any, Callable, Dict, Iterator, Optional, Tuple, Type, Union
import torch
import torch.utils.checkpoint
from torch import nn as nn
from torch.utils.checkpoint import checkpoint
from timm.layers import use_reentrant_ckpt
__all__ = ['model_parameters', 'named_apply', 'named_modules', 'named_modules_with_params', 'adapt_input_conv',
'group_with_matcher', 'group_modules', 'group_parameters', 'flatten_modules', 'checkpoint_seq']
'group_with_matcher', 'group_modules', 'group_parameters', 'flatten_modules', 'checkpoint_seq', 'checkpoint']
def model_parameters(model: nn.Module, exclude_head: bool = False):
@ -183,13 +186,35 @@ def flatten_modules(
yield name, module
def checkpoint(
function,
*args,
use_reentrant: Optional[bool] = None,
**kwargs,
):
""" checkpoint wrapper fn
A thin wrapper around torch.utils.checkpoint.checkpoint to default
use_reentrant to False
"""
if use_reentrant is None:
use_reentrant = use_reentrant_ckpt()
return torch.utils.checkpoint.checkpoint(
function,
*args,
use_reentrant=use_reentrant,
**kwargs,
)
def checkpoint_seq(
functions,
x,
every=1,
flatten=False,
skip_last=False,
preserve_rng_state=True
every: int = 1,
flatten: bool = False,
skip_last: bool = False,
use_reentrant: Optional[bool] = None,
):
r"""A helper function for checkpointing sequential models.
@ -215,10 +240,9 @@ def checkpoint_seq(
functions: A :class:`torch.nn.Sequential` or the list of modules or functions to run sequentially.
x: A Tensor that is input to :attr:`functions`
every: checkpoint every-n functions (default: 1)
flatten (bool): flatten nn.Sequential of nn.Sequentials
skip_last (bool): skip checkpointing the last function in the sequence if True
preserve_rng_state (bool, optional, default=True): Omit stashing and restoring
the RNG state during each checkpoint.
flatten: flatten nn.Sequential of nn.Sequentials
skip_last: skip checkpointing the last function in the sequence if True
use_reentrant: Use re-entrant checkpointing
Returns:
Output of running :attr:`functions` sequentially on :attr:`*inputs`
@ -227,6 +251,9 @@ def checkpoint_seq(
>>> model = nn.Sequential(...)
>>> input_var = checkpoint_seq(model, input_var, every=2)
"""
if use_reentrant is None:
use_reentrant = use_reentrant_ckpt()
def run_function(start, end, functions):
def forward(_x):
for j in range(start, end + 1):
@ -247,7 +274,11 @@ def checkpoint_seq(
end = -1
for start in range(0, num_checkpointed, every):
end = min(start + every - 1, num_checkpointed - 1)
x = checkpoint(run_function(start, end, functions), x, preserve_rng_state=preserve_rng_state)
x = torch.utils.checkpoint.checkpoint(
run_function(start, end, functions),
x,
use_reentrant=use_reentrant,
)
if skip_last:
return run_function(end + 1, len(functions) - 1, functions)(x)
return x

View File

@ -44,15 +44,14 @@ from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import PatchEmbed, Mlp, SwiGLU, LayerNorm, DropPath, trunc_normal_, use_fused_attn
from timm.layers import resample_patch_embed, resample_abs_pos_embed, resize_rel_pos_bias_table, ndgrid
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._manipulate import checkpoint
from ._registry import generate_default_cfgs, register_model
__all__ = ['Beit']

View File

@ -8,13 +8,12 @@ from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from torch.jit.annotations import List
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import BatchNormAct2d, get_norm_act_layer, BlurPool2d, create_classifier
from ._builder import build_model_with_cfg
from ._manipulate import MATCH_PREV_GROUP
from ._manipulate import MATCH_PREV_GROUP, checkpoint
from ._registry import register_model, generate_default_cfgs, register_model_deprecations
__all__ = ['DenseNet']
@ -60,7 +59,7 @@ class DenseLayer(nn.Module):
def closure(*xs):
return self.bottleneck_fn(xs)
return cp.checkpoint(closure, *x)
return checkpoint(closure, *x)
@torch.jit._overload_method # noqa: F811
def forward(self, x):

View File

@ -41,7 +41,6 @@ from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from timm.layers import create_conv2d, create_classifier, get_norm_act_layer, LayerType, \
@ -51,7 +50,7 @@ from ._efficientnet_blocks import SqueezeExcite
from ._efficientnet_builder import BlockArgs, EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \
round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
from ._features import FeatureInfo, FeatureHooks, feature_take_indices
from ._manipulate import checkpoint_seq
from ._manipulate import checkpoint_seq, checkpoint
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
__all__ = ['EfficientNet', 'EfficientNetFeatures']

View File

@ -30,7 +30,6 @@ from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from timm.layers import PatchEmbed, Mlp, GluMlp, SwiGLU, LayerNorm, DropPath, PatchDropout, RotaryEmbeddingCat, \
@ -39,6 +38,7 @@ from timm.layers import PatchEmbed, Mlp, GluMlp, SwiGLU, LayerNorm, DropPath, Pa
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._manipulate import checkpoint
from ._registry import generate_default_cfgs, register_model
__all__ = ['Eva']

View File

@ -22,12 +22,11 @@ from typing import Callable, Optional, Tuple
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import Mlp, DropPath, LayerNorm2d, trunc_normal_, ClassifierHead, NormMlpClassifierHead
from ._builder import build_model_with_cfg
from ._manipulate import named_apply
from ._manipulate import named_apply, checkpoint
from ._registry import generate_default_cfgs, register_model
__all__ = ['FocalNet']

View File

@ -25,14 +25,13 @@ from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, to_2tuple, to_ntuple, Mlp, ClassifierHead, LayerNorm2d, \
get_attn, get_act_layer, get_norm_layer, RelPosBias, _assert
from ._builder import build_model_with_cfg
from ._features_fx import register_notrace_function
from ._manipulate import named_apply
from ._manipulate import named_apply, checkpoint
from ._registry import register_model, generate_default_cfgs
__all__ = ['GlobalContextVit']

View File

@ -29,7 +29,6 @@ from typing import Callable, Dict, List, Optional, Tuple, Type, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, Mlp, LayerScale, ClNormMlpClassifierHead, use_fused_attn, \
@ -39,7 +38,7 @@ from ._registry import generate_default_cfgs, register_model
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._features_fx import register_notrace_function
from ._manipulate import named_apply
from ._manipulate import named_apply, checkpoint
__all__ = ['Hiera']

View File

@ -12,7 +12,6 @@ from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from timm.layers import SelectAdaptivePool2d, Linear, LayerType, PadType, create_conv2d, get_norm_act_layer
@ -21,7 +20,7 @@ from ._efficientnet_blocks import SqueezeExcite
from ._efficientnet_builder import BlockArgs, EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \
round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
from ._features import FeatureInfo, FeatureHooks, feature_take_indices
from ._manipulate import checkpoint_seq
from ._manipulate import checkpoint_seq, checkpoint
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
__all__ = ['MobileNetV3', 'MobileNetV3Features']

View File

@ -20,7 +20,6 @@ from functools import partial, reduce
from typing import Union, List, Tuple, Optional
import torch
import torch.utils.checkpoint as checkpoint
from torch import nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
@ -28,7 +27,8 @@ from timm.layers import Mlp, DropPath, trunc_normal_tf_, get_norm_layer, to_2tup
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._features_fx import register_notrace_function
from ._registry import register_model, register_model_deprecations, generate_default_cfgs
from ._manipulate import checkpoint
from ._registry import register_model, generate_default_cfgs
__all__ = ['MultiScaleVit', 'MultiScaleVitCfg'] # model_registry will add each entrypoint fn to this

View File

@ -21,11 +21,11 @@ from typing import Callable, List, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, to_2tuple, to_ntuple, trunc_normal_, LayerNorm, use_fused_attn
from ._builder import build_model_with_cfg
from ._manipulate import checkpoint
from ._registry import register_model, generate_default_cfgs
__all__ = ['PyramidVisionTransformerV2']

View File

@ -18,14 +18,14 @@ from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, _assert, ClassifierHead,\
from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, ClassifierHead,\
resample_patch_embed, ndgrid, get_act_layer, LayerType
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._features_fx import register_notrace_function
from ._manipulate import checkpoint
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
__all__ = ['SwinTransformerV2'] # model_registry will add each entrypoint fn to this

View File

@ -34,14 +34,13 @@ from typing import Tuple, Optional, List, Union, Any, Type
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, Mlp, ClassifierHead, to_2tuple, _assert, ndgrid
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._features_fx import register_notrace_function
from ._manipulate import named_apply
from ._manipulate import named_apply, checkpoint
from ._registry import generate_default_cfgs, register_model
__all__ = ['SwinTransformerV2Cr'] # model_registry will add each entrypoint fn to this

View File

@ -11,13 +11,13 @@ from typing import Optional
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import Mlp, DropPath, trunc_normal_, _assert, to_2tuple
from timm.layers import Mlp, DropPath, trunc_normal_, _assert, to_2tuple, resample_abs_pos_embed
from ._builder import build_model_with_cfg
from ._manipulate import checkpoint
from ._registry import register_model
from .vision_transformer import resize_pos_embed
__all__ = ['TNT'] # model_registry will add each entrypoint fn to this
@ -340,8 +340,11 @@ class TNT(nn.Module):
def checkpoint_filter_fn(state_dict, model):
""" convert patch embedding weight from manual patchify + linear proj to conv"""
if state_dict['patch_pos'].shape != model.patch_pos.shape:
state_dict['patch_pos'] = resize_pos_embed(state_dict['patch_pos'],
model.patch_pos, getattr(model, 'num_tokens', 1), model.pixel_embed.grid_size)
state_dict['patch_pos'] = resample_abs_pos_embed(
state_dict['patch_pos'],
new_size=model.pixel_embed.grid_size,
num_prefix_tokens=1,
)
return state_dict

View File

@ -37,7 +37,6 @@ except ImportError:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.jit import Final
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \
@ -1019,7 +1018,6 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str =
else:
pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
if pos_embed_w.shape != model.pos_embed.shape:
old_shape = pos_embed_w.shape
num_prefix_tokens = 0 if getattr(model, 'no_embed_class', False) else getattr(model, 'num_prefix_tokens', 1)
pos_embed_w = resample_abs_pos_embed( # resize pos embedding when different size from pretrained weights
pos_embed_w,

View File

@ -17,13 +17,12 @@ except ImportError:
import torch
import torch.nn as nn
from torch.jit import Final
from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from timm.layers import PatchEmbed, Mlp, DropPath, RelPosMlp, RelPosBias, use_fused_attn, LayerType
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._manipulate import named_apply
from ._manipulate import named_apply, checkpoint
from ._registry import generate_default_cfgs, register_model
from .vision_transformer import get_init_weights_vit

View File

@ -16,7 +16,6 @@ from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from timm.layers import PatchEmbed, Mlp, DropPath, PatchDropout, LayerNorm2d, ClassifierHead, NormMlpClassifierHead, \
Format, resample_abs_pos_embed_nhwc, RotaryEmbeddingCat, apply_rot_embed_cat, to_2tuple, use_fused_attn
@ -25,7 +24,7 @@ from torch.jit import Final
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._features_fx import register_notrace_function
from ._manipulate import checkpoint_seq
from ._manipulate import checkpoint_seq, checkpoint
from ._registry import generate_default_cfgs, register_model
# model_registry will add each entrypoint fn to this

View File

@ -26,12 +26,12 @@ import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, Mlp, to_2tuple, to_ntuple, trunc_normal_, use_fused_attn
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._manipulate import checkpoint
from ._registry import register_model, generate_default_cfgs
__all__ = ['VOLO'] # model_registry will add each entrypoint fn to this

View File

@ -17,16 +17,15 @@ from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, trunc_normal_, to_2tuple, use_fused_attn
from timm.layers import DropPath, trunc_normal_, to_2tuple, use_fused_attn, Mlp
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._features_fx import register_notrace_module
from ._manipulate import checkpoint
from ._registry import register_model, generate_default_cfgs, register_model_deprecations
from .cait import ClassAttn
from .vision_transformer import Mlp
__all__ = ['Xcit'] # model_registry will add each entrypoint fn to this