mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add a deprecation phase to module re-org
This commit is contained in:
parent
927f031293
commit
cda39b35bd
@ -19,7 +19,8 @@ import torch.nn as nn
|
||||
import torch.nn.parallel
|
||||
|
||||
from timm.data import resolve_data_config
|
||||
from timm.models import create_model, is_model, list_models, set_fast_norm
|
||||
from timm.layers import set_fast_norm
|
||||
from timm.models import create_model, is_model, list_models
|
||||
from timm.optim import create_optimizer_v2
|
||||
from timm.utils import setup_default_logging, set_jit_fuser, decay_batch_step, check_batch_size_retry
|
||||
|
||||
|
@ -23,6 +23,10 @@ _DOWNLOAD_PROGRESS = False
|
||||
_CHECK_HASH = False
|
||||
|
||||
|
||||
__all__ = ['set_pretrained_download_progress', 'set_pretrained_check_hash', 'load_custom_pretrained', 'load_pretrained',
|
||||
'pretrained_cfg_for_features', 'resolve_pretrained_cfg', 'build_model_with_cfg']
|
||||
|
||||
|
||||
def _resolve_pretrained_source(pretrained_cfg):
|
||||
cfg_source = pretrained_cfg.get('source', '')
|
||||
pretrained_url = pretrained_cfg.get('url', None)
|
||||
|
@ -9,6 +9,9 @@ from ._hub import load_model_config_from_hf
|
||||
from ._registry import is_model, model_entrypoint
|
||||
|
||||
|
||||
__all__ = ['parse_model_name', 'safe_model_name', 'create_model']
|
||||
|
||||
|
||||
def parse_model_name(model_name):
|
||||
if model_name.startswith('hf_hub'):
|
||||
# NOTE for backwards compat, deprecate hf_hub use
|
||||
|
@ -17,6 +17,9 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
__all__ = ['FeatureInfo', 'FeatureHooks', 'FeatureDictNet', 'FeatureListNet', 'FeatureHookNet']
|
||||
|
||||
|
||||
class FeatureInfo:
|
||||
|
||||
def __init__(self, feature_info: List[Dict], out_indices: Tuple[int]):
|
||||
|
@ -35,6 +35,10 @@ except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
__all__ = ['register_notrace_module', 'register_notrace_function', 'create_feature_extractor',
|
||||
'FeatureGraphNet', 'GraphExtractNet']
|
||||
|
||||
|
||||
def register_notrace_module(module: Type[nn.Module]):
|
||||
"""
|
||||
Any module not under timm.models.layers should get this decorator if we don't want to trace through it.
|
||||
|
@ -12,6 +12,8 @@ import timm.models._builder
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
__all__ = ['clean_state_dict', 'load_state_dict', 'load_checkpoint', 'remap_checkpoint', 'resume_checkpoint']
|
||||
|
||||
|
||||
def clean_state_dict(state_dict):
|
||||
# 'clean' checkpoint by removing .module prefix from state dict if it exists from parallel training
|
||||
|
@ -31,6 +31,9 @@ except ImportError:
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
__all__ = ['get_cache_dir', 'download_cached_file', 'has_hf_hub', 'hf_split', 'load_model_config_from_hf',
|
||||
'load_state_dict_from_hf', 'save_for_hf', 'push_to_hf_hub']
|
||||
|
||||
|
||||
def get_cache_dir(child_dir=''):
|
||||
"""
|
||||
|
@ -9,6 +9,9 @@ import torch
|
||||
from torch import nn as nn
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
__all__ = ['model_parameters', 'named_apply', 'named_modules', 'named_modules_with_params', 'adapt_input_conv',
|
||||
'group_with_matcher', 'group_modules', 'group_parameters', 'flatten_modules', 'checkpoint_seq']
|
||||
|
||||
|
||||
def model_parameters(model, exclude_head=False):
|
||||
if exclude_head:
|
||||
|
@ -4,6 +4,9 @@ from dataclasses import dataclass, field, replace, asdict
|
||||
from typing import Any, Deque, Dict, Tuple, Optional, Union
|
||||
|
||||
|
||||
__all__ = ['PretrainedCfg', 'filter_pretrained_cfg', 'DefaultCfg', 'split_model_name_tag', 'generate_default_cfgs']
|
||||
|
||||
|
||||
@dataclass
|
||||
class PretrainedCfg:
|
||||
"""
|
||||
|
@ -5,6 +5,8 @@ from torch import nn as nn
|
||||
|
||||
from timm.layers import Conv2dSame, BatchNormAct2d, Linear
|
||||
|
||||
__all__ = ['extract_layer', 'set_layer', 'adapt_model_from_string', 'adapt_model_from_file']
|
||||
|
||||
|
||||
def extract_layer(model, layer):
|
||||
layer = layer.split('.')
|
||||
|
@ -12,7 +12,7 @@ from typing import List, Optional, Union, Tuple
|
||||
from ._pretrained import PretrainedCfg, DefaultCfg, split_model_name_tag
|
||||
|
||||
__all__ = [
|
||||
'list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules',
|
||||
'list_models', 'list_pretrained', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules',
|
||||
'get_pretrained_cfg_value', 'is_model_pretrained', 'get_arch_name']
|
||||
|
||||
_module_to_models = defaultdict(set) # dict of sets to check membership of model in module
|
||||
|
4
timm/models/factory.py
Normal file
4
timm/models/factory.py
Normal file
@ -0,0 +1,4 @@
|
||||
from ._factory import *
|
||||
|
||||
import warnings
|
||||
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning)
|
4
timm/models/features.py
Normal file
4
timm/models/features.py
Normal file
@ -0,0 +1,4 @@
|
||||
from ._features import *
|
||||
|
||||
import warnings
|
||||
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning)
|
4
timm/models/fx_features.py
Normal file
4
timm/models/fx_features.py
Normal file
@ -0,0 +1,4 @@
|
||||
from ._features_fx import *
|
||||
|
||||
import warnings
|
||||
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning)
|
7
timm/models/helpers.py
Normal file
7
timm/models/helpers.py
Normal file
@ -0,0 +1,7 @@
|
||||
from ._builder import *
|
||||
from ._helpers import *
|
||||
from ._manipulate import *
|
||||
from ._prune import *
|
||||
|
||||
import warnings
|
||||
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning)
|
4
timm/models/hub.py
Normal file
4
timm/models/hub.py
Normal file
@ -0,0 +1,4 @@
|
||||
from _hub import *
|
||||
|
||||
import warnings
|
||||
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning)
|
@ -43,3 +43,6 @@ from timm.layers.std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, Scal
|
||||
from timm.layers.test_time_pool import TestTimePoolHead, apply_test_time_pool
|
||||
from timm.layers.trace_utils import _assert, _float_to_int
|
||||
from timm.layers.weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_
|
||||
|
||||
import warnings
|
||||
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning)
|
||||
|
4
timm/models/registry.py
Normal file
4
timm/models/registry.py
Normal file
@ -0,0 +1,4 @@
|
||||
from ._registry import *
|
||||
|
||||
import warnings
|
||||
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning)
|
Loading…
x
Reference in New Issue
Block a user