mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Fix pruned adapt for EfficientNet models that are now using BatchNormAct layers
This commit is contained in:
parent
024fc4d9ab
commit
dc51334cdc
@ -20,7 +20,7 @@ from torch.utils.checkpoint import checkpoint
|
|||||||
from .features import FeatureListNet, FeatureDictNet, FeatureHookNet
|
from .features import FeatureListNet, FeatureDictNet, FeatureHookNet
|
||||||
from .fx_features import FeatureGraphNet
|
from .fx_features import FeatureGraphNet
|
||||||
from .hub import has_hf_hub, download_cached_file, load_state_dict_from_hf
|
from .hub import has_hf_hub, download_cached_file, load_state_dict_from_hf
|
||||||
from .layers import Conv2dSame, Linear
|
from .layers import Conv2dSame, Linear, BatchNormAct2d
|
||||||
from .registry import get_pretrained_cfg
|
from .registry import get_pretrained_cfg
|
||||||
|
|
||||||
|
|
||||||
@ -374,12 +374,19 @@ def adapt_model_from_string(parent_module, model_string):
|
|||||||
bias=old_module.bias is not None, padding=old_module.padding, dilation=old_module.dilation,
|
bias=old_module.bias is not None, padding=old_module.padding, dilation=old_module.dilation,
|
||||||
groups=g, stride=old_module.stride)
|
groups=g, stride=old_module.stride)
|
||||||
set_layer(new_module, n, new_conv)
|
set_layer(new_module, n, new_conv)
|
||||||
if isinstance(old_module, nn.BatchNorm2d):
|
elif isinstance(old_module, BatchNormAct2d):
|
||||||
|
new_bn = BatchNormAct2d(
|
||||||
|
state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum,
|
||||||
|
affine=old_module.affine, track_running_stats=True)
|
||||||
|
new_bn.drop = old_module.drop
|
||||||
|
new_bn.act = old_module.act
|
||||||
|
set_layer(new_module, n, new_bn)
|
||||||
|
elif isinstance(old_module, nn.BatchNorm2d):
|
||||||
new_bn = nn.BatchNorm2d(
|
new_bn = nn.BatchNorm2d(
|
||||||
num_features=state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum,
|
num_features=state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum,
|
||||||
affine=old_module.affine, track_running_stats=True)
|
affine=old_module.affine, track_running_stats=True)
|
||||||
set_layer(new_module, n, new_bn)
|
set_layer(new_module, n, new_bn)
|
||||||
if isinstance(old_module, nn.Linear):
|
elif isinstance(old_module, nn.Linear):
|
||||||
# FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer?
|
# FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer?
|
||||||
num_features = state_dict[n + '.weight'][1]
|
num_features = state_dict[n + '.weight'][1]
|
||||||
new_fc = Linear(
|
new_fc = Linear(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user