mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add a deprecation phase to module re-org
This commit is contained in:
parent
927f031293
commit
cda39b35bd
@ -19,7 +19,8 @@ import torch.nn as nn
|
|||||||
import torch.nn.parallel
|
import torch.nn.parallel
|
||||||
|
|
||||||
from timm.data import resolve_data_config
|
from timm.data import resolve_data_config
|
||||||
from timm.models import create_model, is_model, list_models, set_fast_norm
|
from timm.layers import set_fast_norm
|
||||||
|
from timm.models import create_model, is_model, list_models
|
||||||
from timm.optim import create_optimizer_v2
|
from timm.optim import create_optimizer_v2
|
||||||
from timm.utils import setup_default_logging, set_jit_fuser, decay_batch_step, check_batch_size_retry
|
from timm.utils import setup_default_logging, set_jit_fuser, decay_batch_step, check_batch_size_retry
|
||||||
|
|
||||||
|
@ -23,6 +23,10 @@ _DOWNLOAD_PROGRESS = False
|
|||||||
_CHECK_HASH = False
|
_CHECK_HASH = False
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ['set_pretrained_download_progress', 'set_pretrained_check_hash', 'load_custom_pretrained', 'load_pretrained',
|
||||||
|
'pretrained_cfg_for_features', 'resolve_pretrained_cfg', 'build_model_with_cfg']
|
||||||
|
|
||||||
|
|
||||||
def _resolve_pretrained_source(pretrained_cfg):
|
def _resolve_pretrained_source(pretrained_cfg):
|
||||||
cfg_source = pretrained_cfg.get('source', '')
|
cfg_source = pretrained_cfg.get('source', '')
|
||||||
pretrained_url = pretrained_cfg.get('url', None)
|
pretrained_url = pretrained_cfg.get('url', None)
|
||||||
|
@ -9,6 +9,9 @@ from ._hub import load_model_config_from_hf
|
|||||||
from ._registry import is_model, model_entrypoint
|
from ._registry import is_model, model_entrypoint
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ['parse_model_name', 'safe_model_name', 'create_model']
|
||||||
|
|
||||||
|
|
||||||
def parse_model_name(model_name):
|
def parse_model_name(model_name):
|
||||||
if model_name.startswith('hf_hub'):
|
if model_name.startswith('hf_hub'):
|
||||||
# NOTE for backwards compat, deprecate hf_hub use
|
# NOTE for backwards compat, deprecate hf_hub use
|
||||||
|
@ -17,6 +17,9 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ['FeatureInfo', 'FeatureHooks', 'FeatureDictNet', 'FeatureListNet', 'FeatureHookNet']
|
||||||
|
|
||||||
|
|
||||||
class FeatureInfo:
|
class FeatureInfo:
|
||||||
|
|
||||||
def __init__(self, feature_info: List[Dict], out_indices: Tuple[int]):
|
def __init__(self, feature_info: List[Dict], out_indices: Tuple[int]):
|
||||||
|
@ -35,6 +35,10 @@ except ImportError:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ['register_notrace_module', 'register_notrace_function', 'create_feature_extractor',
|
||||||
|
'FeatureGraphNet', 'GraphExtractNet']
|
||||||
|
|
||||||
|
|
||||||
def register_notrace_module(module: Type[nn.Module]):
|
def register_notrace_module(module: Type[nn.Module]):
|
||||||
"""
|
"""
|
||||||
Any module not under timm.models.layers should get this decorator if we don't want to trace through it.
|
Any module not under timm.models.layers should get this decorator if we don't want to trace through it.
|
||||||
|
@ -12,6 +12,8 @@ import timm.models._builder
|
|||||||
|
|
||||||
_logger = logging.getLogger(__name__)
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
__all__ = ['clean_state_dict', 'load_state_dict', 'load_checkpoint', 'remap_checkpoint', 'resume_checkpoint']
|
||||||
|
|
||||||
|
|
||||||
def clean_state_dict(state_dict):
|
def clean_state_dict(state_dict):
|
||||||
# 'clean' checkpoint by removing .module prefix from state dict if it exists from parallel training
|
# 'clean' checkpoint by removing .module prefix from state dict if it exists from parallel training
|
||||||
|
@ -31,6 +31,9 @@ except ImportError:
|
|||||||
|
|
||||||
_logger = logging.getLogger(__name__)
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
__all__ = ['get_cache_dir', 'download_cached_file', 'has_hf_hub', 'hf_split', 'load_model_config_from_hf',
|
||||||
|
'load_state_dict_from_hf', 'save_for_hf', 'push_to_hf_hub']
|
||||||
|
|
||||||
|
|
||||||
def get_cache_dir(child_dir=''):
|
def get_cache_dir(child_dir=''):
|
||||||
"""
|
"""
|
||||||
|
@ -9,6 +9,9 @@ import torch
|
|||||||
from torch import nn as nn
|
from torch import nn as nn
|
||||||
from torch.utils.checkpoint import checkpoint
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
|
||||||
|
__all__ = ['model_parameters', 'named_apply', 'named_modules', 'named_modules_with_params', 'adapt_input_conv',
|
||||||
|
'group_with_matcher', 'group_modules', 'group_parameters', 'flatten_modules', 'checkpoint_seq']
|
||||||
|
|
||||||
|
|
||||||
def model_parameters(model, exclude_head=False):
|
def model_parameters(model, exclude_head=False):
|
||||||
if exclude_head:
|
if exclude_head:
|
||||||
|
@ -4,6 +4,9 @@ from dataclasses import dataclass, field, replace, asdict
|
|||||||
from typing import Any, Deque, Dict, Tuple, Optional, Union
|
from typing import Any, Deque, Dict, Tuple, Optional, Union
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ['PretrainedCfg', 'filter_pretrained_cfg', 'DefaultCfg', 'split_model_name_tag', 'generate_default_cfgs']
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PretrainedCfg:
|
class PretrainedCfg:
|
||||||
"""
|
"""
|
||||||
|
@ -5,6 +5,8 @@ from torch import nn as nn
|
|||||||
|
|
||||||
from timm.layers import Conv2dSame, BatchNormAct2d, Linear
|
from timm.layers import Conv2dSame, BatchNormAct2d, Linear
|
||||||
|
|
||||||
|
__all__ = ['extract_layer', 'set_layer', 'adapt_model_from_string', 'adapt_model_from_file']
|
||||||
|
|
||||||
|
|
||||||
def extract_layer(model, layer):
|
def extract_layer(model, layer):
|
||||||
layer = layer.split('.')
|
layer = layer.split('.')
|
||||||
|
@ -12,7 +12,7 @@ from typing import List, Optional, Union, Tuple
|
|||||||
from ._pretrained import PretrainedCfg, DefaultCfg, split_model_name_tag
|
from ._pretrained import PretrainedCfg, DefaultCfg, split_model_name_tag
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules',
|
'list_models', 'list_pretrained', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules',
|
||||||
'get_pretrained_cfg_value', 'is_model_pretrained', 'get_arch_name']
|
'get_pretrained_cfg_value', 'is_model_pretrained', 'get_arch_name']
|
||||||
|
|
||||||
_module_to_models = defaultdict(set) # dict of sets to check membership of model in module
|
_module_to_models = defaultdict(set) # dict of sets to check membership of model in module
|
||||||
|
4
timm/models/factory.py
Normal file
4
timm/models/factory.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
from ._factory import *
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning)
|
4
timm/models/features.py
Normal file
4
timm/models/features.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
from ._features import *
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning)
|
4
timm/models/fx_features.py
Normal file
4
timm/models/fx_features.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
from ._features_fx import *
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning)
|
7
timm/models/helpers.py
Normal file
7
timm/models/helpers.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
from ._builder import *
|
||||||
|
from ._helpers import *
|
||||||
|
from ._manipulate import *
|
||||||
|
from ._prune import *
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning)
|
4
timm/models/hub.py
Normal file
4
timm/models/hub.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
from _hub import *
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning)
|
@ -43,3 +43,6 @@ from timm.layers.std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, Scal
|
|||||||
from timm.layers.test_time_pool import TestTimePoolHead, apply_test_time_pool
|
from timm.layers.test_time_pool import TestTimePoolHead, apply_test_time_pool
|
||||||
from timm.layers.trace_utils import _assert, _float_to_int
|
from timm.layers.trace_utils import _assert, _float_to_int
|
||||||
from timm.layers.weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_
|
from timm.layers.weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning)
|
||||||
|
4
timm/models/registry.py
Normal file
4
timm/models/registry.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
from ._registry import *
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning)
|
Loading…
x
Reference in New Issue
Block a user