mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Cleanup EfficientNet/MobileNetV3 feature extraction a bit, only two tap locations now, small mobilenetv3 models work
This commit is contained in:
parent
68fd8a267b
commit
c146b54abc
@ -416,12 +416,6 @@ class EfficientNetFeatures(nn.Module):
|
||||
se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None):
|
||||
super(EfficientNetFeatures, self).__init__()
|
||||
norm_kwargs = norm_kwargs or {}
|
||||
|
||||
# TODO only create stages needed, currently all stages are created regardless of out_indices
|
||||
num_stages = max(out_indices) + 1
|
||||
|
||||
self.out_indices = out_indices
|
||||
self.feature_location = feature_location
|
||||
self.drop_rate = drop_rate
|
||||
self._in_chs = in_chans
|
||||
|
||||
@ -439,14 +433,10 @@ class EfficientNetFeatures(nn.Module):
|
||||
norm_layer, norm_kwargs, drop_path_rate, feature_location=feature_location, verbose=_DEBUG)
|
||||
self.blocks = nn.Sequential(*builder(self._in_chs, block_args))
|
||||
self.feature_info = FeatureInfo(builder.features, out_indices)
|
||||
self._stage_to_feature_idx = {
|
||||
v['stage_idx']: fi for fi, v in enumerate(self.feature_info) if fi in self.out_indices}
|
||||
self._stage_out_idx = {v['stage']: i for i, v in enumerate(self.feature_info) if i in out_indices}
|
||||
self._in_chs = builder.in_chs
|
||||
|
||||
efficientnet_init_weights(self)
|
||||
if _DEBUG:
|
||||
for fi, v in enumerate(self.feature_info):
|
||||
print('Feature idx: {}: Name: {}, Channels: {}'.format(fi, v['module'], v['num_chs']))
|
||||
|
||||
# Register feature extraction hooks with FeatureHooks helper
|
||||
self.feature_hooks = None
|
||||
@ -460,14 +450,17 @@ class EfficientNetFeatures(nn.Module):
|
||||
x = self.act1(x)
|
||||
if self.feature_hooks is None:
|
||||
features = []
|
||||
if 0 in self._stage_out_idx:
|
||||
features.append(x) # add stem out
|
||||
for i, b in enumerate(self.blocks):
|
||||
x = b(x)
|
||||
if i in self._stage_to_feature_idx:
|
||||
if i + 1 in self._stage_out_idx:
|
||||
features.append(x)
|
||||
return features
|
||||
else:
|
||||
self.blocks(x)
|
||||
return self.feature_hooks.get_output(x.device)
|
||||
out = self.feature_hooks.get_output(x.device)
|
||||
return list(out.values())
|
||||
|
||||
|
||||
def _create_effnet(model_kwargs, variant, pretrained=False):
|
||||
|
@ -128,10 +128,9 @@ class ConvBnAct(nn.Module):
|
||||
self.act1 = act_layer(inplace=True)
|
||||
|
||||
def feature_info(self, location):
|
||||
if location == 'expansion' or location == 'depthwise':
|
||||
# no expansion or depthwise this block, use act after conv
|
||||
if location == 'expansion': # output of conv after act, same as block coutput
|
||||
info = dict(module='act1', hook_type='forward', num_chs=self.conv.out_channels)
|
||||
else: # location == 'bottleneck'
|
||||
else: # location == 'bottleneck', block output
|
||||
info = dict(module='', hook_type='', num_chs=self.conv.out_channels)
|
||||
return info
|
||||
|
||||
@ -175,12 +174,9 @@ class DepthwiseSeparableConv(nn.Module):
|
||||
self.act2 = act_layer(inplace=True) if self.has_pw_act else nn.Identity()
|
||||
|
||||
def feature_info(self, location):
|
||||
if location == 'expansion':
|
||||
# no expansion in this block, use depthwise, before SE
|
||||
info = dict(module='act1', hook_type='forward', num_chs=self.conv_pw.in_channels)
|
||||
elif location == 'depthwise': # after SE
|
||||
if location == 'expansion': # after SE, input to PW
|
||||
info = dict(module='conv_pw', hook_type='forward_pre', num_chs=self.conv_pw.in_channels)
|
||||
else: # location == 'bottleneck'
|
||||
else: # location == 'bottleneck', block output
|
||||
info = dict(module='', hook_type='', num_chs=self.conv_pw.out_channels)
|
||||
return info
|
||||
|
||||
@ -245,11 +241,9 @@ class InvertedResidual(nn.Module):
|
||||
self.bn3 = norm_layer(out_chs, **norm_kwargs)
|
||||
|
||||
def feature_info(self, location):
|
||||
if location == 'expansion':
|
||||
info = dict(module='act1', hook_type='forward', num_chs=self.conv_pw.in_channels)
|
||||
elif location == 'depthwise': # after SE
|
||||
if location == 'expansion': # after SE, input to PWL
|
||||
info = dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels)
|
||||
else: # location == 'bottleneck'
|
||||
else: # location == 'bottleneck', block output
|
||||
info = dict(module='', hook_type='', num_chs=self.conv_pwl.out_channels)
|
||||
return info
|
||||
|
||||
@ -370,12 +364,9 @@ class EdgeResidual(nn.Module):
|
||||
self.bn2 = norm_layer(out_chs, **norm_kwargs)
|
||||
|
||||
def feature_info(self, location):
|
||||
if location == 'expansion':
|
||||
info = dict(module='act1', hook_type='forward', num_chs=self.conv_exp.out_channels)
|
||||
elif location == 'depthwise':
|
||||
# there is no depthwise, take after SE, before PWL
|
||||
if location == 'expansion': # after SE, before PWL
|
||||
info = dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels)
|
||||
else: # location == 'bottleneck'
|
||||
else: # location == 'bottleneck', block output
|
||||
info = dict(module='', hook_type='', num_chs=self.conv_pwl.out_channels)
|
||||
return info
|
||||
|
||||
|
@ -1,7 +1,6 @@
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
from copy import deepcopy
|
||||
|
||||
import torch.nn as nn
|
||||
@ -12,6 +11,11 @@ from .layers import CondConv2d, get_condconv_initializer
|
||||
__all__ = ["EfficientNetBuilder", "decode_arch_def", "efficientnet_init_weights"]
|
||||
|
||||
|
||||
def _log_info_if(msg, condition):
|
||||
if condition:
|
||||
logging.info(msg)
|
||||
|
||||
|
||||
def _parse_ksize(ss):
|
||||
if ss.isdigit():
|
||||
return int(ss)
|
||||
@ -219,8 +223,12 @@ class EfficientNetBuilder:
|
||||
self.norm_layer = norm_layer
|
||||
self.norm_kwargs = norm_kwargs
|
||||
self.drop_path_rate = drop_path_rate
|
||||
if feature_location == 'depthwise':
|
||||
# old 'depthwise' mode renamed 'expansion' to match TF impl, old expansion mode didn't make sense
|
||||
logging.warning("feature_location=='depthwise' is deprecated, using 'expansion'")
|
||||
feature_location = 'expansion'
|
||||
self.feature_location = feature_location
|
||||
assert feature_location in ('bottleneck', 'depthwise', 'expansion', '')
|
||||
assert feature_location in ('bottleneck', 'expansion', '')
|
||||
self.verbose = verbose
|
||||
|
||||
# state updated during build, consumed by model
|
||||
@ -247,8 +255,7 @@ class EfficientNetBuilder:
|
||||
if bt == 'ir':
|
||||
ba['drop_path_rate'] = drop_path_rate
|
||||
ba['se_kwargs'] = self.se_kwargs
|
||||
if self.verbose:
|
||||
logging.info(' InvertedResidual {}, Args: {}'.format(block_idx, str(ba)))
|
||||
_log_info_if(' InvertedResidual {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
|
||||
if ba.get('num_experts', 0) > 0:
|
||||
block = CondConvResidual(**ba)
|
||||
else:
|
||||
@ -256,18 +263,15 @@ class EfficientNetBuilder:
|
||||
elif bt == 'ds' or bt == 'dsa':
|
||||
ba['drop_path_rate'] = drop_path_rate
|
||||
ba['se_kwargs'] = self.se_kwargs
|
||||
if self.verbose:
|
||||
logging.info(' DepthwiseSeparable {}, Args: {}'.format(block_idx, str(ba)))
|
||||
_log_info_if(' DepthwiseSeparable {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
|
||||
block = DepthwiseSeparableConv(**ba)
|
||||
elif bt == 'er':
|
||||
ba['drop_path_rate'] = drop_path_rate
|
||||
ba['se_kwargs'] = self.se_kwargs
|
||||
if self.verbose:
|
||||
logging.info(' EdgeResidual {}, Args: {}'.format(block_idx, str(ba)))
|
||||
_log_info_if(' EdgeResidual {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
|
||||
block = EdgeResidual(**ba)
|
||||
elif bt == 'cn':
|
||||
if self.verbose:
|
||||
logging.info(' ConvBnAct {}, Args: {}'.format(block_idx, str(ba)))
|
||||
_log_info_if(' ConvBnAct {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
|
||||
block = ConvBnAct(**ba)
|
||||
else:
|
||||
assert False, 'Uknkown block type (%s) while building model.' % bt
|
||||
@ -279,64 +283,55 @@ class EfficientNetBuilder:
|
||||
""" Build the blocks
|
||||
Args:
|
||||
in_chs: Number of input-channels passed to first block
|
||||
model_block_args: A list of lists, outer list defines stages, inner
|
||||
model_block_args: A list of lists, outer list defines stacks (block stages), inner
|
||||
list contains strings defining block configuration(s)
|
||||
Return:
|
||||
List of block stacks (each stack wrapped in nn.Sequential)
|
||||
"""
|
||||
if self.verbose:
|
||||
logging.info('Building model trunk with %d stages...' % len(model_block_args))
|
||||
_log_info_if('Building model trunk with %d stages...' % len(model_block_args), self.verbose)
|
||||
self.in_chs = in_chs
|
||||
total_block_count = sum([len(x) for x in model_block_args])
|
||||
total_block_idx = 0
|
||||
current_stride = 2
|
||||
current_dilation = 1
|
||||
stages = []
|
||||
# outer list of block_args defines the stacks ('stages' by some conventions)
|
||||
for stage_idx, stage_block_args in enumerate(model_block_args):
|
||||
last_stack = stage_idx == (len(model_block_args) - 1)
|
||||
if self.verbose:
|
||||
logging.info('Stack: {}'.format(stage_idx))
|
||||
assert isinstance(stage_block_args, list)
|
||||
if model_block_args[0][0]['stride'] > 1:
|
||||
# if the first block starts with a stride, we need to extract first level feat from stem
|
||||
feature_info = dict(
|
||||
module='act1', num_chs=in_chs, stage=0, reduction=current_stride,
|
||||
hook_type='forward' if self.feature_location != 'bottleneck' else '')
|
||||
self.features.append(feature_info)
|
||||
|
||||
# outer list of block_args defines the stacks
|
||||
for stack_idx, stack_args in enumerate(model_block_args):
|
||||
last_stack = stack_idx + 1 == len(model_block_args)
|
||||
_log_info_if('Stack: {}'.format(stack_idx), self.verbose)
|
||||
assert isinstance(stack_args, list)
|
||||
|
||||
blocks = []
|
||||
# each stack (stage) contains a list of block arguments
|
||||
for block_idx, block_args in enumerate(stage_block_args):
|
||||
last_block = block_idx == (len(stage_block_args) - 1)
|
||||
extract_features = '' # No features extracted
|
||||
if self.verbose:
|
||||
logging.info(' Block: {}'.format(block_idx))
|
||||
# each stack (stage of blocks) contains a list of block arguments
|
||||
for block_idx, block_args in enumerate(stack_args):
|
||||
last_block = block_idx + 1 == len(stack_args)
|
||||
_log_info_if(' Block: {}'.format(block_idx), self.verbose)
|
||||
|
||||
# Sort out stride, dilation, and feature extraction details
|
||||
assert block_args['stride'] in (1, 2)
|
||||
if block_idx >= 1:
|
||||
# only the first block in any stack can have a stride > 1
|
||||
if block_idx >= 1: # only the first block in any stack can have a stride > 1
|
||||
block_args['stride'] = 1
|
||||
|
||||
do_extract = False
|
||||
if self.feature_location == 'bottleneck' or self.feature_location == 'depthwise':
|
||||
if last_block:
|
||||
next_stage_idx = stage_idx + 1
|
||||
if next_stage_idx >= len(model_block_args):
|
||||
do_extract = True
|
||||
else:
|
||||
do_extract = model_block_args[next_stage_idx][0]['stride'] > 1
|
||||
elif self.feature_location == 'expansion':
|
||||
if block_args['stride'] > 1 or (last_stack and last_block):
|
||||
do_extract = True
|
||||
if do_extract:
|
||||
extract_features = self.feature_location
|
||||
extract_features = False
|
||||
if last_block:
|
||||
next_stack_idx = stack_idx + 1
|
||||
extract_features = next_stack_idx >= len(model_block_args) or \
|
||||
model_block_args[next_stack_idx][0]['stride'] > 1
|
||||
|
||||
next_dilation = current_dilation
|
||||
next_output_stride = current_stride
|
||||
if block_args['stride'] > 1:
|
||||
next_output_stride = current_stride * block_args['stride']
|
||||
if next_output_stride > self.output_stride:
|
||||
next_dilation = current_dilation * block_args['stride']
|
||||
block_args['stride'] = 1
|
||||
if self.verbose:
|
||||
logging.info(' Converting stride to dilation to maintain output_stride=={}'.format(
|
||||
self.output_stride))
|
||||
_log_info_if(' Converting stride to dilation to maintain output_stride=={}'.format(
|
||||
self.output_stride), self.verbose)
|
||||
else:
|
||||
current_stride = next_output_stride
|
||||
block_args['dilation'] = current_dilation
|
||||
@ -349,15 +344,11 @@ class EfficientNetBuilder:
|
||||
|
||||
# stash feature module name and channel info for model feature extraction
|
||||
if extract_features:
|
||||
feature_info = block.feature_info(extract_features)
|
||||
module_name = f'blocks.{stage_idx}.{block_idx}'
|
||||
if 'module' in feature_info and feature_info['module']:
|
||||
feature_info['module'] = '.'.join([module_name, feature_info['module']])
|
||||
else:
|
||||
feature_info['module'] = module_name
|
||||
feature_info['stage_idx'] = stage_idx
|
||||
feature_info['block_idx'] = block_idx
|
||||
feature_info['reduction'] = current_stride
|
||||
feature_info = dict(
|
||||
stage=stack_idx + 1, reduction=current_stride, **block.feature_info(self.feature_location))
|
||||
module_name = f'blocks.{stack_idx}.{block_idx}'
|
||||
leaf_name = feature_info.get('module', '')
|
||||
feature_info['module'] = '.'.join([module_name, leaf_name]) if leaf_name else module_name
|
||||
self.features.append(feature_info)
|
||||
|
||||
total_block_idx += 1 # incr global block idx (across all stacks)
|
||||
|
@ -162,11 +162,6 @@ class MobileNetV3Features(nn.Module):
|
||||
norm_layer=nn.BatchNorm2d, norm_kwargs=None):
|
||||
super(MobileNetV3Features, self).__init__()
|
||||
norm_kwargs = norm_kwargs or {}
|
||||
|
||||
# TODO only create stages needed, currently all stages are created regardless of out_indices
|
||||
num_stages = max(out_indices) + 1
|
||||
|
||||
self.out_indices = out_indices
|
||||
self.drop_rate = drop_rate
|
||||
self._in_chs = in_chans
|
||||
|
||||
@ -183,14 +178,10 @@ class MobileNetV3Features(nn.Module):
|
||||
norm_layer, norm_kwargs, drop_path_rate, feature_location=feature_location, verbose=_DEBUG)
|
||||
self.blocks = nn.Sequential(*builder(self._in_chs, block_args))
|
||||
self.feature_info = FeatureInfo(builder.features, out_indices)
|
||||
self._stage_to_feature_idx = {
|
||||
v['stage_idx']: fi for fi, v in enumerate(self.feature_info) if fi in self.out_indices}
|
||||
self._stage_out_idx = {v['stage']: i for i, v in enumerate(self.feature_info) if i in out_indices}
|
||||
self._in_chs = builder.in_chs
|
||||
|
||||
efficientnet_init_weights(self)
|
||||
if _DEBUG:
|
||||
for fi, v in enumerate(self.feature_info):
|
||||
print('Feature idx: {}: Name: {}, Channels: {}'.format(fi, v['module'], v['num_chs']))
|
||||
|
||||
# Register feature extraction hooks with FeatureHooks helper
|
||||
self.feature_hooks = None
|
||||
@ -204,9 +195,11 @@ class MobileNetV3Features(nn.Module):
|
||||
x = self.act1(x)
|
||||
if self.feature_hooks is None:
|
||||
features = []
|
||||
if 0 in self._stage_out_idx:
|
||||
features.append(x) # add stem out
|
||||
for i, b in enumerate(self.blocks):
|
||||
x = b(x)
|
||||
if i in self._stage_to_feature_idx:
|
||||
if i + 1 in self._stage_out_idx:
|
||||
features.append(x)
|
||||
return features
|
||||
else:
|
||||
|
Loading…
x
Reference in New Issue
Block a user