Merge changes in feature extraction interface to MobileNetV3
Experimental feature extraction interface seems to be changed a little bit with the most up to date version apparently found in EfficientNet class. Here these changes are added to MobileNetV3 class to make it support it and work again, too.pull/123/head
parent
13cf68850b
commit
bdb165a8a4
|
@ -148,7 +148,7 @@ class MobileNetV3Features(nn.Module):
|
|||
and object detection models.
|
||||
"""
|
||||
|
||||
def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='pre_pwl',
|
||||
def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck',
|
||||
in_chans=3, stem_size=16, channel_multiplier=1.0, output_stride=32, pad_type='',
|
||||
act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0., se_kwargs=None,
|
||||
norm_layer=nn.BatchNorm2d, norm_kwargs=None):
|
||||
|
@ -174,18 +174,23 @@ class MobileNetV3Features(nn.Module):
|
|||
channel_multiplier, 8, None, output_stride, pad_type, act_layer, se_kwargs,
|
||||
norm_layer, norm_kwargs, drop_path_rate, feature_location=feature_location, verbose=_DEBUG)
|
||||
self.blocks = nn.Sequential(*builder(self._in_chs, block_args))
|
||||
self.feature_info = builder.features # builder provides info about feature channels for each block
|
||||
self._feature_info = builder.features # builder provides info about feature channels for each block
|
||||
self._stage_to_feature_idx = {
|
||||
v['stage_idx']: fi for fi, v in self._feature_info.items() if fi in self.out_indices}
|
||||
self._in_chs = builder.in_chs
|
||||
|
||||
efficientnet_init_weights(self)
|
||||
if _DEBUG:
|
||||
for k, v in self.feature_info.items():
|
||||
for k, v in self._feature_info.items():
|
||||
print('Feature idx: {}: Name: {}, Channels: {}'.format(k, v['name'], v['num_chs']))
|
||||
|
||||
# Register feature extraction hooks with FeatureHooks helper
|
||||
hook_type = 'forward_pre' if feature_location == 'pre_pwl' else 'forward'
|
||||
hooks = [dict(name=self.feature_info[idx]['name'], type=hook_type) for idx in out_indices]
|
||||
self.feature_hooks = FeatureHooks(hooks, self.named_modules())
|
||||
self.feature_hooks = None
|
||||
if feature_location != 'bottleneck':
|
||||
hooks = [dict(
|
||||
name=self._feature_info[idx]['module'],
|
||||
type=self._feature_info[idx]['hook_type']) for idx in out_indices]
|
||||
self.feature_hooks = FeatureHooks(hooks, self.named_modules())
|
||||
|
||||
def feature_channels(self, idx=None):
|
||||
""" Feature Channel Shortcut
|
||||
|
@ -193,15 +198,23 @@ class MobileNetV3Features(nn.Module):
|
|||
return feature channel count for that feature block index (independent of out_indices setting).
|
||||
"""
|
||||
if isinstance(idx, int):
|
||||
return self.feature_info[idx]['num_chs']
|
||||
return [self.feature_info[i]['num_chs'] for i in self.out_indices]
|
||||
return self._feature_info[idx]['num_chs']
|
||||
return [self._feature_info[i]['num_chs'] for i in self.out_indices]
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_stem(x)
|
||||
x = self.bn1(x)
|
||||
x = self.act1(x)
|
||||
self.blocks(x)
|
||||
return self.feature_hooks.get_output(x.device)
|
||||
if self.feature_hooks is None:
|
||||
features = []
|
||||
for i, b in enumerate(self.blocks):
|
||||
x = b(x)
|
||||
if i in self._stage_to_feature_idx:
|
||||
features.append(x)
|
||||
return features
|
||||
else:
|
||||
self.blocks(x)
|
||||
return self.feature_hooks.get_output(x.device)
|
||||
|
||||
|
||||
def _create_model(model_kwargs, default_cfg, pretrained=False):
|
||||
|
|
Loading…
Reference in New Issue