mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Update EfficientNet feature extraction for EfficientDet. Add needed MaxPoolSame as well.
This commit is contained in:
parent
e01ccb88ce
commit
1a8f5900ab
@ -326,7 +326,6 @@ class EfficientNet(nn.Module):
|
||||
# Stem
|
||||
if not fix_stem:
|
||||
stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min)
|
||||
print(stem_size)
|
||||
self.conv_stem = create_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type)
|
||||
self.bn1 = norm_layer(stem_size, **norm_kwargs)
|
||||
self.act1 = act_layer(inplace=True)
|
||||
@ -393,7 +392,7 @@ class EfficientNetFeatures(nn.Module):
|
||||
and object detection models.
|
||||
"""
|
||||
|
||||
def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='pre_pwl',
|
||||
def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck',
|
||||
in_chans=3, stem_size=32, channel_multiplier=1.0, channel_divisor=8, channel_min=None,
|
||||
output_stride=32, pad_type='', fix_stem=False, act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0.,
|
||||
se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None):
|
||||
@ -404,6 +403,7 @@ class EfficientNetFeatures(nn.Module):
|
||||
num_stages = max(out_indices) + 1
|
||||
|
||||
self.out_indices = out_indices
|
||||
self.feature_location = feature_location
|
||||
self.drop_rate = drop_rate
|
||||
self._in_chs = in_chans
|
||||
|
||||
@ -420,18 +420,23 @@ class EfficientNetFeatures(nn.Module):
|
||||
channel_multiplier, channel_divisor, channel_min, output_stride, pad_type, act_layer, se_kwargs,
|
||||
norm_layer, norm_kwargs, drop_path_rate, feature_location=feature_location, verbose=_DEBUG)
|
||||
self.blocks = nn.Sequential(*builder(self._in_chs, block_args))
|
||||
self.feature_info = builder.features # builder provides info about feature channels for each block
|
||||
self._feature_info = builder.features # builder provides info about feature channels for each block
|
||||
self._stage_to_feature_idx = {
|
||||
v['stage_idx']: fi for fi, v in self._feature_info.items() if fi in self.out_indices}
|
||||
self._in_chs = builder.in_chs
|
||||
|
||||
efficientnet_init_weights(self)
|
||||
if _DEBUG:
|
||||
for k, v in self.feature_info.items():
|
||||
for k, v in self._feature_info.items():
|
||||
print('Feature idx: {}: Name: {}, Channels: {}'.format(k, v['name'], v['num_chs']))
|
||||
|
||||
# Register feature extraction hooks with FeatureHooks helper
|
||||
hook_type = 'forward_pre' if feature_location == 'pre_pwl' else 'forward'
|
||||
hooks = [dict(name=self.feature_info[idx]['name'], type=hook_type) for idx in out_indices]
|
||||
self.feature_hooks = FeatureHooks(hooks, self.named_modules())
|
||||
self.feature_hooks = None
|
||||
if feature_location != 'bottleneck':
|
||||
hooks = [dict(
|
||||
name=self._feature_info[idx]['module'],
|
||||
type=self._feature_info[idx]['hook_type']) for idx in out_indices]
|
||||
self.feature_hooks = FeatureHooks(hooks, self.named_modules())
|
||||
|
||||
def feature_channels(self, idx=None):
|
||||
""" Feature Channel Shortcut
|
||||
@ -439,15 +444,32 @@ class EfficientNetFeatures(nn.Module):
|
||||
return feature channel count for that feature block index (independent of out_indices setting).
|
||||
"""
|
||||
if isinstance(idx, int):
|
||||
return self.feature_info[idx]['num_chs']
|
||||
return [self.feature_info[i]['num_chs'] for i in self.out_indices]
|
||||
return self._feature_info[idx]['num_chs']
|
||||
return [self._feature_info[i]['num_chs'] for i in self.out_indices]
|
||||
|
||||
def feature_info(self, idx=None):
|
||||
""" Feature Channel Shortcut
|
||||
Returns feature channel count for each output index if idx == None. If idx is an integer, will
|
||||
return feature channel count for that feature block index (independent of out_indices setting).
|
||||
"""
|
||||
if isinstance(idx, int):
|
||||
return self._feature_info[idx]
|
||||
return [self._feature_info[i] for i in self.out_indices]
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_stem(x)
|
||||
x = self.bn1(x)
|
||||
x = self.act1(x)
|
||||
self.blocks(x)
|
||||
return self.feature_hooks.get_output(x.device)
|
||||
if self.feature_hooks is None:
|
||||
features = []
|
||||
for i, b in enumerate(self.blocks):
|
||||
x = b(x)
|
||||
if i in self._stage_to_feature_idx:
|
||||
features.append(x)
|
||||
return features
|
||||
else:
|
||||
self.blocks(x)
|
||||
return self.feature_hooks.get_output(x.device)
|
||||
|
||||
|
||||
def _create_model(model_kwargs, default_cfg, pretrained=False):
|
||||
|
@ -120,11 +120,13 @@ class ConvBnAct(nn.Module):
|
||||
self.bn1 = norm_layer(out_chs, **norm_kwargs)
|
||||
self.act1 = act_layer(inplace=True)
|
||||
|
||||
def feature_module(self, location):
|
||||
return 'act1'
|
||||
|
||||
def feature_channels(self, location):
|
||||
return self.conv.out_channels
|
||||
def feature_info(self, location):
|
||||
if location == 'expansion' or location == 'depthwise':
|
||||
# no expansion or depthwise this block, use act after conv
|
||||
info = dict(module='act1', hook_type='forward', num_chs=self.conv.out_channels)
|
||||
else: # location == 'bottleneck'
|
||||
info = dict(module='', hook_type='', num_chs=self.conv.out_channels)
|
||||
return info
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
@ -165,12 +167,15 @@ class DepthwiseSeparableConv(nn.Module):
|
||||
self.bn2 = norm_layer(out_chs, **norm_kwargs)
|
||||
self.act2 = act_layer(inplace=True) if self.has_pw_act else nn.Identity()
|
||||
|
||||
def feature_module(self, location):
|
||||
# no expansion in this block, pre pw only feature extraction point
|
||||
return 'conv_pw'
|
||||
|
||||
def feature_channels(self, location):
|
||||
return self.conv_pw.in_channels
|
||||
def feature_info(self, location):
|
||||
if location == 'expansion':
|
||||
# no expansion in this block, use depthwise, before SE
|
||||
info = dict(module='act1', hook_type='forward', num_chs=self.conv_pw.in_channels)
|
||||
elif location == 'depthwise': # after SE
|
||||
info = dict(module='conv_pw', hook_type='forward_pre', num_chs=self.conv_pw.in_channels)
|
||||
else: # location == 'bottleneck'
|
||||
info = dict(module='', hook_type='', num_chs=self.conv_pw.out_channels)
|
||||
return info
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
@ -232,16 +237,14 @@ class InvertedResidual(nn.Module):
|
||||
self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs)
|
||||
self.bn3 = norm_layer(out_chs, **norm_kwargs)
|
||||
|
||||
def feature_module(self, location):
|
||||
if location == 'post_exp':
|
||||
return 'act1'
|
||||
return 'conv_pwl'
|
||||
|
||||
def feature_channels(self, location):
|
||||
if location == 'post_exp':
|
||||
return self.conv_pw.out_channels
|
||||
# location == 'pre_pw'
|
||||
return self.conv_pwl.in_channels
|
||||
def feature_info(self, location):
|
||||
if location == 'expansion':
|
||||
info = dict(module='act1', hook_type='forward', num_chs=self.conv_pw.in_channels)
|
||||
elif location == 'depthwise': # after SE
|
||||
info = dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels)
|
||||
else: # location == 'bottleneck'
|
||||
info = dict(module='', hook_type='', num_chs=self.conv_pwl.out_channels)
|
||||
return info
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
@ -359,16 +362,15 @@ class EdgeResidual(nn.Module):
|
||||
mid_chs, out_chs, pw_kernel_size, stride=stride, dilation=dilation, padding=pad_type)
|
||||
self.bn2 = norm_layer(out_chs, **norm_kwargs)
|
||||
|
||||
def feature_module(self, location):
|
||||
if location == 'post_exp':
|
||||
return 'act1'
|
||||
return 'conv_pwl'
|
||||
|
||||
def feature_channels(self, location):
|
||||
if location == 'post_exp':
|
||||
return self.conv_exp.out_channels
|
||||
# location == 'pre_pw'
|
||||
return self.conv_pwl.in_channels
|
||||
def feature_info(self, location):
|
||||
if location == 'expansion':
|
||||
info = dict(module='act1', hook_type='forward', num_chs=self.conv_exp.out_channels)
|
||||
elif location == 'depthwise':
|
||||
# there is no depthwise, take after SE, before PWL
|
||||
info = dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels)
|
||||
else: # location == 'bottleneck'
|
||||
info = dict(module='', hook_type='', num_chs=self.conv_pwl.out_channels)
|
||||
return info
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
@ -218,7 +218,7 @@ class EfficientNetBuilder:
|
||||
self.norm_kwargs = norm_kwargs
|
||||
self.drop_path_rate = drop_path_rate
|
||||
self.feature_location = feature_location
|
||||
assert feature_location in ('pre_pwl', 'post_exp', '')
|
||||
assert feature_location in ('bottleneck', 'depthwise', 'expansion', '')
|
||||
self.verbose = verbose
|
||||
|
||||
# state updated during build, consumed by model
|
||||
@ -313,20 +313,21 @@ class EfficientNetBuilder:
|
||||
block_args['stride'] = 1
|
||||
|
||||
do_extract = False
|
||||
if self.feature_location == 'pre_pwl':
|
||||
if self.feature_location == 'bottleneck' or self.feature_location == 'depthwise':
|
||||
if last_block:
|
||||
next_stage_idx = stage_idx + 1
|
||||
if next_stage_idx >= len(model_block_args):
|
||||
do_extract = True
|
||||
else:
|
||||
do_extract = model_block_args[next_stage_idx][0]['stride'] > 1
|
||||
elif self.feature_location == 'post_exp':
|
||||
if block_args['stride'] > 1 or (last_stack and last_block) :
|
||||
elif self.feature_location == 'expansion':
|
||||
if block_args['stride'] > 1 or (last_stack and last_block):
|
||||
do_extract = True
|
||||
if do_extract:
|
||||
extract_features = self.feature_location
|
||||
|
||||
next_dilation = current_dilation
|
||||
next_output_stride = current_stride
|
||||
if block_args['stride'] > 1:
|
||||
next_output_stride = current_stride * block_args['stride']
|
||||
if next_output_stride > self.output_stride:
|
||||
@ -347,14 +348,13 @@ class EfficientNetBuilder:
|
||||
|
||||
# stash feature module name and channel info for model feature extraction
|
||||
if extract_features:
|
||||
feature_module = block.feature_module(extract_features)
|
||||
if feature_module:
|
||||
feature_module = 'blocks.{}.{}.'.format(stage_idx, block_idx) + feature_module
|
||||
feature_channels = block.feature_channels(extract_features)
|
||||
self.features[feature_idx] = dict(
|
||||
name=feature_module,
|
||||
num_chs=feature_channels
|
||||
)
|
||||
feature_info = block.feature_info(extract_features)
|
||||
if feature_info['module']:
|
||||
feature_info['module'] = 'blocks.{}.{}.'.format(stage_idx, block_idx) + feature_info['module']
|
||||
feature_info['stage_idx'] = stage_idx
|
||||
feature_info['block_idx'] = block_idx
|
||||
feature_info['reduction'] = current_stride
|
||||
self.features[feature_idx] = feature_info
|
||||
feature_idx += 1
|
||||
|
||||
total_block_idx += 1 # incr global block idx (across all stacks)
|
||||
|
@ -1,9 +1,10 @@
|
||||
from .padding import get_padding
|
||||
from .avg_pool2d_same import AvgPool2dSame
|
||||
from .pool2d_same import AvgPool2dSame
|
||||
from .conv2d_same import Conv2dSame
|
||||
from .conv_bn_act import ConvBnAct
|
||||
from .mixed_conv2d import MixedConv2d
|
||||
from .cond_conv2d import CondConv2d, get_condconv_initializer
|
||||
from .pool2d_same import create_pool2d
|
||||
from .create_conv2d import create_conv2d
|
||||
from .create_attn import create_attn
|
||||
from .selective_kernel import SelectiveKernelConv
|
||||
|
@ -1,31 +0,0 @@
|
||||
""" AvgPool2d w/ Same Padding
|
||||
|
||||
Hacked together by Ross Wightman
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing import List
|
||||
import math
|
||||
|
||||
from .helpers import tup_pair
|
||||
from .padding import pad_same
|
||||
|
||||
|
||||
def avg_pool2d_same(x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0),
|
||||
ceil_mode: bool = False, count_include_pad: bool = True):
|
||||
x = pad_same(x, kernel_size, stride)
|
||||
return F.avg_pool2d(x, kernel_size, stride, (0, 0), ceil_mode, count_include_pad)
|
||||
|
||||
|
||||
class AvgPool2dSame(nn.AvgPool2d):
|
||||
""" Tensorflow like 'SAME' wrapper for 2D average pooling
|
||||
"""
|
||||
def __init__(self, kernel_size: int, stride=None, padding=0, ceil_mode=False, count_include_pad=True):
|
||||
kernel_size = tup_pair(kernel_size)
|
||||
stride = tup_pair(stride)
|
||||
super(AvgPool2dSame, self).__init__(kernel_size, stride, (0, 0), ceil_mode, count_include_pad)
|
||||
|
||||
def forward(self, x):
|
||||
return avg_pool2d_same(
|
||||
x, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad)
|
@ -14,7 +14,8 @@ from torch import nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .helpers import tup_pair
|
||||
from .conv2d_same import get_padding_value, conv2d_same
|
||||
from .conv2d_same import conv2d_same
|
||||
from timm.models.layers.padding import get_padding_value
|
||||
|
||||
|
||||
def get_condconv_initializer(initializer, num_experts, expert_shape):
|
||||
|
@ -5,10 +5,10 @@ Hacked together by Ross Wightman
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Union, List, Tuple, Optional, Callable
|
||||
import math
|
||||
from typing import Tuple, Optional
|
||||
|
||||
from .padding import get_padding, pad_same, is_static_pad
|
||||
from timm.models.layers.padding import get_padding_value
|
||||
from .padding import pad_same
|
||||
|
||||
|
||||
def conv2d_same(
|
||||
@ -31,29 +31,6 @@ class Conv2dSame(nn.Conv2d):
|
||||
return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
|
||||
|
||||
def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]:
|
||||
dynamic = False
|
||||
if isinstance(padding, str):
|
||||
# for any string padding, the padding will be calculated for you, one of three ways
|
||||
padding = padding.lower()
|
||||
if padding == 'same':
|
||||
# TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
|
||||
if is_static_pad(kernel_size, **kwargs):
|
||||
# static case, no extra overhead
|
||||
padding = get_padding(kernel_size, **kwargs)
|
||||
else:
|
||||
# dynamic 'SAME' padding, has runtime/GPU memory overhead
|
||||
padding = 0
|
||||
dynamic = True
|
||||
elif padding == 'valid':
|
||||
# 'VALID' padding, same as padding=0
|
||||
padding = 0
|
||||
else:
|
||||
# Default to PyTorch style 'same'-ish symmetric padding
|
||||
padding = get_padding(kernel_size, **kwargs)
|
||||
return padding, dynamic
|
||||
|
||||
|
||||
def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs):
|
||||
padding = kwargs.pop('padding', '')
|
||||
kwargs.setdefault('bias', False)
|
||||
|
@ -3,7 +3,7 @@
|
||||
Hacked together by Ross Wightman
|
||||
"""
|
||||
import math
|
||||
from typing import List
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch.nn.functional as F
|
||||
|
||||
@ -25,9 +25,32 @@ def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_):
|
||||
|
||||
|
||||
# Dynamically pad input x with 'SAME' padding for conv with specified args
|
||||
def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1)):
|
||||
def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1), value: float = 0):
|
||||
ih, iw = x.size()[-2:]
|
||||
pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(iw, k[1], s[1], d[1])
|
||||
if pad_h > 0 or pad_w > 0:
|
||||
x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
|
||||
x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2], value=value)
|
||||
return x
|
||||
|
||||
|
||||
def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]:
|
||||
dynamic = False
|
||||
if isinstance(padding, str):
|
||||
# for any string padding, the padding will be calculated for you, one of three ways
|
||||
padding = padding.lower()
|
||||
if padding == 'same':
|
||||
# TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
|
||||
if is_static_pad(kernel_size, **kwargs):
|
||||
# static case, no extra overhead
|
||||
padding = get_padding(kernel_size, **kwargs)
|
||||
else:
|
||||
# dynamic 'SAME' padding, has runtime/GPU memory overhead
|
||||
padding = 0
|
||||
dynamic = True
|
||||
elif padding == 'valid':
|
||||
# 'VALID' padding, same as padding=0
|
||||
padding = 0
|
||||
else:
|
||||
# Default to PyTorch style 'same'-ish symmetric padding
|
||||
padding = get_padding(kernel_size, **kwargs)
|
||||
return padding, dynamic
|
||||
|
71
timm/models/layers/pool2d_same.py
Normal file
71
timm/models/layers/pool2d_same.py
Normal file
@ -0,0 +1,71 @@
|
||||
""" AvgPool2d w/ Same Padding
|
||||
|
||||
Hacked together by Ross Wightman
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Union, List, Tuple, Optional
|
||||
import math
|
||||
|
||||
from .helpers import tup_pair
|
||||
from .padding import pad_same, get_padding_value
|
||||
|
||||
|
||||
def avg_pool2d_same(x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0),
|
||||
ceil_mode: bool = False, count_include_pad: bool = True):
|
||||
# FIXME how to deal with count_include_pad vs not for external padding?
|
||||
x = pad_same(x, kernel_size, stride)
|
||||
return F.avg_pool2d(x, kernel_size, stride, (0, 0), ceil_mode, count_include_pad)
|
||||
|
||||
|
||||
class AvgPool2dSame(nn.AvgPool2d):
|
||||
""" Tensorflow like 'SAME' wrapper for 2D average pooling
|
||||
"""
|
||||
def __init__(self, kernel_size: int, stride=None, padding=0, ceil_mode=False, count_include_pad=True):
|
||||
kernel_size = tup_pair(kernel_size)
|
||||
stride = tup_pair(stride)
|
||||
super(AvgPool2dSame, self).__init__(kernel_size, stride, (0, 0), ceil_mode, count_include_pad)
|
||||
|
||||
def forward(self, x):
|
||||
return avg_pool2d_same(
|
||||
x, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad)
|
||||
|
||||
|
||||
def max_pool2d_same(
|
||||
x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0),
|
||||
dilation: List[int] = (1, 1), ceil_mode: bool = False):
|
||||
x = pad_same(x, kernel_size, stride, value=-float('inf'))
|
||||
return F.max_pool2d(x, kernel_size, stride, (0, 0), dilation, ceil_mode)
|
||||
|
||||
|
||||
class MaxPool2dSame(nn.MaxPool2d):
|
||||
""" Tensorflow like 'SAME' wrapper for 2D max pooling
|
||||
"""
|
||||
def __init__(self, kernel_size: int, stride=None, padding=0, dilation=1, ceil_mode=False, count_include_pad=True):
|
||||
kernel_size = tup_pair(kernel_size)
|
||||
stride = tup_pair(stride)
|
||||
super(MaxPool2dSame, self).__init__(kernel_size, stride, (0, 0), dilation, ceil_mode, count_include_pad)
|
||||
|
||||
def forward(self, x):
|
||||
return max_pool2d_same(x, self.kernel_size, self.stride, self.padding, self.dilation, self.ceil_mode)
|
||||
|
||||
|
||||
def create_pool2d(pool_type, kernel_size, stride=None, **kwargs):
|
||||
stride = stride or kernel_size
|
||||
padding = kwargs.pop('padding', '')
|
||||
padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, **kwargs)
|
||||
if is_dynamic:
|
||||
if pool_type == 'avg':
|
||||
return AvgPool2dSame(kernel_size, stride=stride, **kwargs)
|
||||
elif pool_type == 'max':
|
||||
return MaxPool2dSame(kernel_size, stride=stride, **kwargs)
|
||||
else:
|
||||
assert False, f'Unsupported pool type {pool_type}'
|
||||
else:
|
||||
if pool_type == 'avg':
|
||||
return nn.AvgPool2d(kernel_size, stride=stride, padding=padding, **kwargs)
|
||||
elif pool_type == 'max':
|
||||
return nn.MaxPool2d(kernel_size, stride=stride, padding=padding, **kwargs)
|
||||
else:
|
||||
assert False, f'Unsupported pool type {pool_type}'
|
Loading…
x
Reference in New Issue
Block a user