mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge pull request #123 from aclex/mobilenetv3_fix_feature_extraction
Merge changes in feature extraction interface to MobileNetV3
This commit is contained in:
commit
e15f979457
@ -148,7 +148,7 @@ class MobileNetV3Features(nn.Module):
|
||||
and object detection models.
|
||||
"""
|
||||
|
||||
def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='pre_pwl',
|
||||
def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck',
|
||||
in_chans=3, stem_size=16, channel_multiplier=1.0, output_stride=32, pad_type='',
|
||||
act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0., se_kwargs=None,
|
||||
norm_layer=nn.BatchNorm2d, norm_kwargs=None):
|
||||
@ -174,18 +174,23 @@ class MobileNetV3Features(nn.Module):
|
||||
channel_multiplier, 8, None, output_stride, pad_type, act_layer, se_kwargs,
|
||||
norm_layer, norm_kwargs, drop_path_rate, feature_location=feature_location, verbose=_DEBUG)
|
||||
self.blocks = nn.Sequential(*builder(self._in_chs, block_args))
|
||||
self.feature_info = builder.features # builder provides info about feature channels for each block
|
||||
self._feature_info = builder.features # builder provides info about feature channels for each block
|
||||
self._stage_to_feature_idx = {
|
||||
v['stage_idx']: fi for fi, v in self._feature_info.items() if fi in self.out_indices}
|
||||
self._in_chs = builder.in_chs
|
||||
|
||||
efficientnet_init_weights(self)
|
||||
if _DEBUG:
|
||||
for k, v in self.feature_info.items():
|
||||
for k, v in self._feature_info.items():
|
||||
print('Feature idx: {}: Name: {}, Channels: {}'.format(k, v['name'], v['num_chs']))
|
||||
|
||||
# Register feature extraction hooks with FeatureHooks helper
|
||||
hook_type = 'forward_pre' if feature_location == 'pre_pwl' else 'forward'
|
||||
hooks = [dict(name=self.feature_info[idx]['name'], type=hook_type) for idx in out_indices]
|
||||
self.feature_hooks = FeatureHooks(hooks, self.named_modules())
|
||||
self.feature_hooks = None
|
||||
if feature_location != 'bottleneck':
|
||||
hooks = [dict(
|
||||
name=self._feature_info[idx]['module'],
|
||||
type=self._feature_info[idx]['hook_type']) for idx in out_indices]
|
||||
self.feature_hooks = FeatureHooks(hooks, self.named_modules())
|
||||
|
||||
def feature_channels(self, idx=None):
|
||||
""" Feature Channel Shortcut
|
||||
@ -193,15 +198,23 @@ class MobileNetV3Features(nn.Module):
|
||||
return feature channel count for that feature block index (independent of out_indices setting).
|
||||
"""
|
||||
if isinstance(idx, int):
|
||||
return self.feature_info[idx]['num_chs']
|
||||
return [self.feature_info[i]['num_chs'] for i in self.out_indices]
|
||||
return self._feature_info[idx]['num_chs']
|
||||
return [self._feature_info[i]['num_chs'] for i in self.out_indices]
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_stem(x)
|
||||
x = self.bn1(x)
|
||||
x = self.act1(x)
|
||||
self.blocks(x)
|
||||
return self.feature_hooks.get_output(x.device)
|
||||
if self.feature_hooks is None:
|
||||
features = []
|
||||
for i, b in enumerate(self.blocks):
|
||||
x = b(x)
|
||||
if i in self._stage_to_feature_idx:
|
||||
features.append(x)
|
||||
return features
|
||||
else:
|
||||
self.blocks(x)
|
||||
return self.feature_hooks.get_output(x.device)
|
||||
|
||||
|
||||
def _create_model(model_kwargs, default_cfg, pretrained=False):
|
||||
|
Loading…
x
Reference in New Issue
Block a user