mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge changes in feature extraction interface to MobileNetV3
Experimental feature extraction interface seems to be changed a little bit with the most up to date version apparently found in EfficientNet class. Here these changes are added to MobileNetV3 class to make it support it and work again, too.
This commit is contained in:
parent
13cf68850b
commit
bdb165a8a4
@ -148,7 +148,7 @@ class MobileNetV3Features(nn.Module):
|
|||||||
and object detection models.
|
and object detection models.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='pre_pwl',
|
def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck',
|
||||||
in_chans=3, stem_size=16, channel_multiplier=1.0, output_stride=32, pad_type='',
|
in_chans=3, stem_size=16, channel_multiplier=1.0, output_stride=32, pad_type='',
|
||||||
act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0., se_kwargs=None,
|
act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0., se_kwargs=None,
|
||||||
norm_layer=nn.BatchNorm2d, norm_kwargs=None):
|
norm_layer=nn.BatchNorm2d, norm_kwargs=None):
|
||||||
@ -174,18 +174,23 @@ class MobileNetV3Features(nn.Module):
|
|||||||
channel_multiplier, 8, None, output_stride, pad_type, act_layer, se_kwargs,
|
channel_multiplier, 8, None, output_stride, pad_type, act_layer, se_kwargs,
|
||||||
norm_layer, norm_kwargs, drop_path_rate, feature_location=feature_location, verbose=_DEBUG)
|
norm_layer, norm_kwargs, drop_path_rate, feature_location=feature_location, verbose=_DEBUG)
|
||||||
self.blocks = nn.Sequential(*builder(self._in_chs, block_args))
|
self.blocks = nn.Sequential(*builder(self._in_chs, block_args))
|
||||||
self.feature_info = builder.features # builder provides info about feature channels for each block
|
self._feature_info = builder.features # builder provides info about feature channels for each block
|
||||||
|
self._stage_to_feature_idx = {
|
||||||
|
v['stage_idx']: fi for fi, v in self._feature_info.items() if fi in self.out_indices}
|
||||||
self._in_chs = builder.in_chs
|
self._in_chs = builder.in_chs
|
||||||
|
|
||||||
efficientnet_init_weights(self)
|
efficientnet_init_weights(self)
|
||||||
if _DEBUG:
|
if _DEBUG:
|
||||||
for k, v in self.feature_info.items():
|
for k, v in self._feature_info.items():
|
||||||
print('Feature idx: {}: Name: {}, Channels: {}'.format(k, v['name'], v['num_chs']))
|
print('Feature idx: {}: Name: {}, Channels: {}'.format(k, v['name'], v['num_chs']))
|
||||||
|
|
||||||
# Register feature extraction hooks with FeatureHooks helper
|
# Register feature extraction hooks with FeatureHooks helper
|
||||||
hook_type = 'forward_pre' if feature_location == 'pre_pwl' else 'forward'
|
self.feature_hooks = None
|
||||||
hooks = [dict(name=self.feature_info[idx]['name'], type=hook_type) for idx in out_indices]
|
if feature_location != 'bottleneck':
|
||||||
self.feature_hooks = FeatureHooks(hooks, self.named_modules())
|
hooks = [dict(
|
||||||
|
name=self._feature_info[idx]['module'],
|
||||||
|
type=self._feature_info[idx]['hook_type']) for idx in out_indices]
|
||||||
|
self.feature_hooks = FeatureHooks(hooks, self.named_modules())
|
||||||
|
|
||||||
def feature_channels(self, idx=None):
|
def feature_channels(self, idx=None):
|
||||||
""" Feature Channel Shortcut
|
""" Feature Channel Shortcut
|
||||||
@ -193,15 +198,23 @@ class MobileNetV3Features(nn.Module):
|
|||||||
return feature channel count for that feature block index (independent of out_indices setting).
|
return feature channel count for that feature block index (independent of out_indices setting).
|
||||||
"""
|
"""
|
||||||
if isinstance(idx, int):
|
if isinstance(idx, int):
|
||||||
return self.feature_info[idx]['num_chs']
|
return self._feature_info[idx]['num_chs']
|
||||||
return [self.feature_info[i]['num_chs'] for i in self.out_indices]
|
return [self._feature_info[i]['num_chs'] for i in self.out_indices]
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.conv_stem(x)
|
x = self.conv_stem(x)
|
||||||
x = self.bn1(x)
|
x = self.bn1(x)
|
||||||
x = self.act1(x)
|
x = self.act1(x)
|
||||||
self.blocks(x)
|
if self.feature_hooks is None:
|
||||||
return self.feature_hooks.get_output(x.device)
|
features = []
|
||||||
|
for i, b in enumerate(self.blocks):
|
||||||
|
x = b(x)
|
||||||
|
if i in self._stage_to_feature_idx:
|
||||||
|
features.append(x)
|
||||||
|
return features
|
||||||
|
else:
|
||||||
|
self.blocks(x)
|
||||||
|
return self.feature_hooks.get_output(x.device)
|
||||||
|
|
||||||
|
|
||||||
def _create_model(model_kwargs, default_cfg, pretrained=False):
|
def _create_model(model_kwargs, default_cfg, pretrained=False):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user