mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Bug in last mod for features_only default_cfg
This commit is contained in:
parent
867a0e5a04
commit
cd72e66eff
@ -453,19 +453,19 @@ class EfficientNetFeatures(nn.Module):
|
||||
|
||||
|
||||
def _create_effnet(model_kwargs, variant, pretrained=False):
|
||||
features_only = False
|
||||
model_cls = EfficientNet
|
||||
if model_kwargs.pop('features_only', False):
|
||||
load_strict = False
|
||||
features_only = True
|
||||
model_kwargs.pop('num_classes', 0)
|
||||
model_kwargs.pop('num_features', 0)
|
||||
model_kwargs.pop('head_conv', None)
|
||||
model_cls = EfficientNetFeatures
|
||||
else:
|
||||
load_strict = True
|
||||
model_cls = EfficientNet
|
||||
model = build_model_with_cfg(
|
||||
model_cls, variant, pretrained, default_cfg=default_cfgs[variant],
|
||||
pretrained_strict=load_strict, **model_kwargs)
|
||||
model.default_cfg = default_cfg_for_features(model.default_cfg)
|
||||
pretrained_strict=not features_only, **model_kwargs)
|
||||
if features_only:
|
||||
model.default_cfg = default_cfg_for_features(model.default_cfg)
|
||||
return model
|
||||
|
||||
|
||||
|
@ -773,16 +773,16 @@ class HighResolutionNetFeatures(HighResolutionNet):
|
||||
|
||||
def _create_hrnet(variant, pretrained, **model_kwargs):
|
||||
model_cls = HighResolutionNet
|
||||
strict = True
|
||||
features_only = False
|
||||
if model_kwargs.pop('features_only', False):
|
||||
model_cls = HighResolutionNetFeatures
|
||||
model_kwargs['num_classes'] = 0
|
||||
strict = False
|
||||
|
||||
features_only = True
|
||||
model = build_model_with_cfg(
|
||||
model_cls, variant, pretrained, default_cfg=default_cfgs[variant],
|
||||
model_cfg=cfg_cls[variant], pretrained_strict=strict, **model_kwargs)
|
||||
model.default_cfg = default_cfg_for_features(model.default_cfg)
|
||||
model_cfg=cfg_cls[variant], pretrained_strict=not features_only, **model_kwargs)
|
||||
if features_only:
|
||||
model.default_cfg = default_cfg_for_features(model.default_cfg)
|
||||
return model
|
||||
|
||||
|
||||
|
@ -201,20 +201,20 @@ class MobileNetV3Features(nn.Module):
|
||||
|
||||
|
||||
def _create_mnv3(model_kwargs, variant, pretrained=False):
|
||||
features_only = False
|
||||
model_cls = MobileNetV3
|
||||
if model_kwargs.pop('features_only', False):
|
||||
load_strict = False
|
||||
features_only = True
|
||||
model_kwargs.pop('num_classes', 0)
|
||||
model_kwargs.pop('num_features', 0)
|
||||
model_kwargs.pop('head_conv', None)
|
||||
model_kwargs.pop('head_bias', None)
|
||||
model_cls = MobileNetV3Features
|
||||
else:
|
||||
load_strict = True
|
||||
model_cls = MobileNetV3
|
||||
model = build_model_with_cfg(
|
||||
model_cls, variant, pretrained, default_cfg=default_cfgs[variant],
|
||||
pretrained_strict=load_strict, **model_kwargs)
|
||||
model.default_cfg = default_cfg_for_features(model.default_cfg)
|
||||
pretrained_strict=not features_only, **model_kwargs)
|
||||
if features_only:
|
||||
model.default_cfg = default_cfg_for_features(model.default_cfg)
|
||||
return model
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user