mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Wrap torch checkpoint() fn to default use_reentrant flag to False and allow env var override
This commit is contained in:
parent
131518c15c
commit
2b251fb291
@ -8,7 +8,8 @@ from .blur_pool import BlurPool2d, create_aa
|
||||
from .classifier import create_classifier, ClassifierHead, NormMlpClassifierHead, ClNormMlpClassifierHead
|
||||
from .cond_conv2d import CondConv2d, get_condconv_initializer
|
||||
from .config import is_exportable, is_scriptable, is_no_jit, use_fused_attn, \
|
||||
set_exportable, set_scriptable, set_no_jit, set_layer_config, set_fused_attn
|
||||
set_exportable, set_scriptable, set_no_jit, set_layer_config, set_fused_attn, \
|
||||
set_reentrant_ckpt, use_reentrant_ckpt
|
||||
from .conv2d_same import Conv2dSame, conv2d_same
|
||||
from .conv_bn_act import ConvNormAct, ConvNormActAa, ConvBnAct
|
||||
from .create_act import create_act_layer, get_act_layer, get_act_fn
|
||||
|
@ -8,7 +8,8 @@ import torch
|
||||
|
||||
__all__ = [
|
||||
'is_exportable', 'is_scriptable', 'is_no_jit', 'use_fused_attn',
|
||||
'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config', 'set_fused_attn'
|
||||
'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config', 'set_fused_attn',
|
||||
'set_reentrant_ckpt', 'use_reentrant_ckpt'
|
||||
]
|
||||
|
||||
# Set to True if prefer to have layers with no jit optimization (includes activations)
|
||||
@ -34,6 +35,12 @@ else:
|
||||
_USE_FUSED_ATTN = 1 # 0 == off, 1 == on (for tested use), 2 == on (for experimental use)
|
||||
|
||||
|
||||
if 'TIMM_REENTRANT_CKPT' in os.environ:
|
||||
_USE_REENTRANT_CKPT = bool(os.environ['TIMM_REENTRANT_CKPT'])
|
||||
else:
|
||||
_USE_REENTRANT_CKPT = False # defaults to disabled (off)
|
||||
|
||||
|
||||
def is_no_jit():
|
||||
return _NO_JIT
|
||||
|
||||
@ -147,3 +154,12 @@ def set_fused_attn(enable: bool = True, experimental: bool = False):
|
||||
_USE_FUSED_ATTN = 1
|
||||
else:
|
||||
_USE_FUSED_ATTN = 0
|
||||
|
||||
|
||||
def use_reentrant_ckpt() -> bool:
|
||||
return _USE_REENTRANT_CKPT
|
||||
|
||||
|
||||
def set_reentrant_ckpt(enable: bool = True):
|
||||
global _USE_REENTRANT_CKPT
|
||||
_USE_REENTRANT_CKPT = enable
|
@ -15,10 +15,9 @@ from typing import Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from timm.layers import Format, _assert
|
||||
|
||||
from ._manipulate import checkpoint
|
||||
|
||||
__all__ = [
|
||||
'FeatureInfo', 'FeatureHooks', 'FeatureDictNet', 'FeatureListNet', 'FeatureHookNet', 'FeatureGetterNet',
|
||||
|
@ -3,14 +3,17 @@ import math
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from itertools import chain
|
||||
from typing import Any, Callable, Dict, Iterator, Tuple, Type, Union
|
||||
from typing import Any, Callable, Dict, Iterator, Optional, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn as nn
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from timm.layers import use_reentrant_ckpt
|
||||
|
||||
|
||||
__all__ = ['model_parameters', 'named_apply', 'named_modules', 'named_modules_with_params', 'adapt_input_conv',
|
||||
'group_with_matcher', 'group_modules', 'group_parameters', 'flatten_modules', 'checkpoint_seq']
|
||||
'group_with_matcher', 'group_modules', 'group_parameters', 'flatten_modules', 'checkpoint_seq', 'checkpoint']
|
||||
|
||||
|
||||
def model_parameters(model: nn.Module, exclude_head: bool = False):
|
||||
@ -183,13 +186,35 @@ def flatten_modules(
|
||||
yield name, module
|
||||
|
||||
|
||||
def checkpoint(
|
||||
function,
|
||||
*args,
|
||||
use_reentrant: Optional[bool] = None,
|
||||
**kwargs,
|
||||
):
|
||||
""" checkpoint wrapper fn
|
||||
|
||||
A thin wrapper around torch.utils.checkpoint.checkpoint to default
|
||||
use_reentrant to False
|
||||
"""
|
||||
if use_reentrant is None:
|
||||
use_reentrant = use_reentrant_ckpt()
|
||||
|
||||
return torch.utils.checkpoint.checkpoint(
|
||||
function,
|
||||
*args,
|
||||
use_reentrant=use_reentrant,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def checkpoint_seq(
|
||||
functions,
|
||||
x,
|
||||
every=1,
|
||||
flatten=False,
|
||||
skip_last=False,
|
||||
preserve_rng_state=True
|
||||
every: int = 1,
|
||||
flatten: bool = False,
|
||||
skip_last: bool = False,
|
||||
use_reentrant: Optional[bool] = None,
|
||||
):
|
||||
r"""A helper function for checkpointing sequential models.
|
||||
|
||||
@ -215,10 +240,9 @@ def checkpoint_seq(
|
||||
functions: A :class:`torch.nn.Sequential` or the list of modules or functions to run sequentially.
|
||||
x: A Tensor that is input to :attr:`functions`
|
||||
every: checkpoint every-n functions (default: 1)
|
||||
flatten (bool): flatten nn.Sequential of nn.Sequentials
|
||||
skip_last (bool): skip checkpointing the last function in the sequence if True
|
||||
preserve_rng_state (bool, optional, default=True): Omit stashing and restoring
|
||||
the RNG state during each checkpoint.
|
||||
flatten: flatten nn.Sequential of nn.Sequentials
|
||||
skip_last: skip checkpointing the last function in the sequence if True
|
||||
use_reentrant: Use re-entrant checkpointing
|
||||
|
||||
Returns:
|
||||
Output of running :attr:`functions` sequentially on :attr:`*inputs`
|
||||
@ -227,6 +251,9 @@ def checkpoint_seq(
|
||||
>>> model = nn.Sequential(...)
|
||||
>>> input_var = checkpoint_seq(model, input_var, every=2)
|
||||
"""
|
||||
if use_reentrant is None:
|
||||
use_reentrant = use_reentrant_ckpt()
|
||||
|
||||
def run_function(start, end, functions):
|
||||
def forward(_x):
|
||||
for j in range(start, end + 1):
|
||||
@ -247,7 +274,11 @@ def checkpoint_seq(
|
||||
end = -1
|
||||
for start in range(0, num_checkpointed, every):
|
||||
end = min(start + every - 1, num_checkpointed - 1)
|
||||
x = checkpoint(run_function(start, end, functions), x, preserve_rng_state=preserve_rng_state)
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
run_function(start, end, functions),
|
||||
x,
|
||||
use_reentrant=use_reentrant,
|
||||
)
|
||||
if skip_last:
|
||||
return run_function(end + 1, len(functions) - 1, functions)(x)
|
||||
return x
|
||||
|
@ -44,15 +44,14 @@ from typing import Callable, List, Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import PatchEmbed, Mlp, SwiGLU, LayerNorm, DropPath, trunc_normal_, use_fused_attn
|
||||
from timm.layers import resample_patch_embed, resample_abs_pos_embed, resize_rel_pos_bias_table, ndgrid
|
||||
|
||||
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import feature_take_indices
|
||||
from ._manipulate import checkpoint
|
||||
from ._registry import generate_default_cfgs, register_model
|
||||
|
||||
__all__ = ['Beit']
|
||||
|
@ -8,13 +8,12 @@ from collections import OrderedDict
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint as cp
|
||||
from torch.jit.annotations import List
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import BatchNormAct2d, get_norm_act_layer, BlurPool2d, create_classifier
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import MATCH_PREV_GROUP
|
||||
from ._manipulate import MATCH_PREV_GROUP, checkpoint
|
||||
from ._registry import register_model, generate_default_cfgs, register_model_deprecations
|
||||
|
||||
__all__ = ['DenseNet']
|
||||
@ -60,7 +59,7 @@ class DenseLayer(nn.Module):
|
||||
def closure(*xs):
|
||||
return self.bottleneck_fn(xs)
|
||||
|
||||
return cp.checkpoint(closure, *x)
|
||||
return checkpoint(closure, *x)
|
||||
|
||||
@torch.jit._overload_method # noqa: F811
|
||||
def forward(self, x):
|
||||
|
@ -41,7 +41,6 @@ from typing import Callable, List, Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
from timm.layers import create_conv2d, create_classifier, get_norm_act_layer, LayerType, \
|
||||
@ -51,7 +50,7 @@ from ._efficientnet_blocks import SqueezeExcite
|
||||
from ._efficientnet_builder import BlockArgs, EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \
|
||||
round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
|
||||
from ._features import FeatureInfo, FeatureHooks, feature_take_indices
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._manipulate import checkpoint_seq, checkpoint
|
||||
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
||||
|
||||
__all__ = ['EfficientNet', 'EfficientNetFeatures']
|
||||
|
@ -30,7 +30,6 @@ from typing import Callable, List, Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||
from timm.layers import PatchEmbed, Mlp, GluMlp, SwiGLU, LayerNorm, DropPath, PatchDropout, RotaryEmbeddingCat, \
|
||||
@ -39,6 +38,7 @@ from timm.layers import PatchEmbed, Mlp, GluMlp, SwiGLU, LayerNorm, DropPath, Pa
|
||||
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import feature_take_indices
|
||||
from ._manipulate import checkpoint
|
||||
from ._registry import generate_default_cfgs, register_model
|
||||
|
||||
__all__ = ['Eva']
|
||||
|
@ -22,12 +22,11 @@ from typing import Callable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint as checkpoint
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import Mlp, DropPath, LayerNorm2d, trunc_normal_, ClassifierHead, NormMlpClassifierHead
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import named_apply
|
||||
from ._manipulate import named_apply, checkpoint
|
||||
from ._registry import generate_default_cfgs, register_model
|
||||
|
||||
__all__ = ['FocalNet']
|
||||
|
@ -25,14 +25,13 @@ from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint as checkpoint
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import DropPath, to_2tuple, to_ntuple, Mlp, ClassifierHead, LayerNorm2d, \
|
||||
get_attn, get_act_layer, get_norm_layer, RelPosBias, _assert
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features_fx import register_notrace_function
|
||||
from ._manipulate import named_apply
|
||||
from ._manipulate import named_apply, checkpoint
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
|
||||
__all__ = ['GlobalContextVit']
|
||||
|
@ -29,7 +29,6 @@ from typing import Callable, Dict, List, Optional, Tuple, Type, Union
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import DropPath, Mlp, LayerScale, ClNormMlpClassifierHead, use_fused_attn, \
|
||||
@ -39,7 +38,7 @@ from ._registry import generate_default_cfgs, register_model
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import feature_take_indices
|
||||
from ._features_fx import register_notrace_function
|
||||
from ._manipulate import named_apply
|
||||
from ._manipulate import named_apply, checkpoint
|
||||
|
||||
|
||||
__all__ = ['Hiera']
|
||||
|
@ -12,7 +12,6 @@ from typing import Callable, List, Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
from timm.layers import SelectAdaptivePool2d, Linear, LayerType, PadType, create_conv2d, get_norm_act_layer
|
||||
@ -21,7 +20,7 @@ from ._efficientnet_blocks import SqueezeExcite
|
||||
from ._efficientnet_builder import BlockArgs, EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \
|
||||
round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
|
||||
from ._features import FeatureInfo, FeatureHooks, feature_take_indices
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._manipulate import checkpoint_seq, checkpoint
|
||||
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
||||
|
||||
__all__ = ['MobileNetV3', 'MobileNetV3Features']
|
||||
|
@ -20,7 +20,6 @@ from functools import partial, reduce
|
||||
from typing import Union, List, Tuple, Optional
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint as checkpoint
|
||||
from torch import nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
@ -28,7 +27,8 @@ from timm.layers import Mlp, DropPath, trunc_normal_tf_, get_norm_layer, to_2tup
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import feature_take_indices
|
||||
from ._features_fx import register_notrace_function
|
||||
from ._registry import register_model, register_model_deprecations, generate_default_cfgs
|
||||
from ._manipulate import checkpoint
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
|
||||
__all__ = ['MultiScaleVit', 'MultiScaleVitCfg'] # model_registry will add each entrypoint fn to this
|
||||
|
||||
|
@ -21,11 +21,11 @@ from typing import Callable, List, Optional, Union
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint as checkpoint
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import DropPath, to_2tuple, to_ntuple, trunc_normal_, LayerNorm, use_fused_attn
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import checkpoint
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
|
||||
__all__ = ['PyramidVisionTransformerV2']
|
||||
|
@ -18,14 +18,14 @@ from typing import Callable, List, Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint as checkpoint
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, _assert, ClassifierHead,\
|
||||
from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, ClassifierHead,\
|
||||
resample_patch_embed, ndgrid, get_act_layer, LayerType
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import feature_take_indices
|
||||
from ._features_fx import register_notrace_function
|
||||
from ._manipulate import checkpoint
|
||||
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
||||
|
||||
__all__ = ['SwinTransformerV2'] # model_registry will add each entrypoint fn to this
|
||||
|
@ -34,14 +34,13 @@ from typing import Tuple, Optional, List, Union, Any, Type
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint as checkpoint
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import DropPath, Mlp, ClassifierHead, to_2tuple, _assert, ndgrid
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import feature_take_indices
|
||||
from ._features_fx import register_notrace_function
|
||||
from ._manipulate import named_apply
|
||||
from ._manipulate import named_apply, checkpoint
|
||||
from ._registry import generate_default_cfgs, register_model
|
||||
|
||||
__all__ = ['SwinTransformerV2Cr'] # model_registry will add each entrypoint fn to this
|
||||
|
@ -11,13 +11,13 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import Mlp, DropPath, trunc_normal_, _assert, to_2tuple
|
||||
from timm.layers import Mlp, DropPath, trunc_normal_, _assert, to_2tuple, resample_abs_pos_embed
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import checkpoint
|
||||
from ._registry import register_model
|
||||
from .vision_transformer import resize_pos_embed
|
||||
|
||||
|
||||
__all__ = ['TNT'] # model_registry will add each entrypoint fn to this
|
||||
|
||||
@ -340,8 +340,11 @@ class TNT(nn.Module):
|
||||
def checkpoint_filter_fn(state_dict, model):
|
||||
""" convert patch embedding weight from manual patchify + linear proj to conv"""
|
||||
if state_dict['patch_pos'].shape != model.patch_pos.shape:
|
||||
state_dict['patch_pos'] = resize_pos_embed(state_dict['patch_pos'],
|
||||
model.patch_pos, getattr(model, 'num_tokens', 1), model.pixel_embed.grid_size)
|
||||
state_dict['patch_pos'] = resample_abs_pos_embed(
|
||||
state_dict['patch_pos'],
|
||||
new_size=model.pixel_embed.grid_size,
|
||||
num_prefix_tokens=1,
|
||||
)
|
||||
return state_dict
|
||||
|
||||
|
||||
|
@ -37,7 +37,6 @@ except ImportError:
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from torch.jit import Final
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \
|
||||
@ -1019,7 +1018,6 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str =
|
||||
else:
|
||||
pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
|
||||
if pos_embed_w.shape != model.pos_embed.shape:
|
||||
old_shape = pos_embed_w.shape
|
||||
num_prefix_tokens = 0 if getattr(model, 'no_embed_class', False) else getattr(model, 'num_prefix_tokens', 1)
|
||||
pos_embed_w = resample_abs_pos_embed( # resize pos embedding when different size from pretrained weights
|
||||
pos_embed_w,
|
||||
|
@ -17,13 +17,12 @@ except ImportError:
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.jit import Final
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
from timm.layers import PatchEmbed, Mlp, DropPath, RelPosMlp, RelPosBias, use_fused_attn, LayerType
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import feature_take_indices
|
||||
from ._manipulate import named_apply
|
||||
from ._manipulate import named_apply, checkpoint
|
||||
from ._registry import generate_default_cfgs, register_model
|
||||
from .vision_transformer import get_init_weights_vit
|
||||
|
||||
|
@ -16,7 +16,6 @@ from typing import Callable, List, Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
from timm.layers import PatchEmbed, Mlp, DropPath, PatchDropout, LayerNorm2d, ClassifierHead, NormMlpClassifierHead, \
|
||||
Format, resample_abs_pos_embed_nhwc, RotaryEmbeddingCat, apply_rot_embed_cat, to_2tuple, use_fused_attn
|
||||
@ -25,7 +24,7 @@ from torch.jit import Final
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import feature_take_indices
|
||||
from ._features_fx import register_notrace_function
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._manipulate import checkpoint_seq, checkpoint
|
||||
from ._registry import generate_default_cfgs, register_model
|
||||
|
||||
# model_registry will add each entrypoint fn to this
|
||||
|
@ -26,12 +26,12 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import DropPath, Mlp, to_2tuple, to_ntuple, trunc_normal_, use_fused_attn
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import feature_take_indices
|
||||
from ._manipulate import checkpoint
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
|
||||
__all__ = ['VOLO'] # model_registry will add each entrypoint fn to this
|
||||
|
@ -17,16 +17,15 @@ from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import DropPath, trunc_normal_, to_2tuple, use_fused_attn
|
||||
from timm.layers import DropPath, trunc_normal_, to_2tuple, use_fused_attn, Mlp
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import feature_take_indices
|
||||
from ._features_fx import register_notrace_module
|
||||
from ._manipulate import checkpoint
|
||||
from ._registry import register_model, generate_default_cfgs, register_model_deprecations
|
||||
from .cait import ClassAttn
|
||||
from .vision_transformer import Mlp
|
||||
|
||||
__all__ = ['Xcit'] # model_registry will add each entrypoint fn to this
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user