mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add default_cfg back to models wrapped in feature extraction module as per discussion in #294.
This commit is contained in:
parent
4ca52d73d8
commit
867a0e5a04
@ -34,7 +34,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCE
|
|||||||
from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
|
from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
|
||||||
from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights
|
from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights
|
||||||
from .features import FeatureInfo, FeatureHooks
|
from .features import FeatureInfo, FeatureHooks
|
||||||
from .helpers import build_model_with_cfg
|
from .helpers import build_model_with_cfg, default_cfg_for_features
|
||||||
from .layers import create_conv2d, create_classifier
|
from .layers import create_conv2d, create_classifier
|
||||||
from .registry import register_model
|
from .registry import register_model
|
||||||
|
|
||||||
@ -462,9 +462,11 @@ def _create_effnet(model_kwargs, variant, pretrained=False):
|
|||||||
else:
|
else:
|
||||||
load_strict = True
|
load_strict = True
|
||||||
model_cls = EfficientNet
|
model_cls = EfficientNet
|
||||||
return build_model_with_cfg(
|
model = build_model_with_cfg(
|
||||||
model_cls, variant, pretrained, default_cfg=default_cfgs[variant],
|
model_cls, variant, pretrained, default_cfg=default_cfgs[variant],
|
||||||
pretrained_strict=load_strict, **model_kwargs)
|
pretrained_strict=load_strict, **model_kwargs)
|
||||||
|
model.default_cfg = default_cfg_for_features(model.default_cfg)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
def _gen_mnasnet_a1(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
|
def _gen_mnasnet_a1(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
|
||||||
|
@ -251,6 +251,15 @@ def adapt_model_from_file(parent_module, model_variant):
|
|||||||
return adapt_model_from_string(parent_module, f.read().strip())
|
return adapt_model_from_string(parent_module, f.read().strip())
|
||||||
|
|
||||||
|
|
||||||
|
def default_cfg_for_features(default_cfg):
|
||||||
|
default_cfg = deepcopy(default_cfg)
|
||||||
|
# remove default pretrained cfg fields that don't have much relevance for feature backbone
|
||||||
|
to_remove = ('num_classes', 'crop_pct', 'classifier') # add default final pool size?
|
||||||
|
for tr in to_remove:
|
||||||
|
default_cfg.pop(tr, None)
|
||||||
|
return default_cfg
|
||||||
|
|
||||||
|
|
||||||
def build_model_with_cfg(
|
def build_model_with_cfg(
|
||||||
model_cls: Callable,
|
model_cls: Callable,
|
||||||
variant: str,
|
variant: str,
|
||||||
@ -296,5 +305,6 @@ def build_model_with_cfg(
|
|||||||
else:
|
else:
|
||||||
assert False, f'Unknown feature class {feature_cls}'
|
assert False, f'Unknown feature class {feature_cls}'
|
||||||
model = feature_cls(model, **feature_cfg)
|
model = feature_cls(model, **feature_cfg)
|
||||||
|
model.default_cfg = default_cfg_for_features(default_cfg) # add back default_cfg
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
@ -17,7 +17,7 @@ import torch.nn.functional as F
|
|||||||
|
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from .features import FeatureInfo
|
from .features import FeatureInfo
|
||||||
from .helpers import build_model_with_cfg
|
from .helpers import build_model_with_cfg, default_cfg_for_features
|
||||||
from .layers import create_classifier
|
from .layers import create_classifier
|
||||||
from .registry import register_model
|
from .registry import register_model
|
||||||
from .resnet import BasicBlock, Bottleneck # leveraging ResNet blocks w/ additional features like SE
|
from .resnet import BasicBlock, Bottleneck # leveraging ResNet blocks w/ additional features like SE
|
||||||
@ -779,9 +779,11 @@ def _create_hrnet(variant, pretrained, **model_kwargs):
|
|||||||
model_kwargs['num_classes'] = 0
|
model_kwargs['num_classes'] = 0
|
||||||
strict = False
|
strict = False
|
||||||
|
|
||||||
return build_model_with_cfg(
|
model = build_model_with_cfg(
|
||||||
model_cls, variant, pretrained, default_cfg=default_cfgs[variant],
|
model_cls, variant, pretrained, default_cfg=default_cfgs[variant],
|
||||||
model_cfg=cfg_cls[variant], pretrained_strict=strict, **model_kwargs)
|
model_cfg=cfg_cls[variant], pretrained_strict=strict, **model_kwargs)
|
||||||
|
model.default_cfg = default_cfg_for_features(model.default_cfg)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
|
@ -17,7 +17,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCE
|
|||||||
from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
|
from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
|
||||||
from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights
|
from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights
|
||||||
from .features import FeatureInfo, FeatureHooks
|
from .features import FeatureInfo, FeatureHooks
|
||||||
from .helpers import build_model_with_cfg
|
from .helpers import build_model_with_cfg, default_cfg_for_features
|
||||||
from .layers import SelectAdaptivePool2d, Linear, create_conv2d, get_act_fn, hard_sigmoid
|
from .layers import SelectAdaptivePool2d, Linear, create_conv2d, get_act_fn, hard_sigmoid
|
||||||
from .registry import register_model
|
from .registry import register_model
|
||||||
|
|
||||||
@ -211,9 +211,11 @@ def _create_mnv3(model_kwargs, variant, pretrained=False):
|
|||||||
else:
|
else:
|
||||||
load_strict = True
|
load_strict = True
|
||||||
model_cls = MobileNetV3
|
model_cls = MobileNetV3
|
||||||
return build_model_with_cfg(
|
model = build_model_with_cfg(
|
||||||
model_cls, variant, pretrained, default_cfg=default_cfgs[variant],
|
model_cls, variant, pretrained, default_cfg=default_cfgs[variant],
|
||||||
pretrained_strict=load_strict, **model_kwargs)
|
pretrained_strict=load_strict, **model_kwargs)
|
||||||
|
model.default_cfg = default_cfg_for_features(model.default_cfg)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
|
def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user