diff --git a/README.md b/README.md index afccc02d..7b1f06f0 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,9 @@ ## What's New +## Jan 6, 2025 +* Add `torch.utils.checkpoint.checkpoint()` wrapper in `timm.models` that defaults `use_reentrant=False`, unless `TIMM_REENTRANT_CKPT=1` is set in env. + ## Dec 31, 2024 * `convnext_nano` 384x384 ImageNet-12k pretrain & fine-tune. https://huggingface.co/models?search=convnext_nano%20r384 * Add AIM-v2 encoders from https://github.com/apple/ml-aim, see on Hub: https://huggingface.co/models?search=timm%20aimv2 diff --git a/timm/layers/config.py b/timm/layers/config.py index f69f3803..e2a23b5a 100644 --- a/timm/layers/config.py +++ b/timm/layers/config.py @@ -162,4 +162,4 @@ def use_reentrant_ckpt() -> bool: def set_reentrant_ckpt(enable: bool = True): global _USE_REENTRANT_CKPT - _USE_REENTRANT_CKPT = enable \ No newline at end of file + _USE_REENTRANT_CKPT = enable diff --git a/timm/models/__init__.py b/timm/models/__init__.py index c5b1984f..3db5af60 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -91,7 +91,7 @@ from ._features_fx import FeatureGraphNet, GraphExtractNet, create_feature_extra from ._helpers import clean_state_dict, load_state_dict, load_checkpoint, remap_state_dict, resume_checkpoint from ._hub import load_model_config_from_hf, load_state_dict_from_hf, push_to_hf_hub from ._manipulate import model_parameters, named_apply, named_modules, named_modules_with_params, \ - group_modules, group_parameters, checkpoint_seq, adapt_input_conv + group_modules, group_parameters, checkpoint_seq, checkpoint, adapt_input_conv from ._pretrained import PretrainedCfg, DefaultCfg, filter_pretrained_cfg from ._prune import adapt_model_from_string from ._registry import split_model_name_tag, get_arch_name, generate_default_cfgs, register_model, \ diff --git a/timm/models/vision_transformer_sam.py b/timm/models/vision_transformer_sam.py index 3fb0b59f..3979b3b4 100644 --- a/timm/models/vision_transformer_sam.py +++ b/timm/models/vision_transformer_sam.py @@ -24,7 +24,7 @@ from torch.jit import Final from ._builder import build_model_with_cfg from ._features import feature_take_indices from ._features_fx import register_notrace_function -from ._manipulate import checkpoint_seq, checkpoint +from ._manipulate import checkpoint_seq from ._registry import generate_default_cfgs, register_model # model_registry will add each entrypoint fn to this