Add default_cfg back to models wrapped in feature extraction module as per discussion in #294.
parent
4ca52d73d8
commit
867a0e5a04
|
@ -34,7 +34,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCE
|
|||
from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
|
||||
from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights
|
||||
from .features import FeatureInfo, FeatureHooks
|
||||
from .helpers import build_model_with_cfg
|
||||
from .helpers import build_model_with_cfg, default_cfg_for_features
|
||||
from .layers import create_conv2d, create_classifier
|
||||
from .registry import register_model
|
||||
|
||||
|
@ -462,9 +462,11 @@ def _create_effnet(model_kwargs, variant, pretrained=False):
|
|||
else:
|
||||
load_strict = True
|
||||
model_cls = EfficientNet
|
||||
return build_model_with_cfg(
|
||||
model = build_model_with_cfg(
|
||||
model_cls, variant, pretrained, default_cfg=default_cfgs[variant],
|
||||
pretrained_strict=load_strict, **model_kwargs)
|
||||
model.default_cfg = default_cfg_for_features(model.default_cfg)
|
||||
return model
|
||||
|
||||
|
||||
def _gen_mnasnet_a1(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
|
||||
|
|
|
@ -251,6 +251,15 @@ def adapt_model_from_file(parent_module, model_variant):
|
|||
return adapt_model_from_string(parent_module, f.read().strip())
|
||||
|
||||
|
||||
def default_cfg_for_features(default_cfg):
|
||||
default_cfg = deepcopy(default_cfg)
|
||||
# remove default pretrained cfg fields that don't have much relevance for feature backbone
|
||||
to_remove = ('num_classes', 'crop_pct', 'classifier') # add default final pool size?
|
||||
for tr in to_remove:
|
||||
default_cfg.pop(tr, None)
|
||||
return default_cfg
|
||||
|
||||
|
||||
def build_model_with_cfg(
|
||||
model_cls: Callable,
|
||||
variant: str,
|
||||
|
@ -296,5 +305,6 @@ def build_model_with_cfg(
|
|||
else:
|
||||
assert False, f'Unknown feature class {feature_cls}'
|
||||
model = feature_cls(model, **feature_cfg)
|
||||
model.default_cfg = default_cfg_for_features(default_cfg) # add back default_cfg
|
||||
|
||||
return model
|
||||
|
|
|
@ -17,7 +17,7 @@ import torch.nn.functional as F
|
|||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .features import FeatureInfo
|
||||
from .helpers import build_model_with_cfg
|
||||
from .helpers import build_model_with_cfg, default_cfg_for_features
|
||||
from .layers import create_classifier
|
||||
from .registry import register_model
|
||||
from .resnet import BasicBlock, Bottleneck # leveraging ResNet blocks w/ additional features like SE
|
||||
|
@ -779,9 +779,11 @@ def _create_hrnet(variant, pretrained, **model_kwargs):
|
|||
model_kwargs['num_classes'] = 0
|
||||
strict = False
|
||||
|
||||
return build_model_with_cfg(
|
||||
model = build_model_with_cfg(
|
||||
model_cls, variant, pretrained, default_cfg=default_cfgs[variant],
|
||||
model_cfg=cfg_cls[variant], pretrained_strict=strict, **model_kwargs)
|
||||
model.default_cfg = default_cfg_for_features(model.default_cfg)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
|
|
|
@ -17,7 +17,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCE
|
|||
from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
|
||||
from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights
|
||||
from .features import FeatureInfo, FeatureHooks
|
||||
from .helpers import build_model_with_cfg
|
||||
from .helpers import build_model_with_cfg, default_cfg_for_features
|
||||
from .layers import SelectAdaptivePool2d, Linear, create_conv2d, get_act_fn, hard_sigmoid
|
||||
from .registry import register_model
|
||||
|
||||
|
@ -211,9 +211,11 @@ def _create_mnv3(model_kwargs, variant, pretrained=False):
|
|||
else:
|
||||
load_strict = True
|
||||
model_cls = MobileNetV3
|
||||
return build_model_with_cfg(
|
||||
model = build_model_with_cfg(
|
||||
model_cls, variant, pretrained, default_cfg=default_cfgs[variant],
|
||||
pretrained_strict=load_strict, **model_kwargs)
|
||||
model.default_cfg = default_cfg_for_features(model.default_cfg)
|
||||
return model
|
||||
|
||||
|
||||
def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
|
||||
|
|
Loading…
Reference in New Issue