Merge pull request #1850 from huggingface/effnet_improve_features_only
Support other features only modes for EfficientNet. Fix #1848 fix #1849pull/1866/head
commit
c241081251
|
@ -75,7 +75,7 @@ class ConvBnAct(nn.Module):
|
||||||
if location == 'expansion': # output of conv after act, same as block coutput
|
if location == 'expansion': # output of conv after act, same as block coutput
|
||||||
return dict(module='bn1', hook_type='forward', num_chs=self.conv.out_channels)
|
return dict(module='bn1', hook_type='forward', num_chs=self.conv.out_channels)
|
||||||
else: # location == 'bottleneck', block output
|
else: # location == 'bottleneck', block output
|
||||||
return dict(module='', hook_type='', num_chs=self.conv.out_channels)
|
return dict(module='', num_chs=self.conv.out_channels)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
shortcut = x
|
shortcut = x
|
||||||
|
@ -116,7 +116,7 @@ class DepthwiseSeparableConv(nn.Module):
|
||||||
if location == 'expansion': # after SE, input to PW
|
if location == 'expansion': # after SE, input to PW
|
||||||
return dict(module='conv_pw', hook_type='forward_pre', num_chs=self.conv_pw.in_channels)
|
return dict(module='conv_pw', hook_type='forward_pre', num_chs=self.conv_pw.in_channels)
|
||||||
else: # location == 'bottleneck', block output
|
else: # location == 'bottleneck', block output
|
||||||
return dict(module='', hook_type='', num_chs=self.conv_pw.out_channels)
|
return dict(module='', num_chs=self.conv_pw.out_channels)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
shortcut = x
|
shortcut = x
|
||||||
|
@ -173,7 +173,7 @@ class InvertedResidual(nn.Module):
|
||||||
if location == 'expansion': # after SE, input to PWL
|
if location == 'expansion': # after SE, input to PWL
|
||||||
return dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels)
|
return dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels)
|
||||||
else: # location == 'bottleneck', block output
|
else: # location == 'bottleneck', block output
|
||||||
return dict(module='', hook_type='', num_chs=self.conv_pwl.out_channels)
|
return dict(module='', num_chs=self.conv_pwl.out_channels)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
shortcut = x
|
shortcut = x
|
||||||
|
@ -266,7 +266,7 @@ class EdgeResidual(nn.Module):
|
||||||
if location == 'expansion': # after SE, before PWL
|
if location == 'expansion': # after SE, before PWL
|
||||||
return dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels)
|
return dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels)
|
||||||
else: # location == 'bottleneck', block output
|
else: # location == 'bottleneck', block output
|
||||||
return dict(module='', hook_type='', num_chs=self.conv_pwl.out_channels)
|
return dict(module='', num_chs=self.conv_pwl.out_channels)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
shortcut = x
|
shortcut = x
|
||||||
|
|
|
@ -370,9 +370,7 @@ class EfficientNetBuilder:
|
||||||
stages = []
|
stages = []
|
||||||
if model_block_args[0][0]['stride'] > 1:
|
if model_block_args[0][0]['stride'] > 1:
|
||||||
# if the first block starts with a stride, we need to extract first level feat from stem
|
# if the first block starts with a stride, we need to extract first level feat from stem
|
||||||
feature_info = dict(
|
feature_info = dict(module='bn1', num_chs=in_chs, stage=0, reduction=current_stride)
|
||||||
module='act1', num_chs=in_chs, stage=0, reduction=current_stride,
|
|
||||||
hook_type='forward' if self.feature_location != 'bottleneck' else '')
|
|
||||||
self.features.append(feature_info)
|
self.features.append(feature_info)
|
||||||
|
|
||||||
# outer list of block_args defines the stacks
|
# outer list of block_args defines the stacks
|
||||||
|
@ -418,10 +416,16 @@ class EfficientNetBuilder:
|
||||||
# stash feature module name and channel info for model feature extraction
|
# stash feature module name and channel info for model feature extraction
|
||||||
if extract_features:
|
if extract_features:
|
||||||
feature_info = dict(
|
feature_info = dict(
|
||||||
stage=stack_idx + 1, reduction=current_stride, **block.feature_info(self.feature_location))
|
stage=stack_idx + 1,
|
||||||
module_name = f'blocks.{stack_idx}.{block_idx}'
|
reduction=current_stride,
|
||||||
|
**block.feature_info(self.feature_location),
|
||||||
|
)
|
||||||
leaf_name = feature_info.get('module', '')
|
leaf_name = feature_info.get('module', '')
|
||||||
feature_info['module'] = '.'.join([module_name, leaf_name]) if leaf_name else module_name
|
if leaf_name:
|
||||||
|
feature_info['module'] = '.'.join([f'blocks.{stack_idx}.{block_idx}', leaf_name])
|
||||||
|
else:
|
||||||
|
assert last_block
|
||||||
|
feature_info['module'] = f'blocks.{stack_idx}'
|
||||||
self.features.append(feature_info)
|
self.features.append(feature_info)
|
||||||
|
|
||||||
total_block_idx += 1 # incr global block idx (across all stacks)
|
total_block_idx += 1 # incr global block idx (across all stacks)
|
||||||
|
|
|
@ -27,12 +27,13 @@ class FeatureInfo:
|
||||||
|
|
||||||
def __init__(self, feature_info: List[Dict], out_indices: Tuple[int]):
|
def __init__(self, feature_info: List[Dict], out_indices: Tuple[int]):
|
||||||
prev_reduction = 1
|
prev_reduction = 1
|
||||||
for fi in feature_info:
|
for i, fi in enumerate(feature_info):
|
||||||
# sanity check the mandatory fields, there may be additional fields depending on the model
|
# sanity check the mandatory fields, there may be additional fields depending on the model
|
||||||
assert 'num_chs' in fi and fi['num_chs'] > 0
|
assert 'num_chs' in fi and fi['num_chs'] > 0
|
||||||
assert 'reduction' in fi and fi['reduction'] >= prev_reduction
|
assert 'reduction' in fi and fi['reduction'] >= prev_reduction
|
||||||
prev_reduction = fi['reduction']
|
prev_reduction = fi['reduction']
|
||||||
assert 'module' in fi
|
assert 'module' in fi
|
||||||
|
fi.setdefault('index', i)
|
||||||
self.out_indices = out_indices
|
self.out_indices = out_indices
|
||||||
self.info = feature_info
|
self.info = feature_info
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,7 @@ from typing import Callable, List, Dict, Union, Type
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from ._features import _get_feature_info
|
from ._features import _get_feature_info, _get_return_layers
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from torchvision.models.feature_extraction import create_feature_extractor as _create_feature_extractor
|
from torchvision.models.feature_extraction import create_feature_extractor as _create_feature_extractor
|
||||||
|
@ -93,9 +93,7 @@ class FeatureGraphNet(nn.Module):
|
||||||
self.feature_info = _get_feature_info(model, out_indices)
|
self.feature_info = _get_feature_info(model, out_indices)
|
||||||
if out_map is not None:
|
if out_map is not None:
|
||||||
assert len(out_map) == len(out_indices)
|
assert len(out_map) == len(out_indices)
|
||||||
return_nodes = {
|
return_nodes = _get_return_layers(self.feature_info, out_map)
|
||||||
info['module']: out_map[i] if out_map is not None else info['module']
|
|
||||||
for i, info in enumerate(self.feature_info) if i in out_indices}
|
|
||||||
self.graph_module = create_feature_extractor(model, return_nodes)
|
self.graph_module = create_feature_extractor(model, return_nodes)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
|
|
@ -232,7 +232,7 @@ class EfficientNetFeatures(nn.Module):
|
||||||
)
|
)
|
||||||
self.blocks = nn.Sequential(*builder(stem_size, block_args))
|
self.blocks = nn.Sequential(*builder(stem_size, block_args))
|
||||||
self.feature_info = FeatureInfo(builder.features, out_indices)
|
self.feature_info = FeatureInfo(builder.features, out_indices)
|
||||||
self._stage_out_idx = {v['stage']: i for i, v in enumerate(self.feature_info) if i in out_indices}
|
self._stage_out_idx = {f['stage']: f['index'] for f in self.feature_info.get_dicts()}
|
||||||
|
|
||||||
efficientnet_init_weights(self)
|
efficientnet_init_weights(self)
|
||||||
|
|
||||||
|
@ -268,20 +268,28 @@ class EfficientNetFeatures(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
def _create_effnet(variant, pretrained=False, **kwargs):
|
def _create_effnet(variant, pretrained=False, **kwargs):
|
||||||
features_only = False
|
features_mode = ''
|
||||||
model_cls = EfficientNet
|
model_cls = EfficientNet
|
||||||
kwargs_filter = None
|
kwargs_filter = None
|
||||||
if kwargs.pop('features_only', False):
|
if kwargs.pop('features_only', False):
|
||||||
features_only = True
|
if 'feature_cfg' in kwargs:
|
||||||
|
features_mode = 'cfg'
|
||||||
|
else:
|
||||||
kwargs_filter = ('num_classes', 'num_features', 'head_conv', 'global_pool')
|
kwargs_filter = ('num_classes', 'num_features', 'head_conv', 'global_pool')
|
||||||
model_cls = EfficientNetFeatures
|
model_cls = EfficientNetFeatures
|
||||||
|
features_mode = 'cls'
|
||||||
|
|
||||||
model = build_model_with_cfg(
|
model = build_model_with_cfg(
|
||||||
model_cls, variant, pretrained,
|
model_cls,
|
||||||
pretrained_strict=not features_only,
|
variant,
|
||||||
|
pretrained,
|
||||||
|
features_only=features_mode == 'cfg',
|
||||||
|
pretrained_strict=features_mode != 'cls',
|
||||||
kwargs_filter=kwargs_filter,
|
kwargs_filter=kwargs_filter,
|
||||||
**kwargs)
|
**kwargs,
|
||||||
if features_only:
|
)
|
||||||
model.default_cfg = pretrained_cfg_for_features(model.default_cfg)
|
if features_mode == 'cls':
|
||||||
|
model.pretrained_cfg = model.default_cfg = pretrained_cfg_for_features(model.pretrained_cfg)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -829,7 +829,7 @@ class HighResolutionNetFeatures(HighResolutionNet):
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
self.feature_info = FeatureInfo(self.feature_info, out_indices)
|
self.feature_info = FeatureInfo(self.feature_info, out_indices)
|
||||||
self._out_idx = {i for i in out_indices}
|
self._out_idx = {f['index'] for f in self.feature_info.get_dicts()}
|
||||||
|
|
||||||
def forward_features(self, x):
|
def forward_features(self, x):
|
||||||
assert False, 'Not supported'
|
assert False, 'Not supported'
|
||||||
|
|
|
@ -210,7 +210,7 @@ class MobileNetV3Features(nn.Module):
|
||||||
)
|
)
|
||||||
self.blocks = nn.Sequential(*builder(stem_size, block_args))
|
self.blocks = nn.Sequential(*builder(stem_size, block_args))
|
||||||
self.feature_info = FeatureInfo(builder.features, out_indices)
|
self.feature_info = FeatureInfo(builder.features, out_indices)
|
||||||
self._stage_out_idx = {v['stage']: i for i, v in enumerate(self.feature_info) if i in out_indices}
|
self._stage_out_idx = {f['stage']: f['index'] for f in self.feature_info.get_dicts()}
|
||||||
|
|
||||||
efficientnet_init_weights(self)
|
efficientnet_init_weights(self)
|
||||||
|
|
||||||
|
@ -247,21 +247,27 @@ class MobileNetV3Features(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
def _create_mnv3(variant, pretrained=False, **kwargs):
|
def _create_mnv3(variant, pretrained=False, **kwargs):
|
||||||
features_only = False
|
features_mode = ''
|
||||||
model_cls = MobileNetV3
|
model_cls = MobileNetV3
|
||||||
kwargs_filter = None
|
kwargs_filter = None
|
||||||
if kwargs.pop('features_only', False):
|
if kwargs.pop('features_only', False):
|
||||||
features_only = True
|
if 'feature_cfg' in kwargs:
|
||||||
|
features_mode = 'cfg'
|
||||||
|
else:
|
||||||
kwargs_filter = ('num_classes', 'num_features', 'head_conv', 'head_bias', 'global_pool')
|
kwargs_filter = ('num_classes', 'num_features', 'head_conv', 'head_bias', 'global_pool')
|
||||||
model_cls = MobileNetV3Features
|
model_cls = MobileNetV3Features
|
||||||
|
features_mode = 'cls'
|
||||||
|
|
||||||
model = build_model_with_cfg(
|
model = build_model_with_cfg(
|
||||||
model_cls,
|
model_cls,
|
||||||
variant,
|
variant,
|
||||||
pretrained,
|
pretrained,
|
||||||
pretrained_strict=not features_only,
|
features_only=features_mode == 'cfg',
|
||||||
|
pretrained_strict=features_mode != 'cls',
|
||||||
kwargs_filter=kwargs_filter,
|
kwargs_filter=kwargs_filter,
|
||||||
**kwargs)
|
**kwargs,
|
||||||
if features_only:
|
)
|
||||||
|
if features_mode == 'cls':
|
||||||
model.default_cfg = pretrained_cfg_for_features(model.default_cfg)
|
model.default_cfg = pretrained_cfg_for_features(model.default_cfg)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue